Agate/grpc/client.go

237 lines
6.7 KiB
Go

package grpc
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"gitea.unprism.ru/KRBL/Agate/store"
)
// SnapshotClient implements the client for connecting to a remote snapshot server
type SnapshotClient struct {
conn *grpc.ClientConn
client SnapshotServiceClient
}
// NewSnapshotClient creates a new client connected to the specified address
func NewSnapshotClient(address string) (*SnapshotClient, error) {
// Connect to the server with insecure credentials (for simplicity)
conn, err := grpc.NewClient(address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, fmt.Errorf("failed to connect to server at %s: %w", address, err)
}
// Create the gRPC client
client := NewSnapshotServiceClient(conn)
return &SnapshotClient{
conn: conn,
client: client,
}, nil
}
// Close closes the connection to the server
func (c *SnapshotClient) Close() error {
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// ListSnapshots retrieves a list of snapshots from the remote server
func (c *SnapshotClient) ListSnapshots(ctx context.Context) ([]store.SnapshotInfo, error) {
response, err := c.client.ListSnapshots(ctx, &ListSnapshotsRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list snapshots: %w", err)
}
// Convert gRPC snapshot info to store.SnapshotInfo
snapshots := make([]store.SnapshotInfo, 0, len(response.Snapshots))
for _, snapshot := range response.Snapshots {
snapshots = append(snapshots, store.SnapshotInfo{
ID: snapshot.Id,
Name: snapshot.Name,
ParentID: snapshot.ParentId,
CreationTime: snapshot.CreationTime.AsTime(),
})
}
return snapshots, nil
}
// FetchSnapshotDetails retrieves detailed information about a specific snapshot
func (c *SnapshotClient) FetchSnapshotDetails(ctx context.Context, snapshotID string) (*store.Snapshot, error) {
response, err := c.client.GetSnapshotDetails(ctx, &GetSnapshotDetailsRequest{
SnapshotId: snapshotID,
})
if err != nil {
return nil, fmt.Errorf("failed to get snapshot details: %w", err)
}
// Convert gRPC snapshot details to store.Snapshot
snapshot := &store.Snapshot{
ID: response.Info.Id,
Name: response.Info.Name,
ParentID: response.Info.ParentId,
CreationTime: response.Info.CreationTime.AsTime(),
Files: make([]store.FileInfo, 0, len(response.Files)),
}
// Convert file info
for _, file := range response.Files {
snapshot.Files = append(snapshot.Files, store.FileInfo{
Path: file.Path,
Size: file.SizeBytes,
IsDir: file.IsDir,
SHA256: file.Sha256Hash,
})
}
return snapshot, nil
}
// DownloadSnapshot downloads a snapshot from the server
// This implementation downloads each file individually to optimize bandwidth usage
func (c *SnapshotClient) DownloadSnapshot(ctx context.Context, snapshotID string, targetDir string, localParentID string) error {
// Get snapshot details
snapshot, err := c.FetchSnapshotDetails(ctx, snapshotID)
if err != nil {
return fmt.Errorf("failed to get snapshot details: %w", err)
}
// Create target directory if it doesn't exist
if err := os.MkdirAll(targetDir, 0755); err != nil {
return fmt.Errorf("failed to create target directory: %w", err)
}
// If a local parent is specified, get its details to compare files
var localParentFiles map[string]store.FileInfo
if localParentID != "" {
localParent, err := c.FetchSnapshotDetails(ctx, localParentID)
if err == nil {
// Create a map of file paths to file info for quick lookup
localParentFiles = make(map[string]store.FileInfo, len(localParent.Files))
for _, file := range localParent.Files {
localParentFiles[file.Path] = file
}
}
}
// Download each file
for _, file := range snapshot.Files {
// Skip directories, we'll create them when needed
if file.IsDir {
// Create directory
dirPath := filepath.Join(targetDir, file.Path)
if err := os.MkdirAll(dirPath, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dirPath, err)
}
continue
}
// Check if we can skip downloading this file
if localParentFiles != nil {
if parentFile, exists := localParentFiles[file.Path]; exists && parentFile.SHA256 == file.SHA256 {
// File exists in parent with same hash, copy it instead of downloading
parentFilePath := filepath.Join(targetDir, "..", localParentID, file.Path)
targetFilePath := filepath.Join(targetDir, file.Path)
// Ensure parent directory exists
if err := os.MkdirAll(filepath.Dir(targetFilePath), 0755); err != nil {
return fmt.Errorf("failed to create directory for %s: %w", targetFilePath, err)
}
// Copy the file
if err := copyFile(parentFilePath, targetFilePath); err != nil {
// If copy fails, fall back to downloading
fmt.Printf("Failed to copy file %s, will download instead: %v\n", file.Path, err)
} else {
// Skip to next file
continue
}
}
}
// Download the file
if err := c.downloadFile(ctx, snapshotID, file.Path, filepath.Join(targetDir, file.Path)); err != nil {
return fmt.Errorf("failed to download file %s: %w", file.Path, err)
}
}
return nil
}
// downloadFile downloads a single file from the server
func (c *SnapshotClient) downloadFile(ctx context.Context, snapshotID, filePath, targetPath string) error {
// Create the request
req := &DownloadFileRequest{
SnapshotId: snapshotID,
FilePath: filePath,
}
// Start streaming the file
stream, err := c.client.DownloadFile(ctx, req)
if err != nil {
return fmt.Errorf("failed to start file download: %w", err)
}
// Ensure the target directory exists
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
return fmt.Errorf("failed to create directory for %s: %w", targetPath, err)
}
// Create the target file
file, err := os.Create(targetPath)
if err != nil {
return fmt.Errorf("failed to create file %s: %w", targetPath, err)
}
defer file.Close()
// Receive and write chunks
for {
resp, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error receiving file chunk: %w", err)
}
// Write the chunk to the file
if _, err := file.Write(resp.ChunkData); err != nil {
return fmt.Errorf("error writing to file: %w", err)
}
}
return nil
}
// Helper function to copy a file
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// ConnectToServer creates a new client connected to the specified address
func ConnectToServer(address string) (*SnapshotClient, error) {
return NewSnapshotClient(address)
}