package fileutils

import (
	"fmt"
	"io"
	"os"
	"time"

	"github.com/pkg/sftp"
)

// ProgressCallback is a function that reports upload progress
type ProgressCallback func(percentage float64, estimatedTimeRemaining time.Duration)

// UploadFile uploads a local file to a remote server using SFTP
func UploadFile(client *sftp.Client, localPath, remotePath string) error {
	return UploadFileWithProgress(client, localPath, remotePath, nil)
}

// UploadFileWithProgress uploads a local file to a remote server using SFTP and reports progress
func UploadFileWithProgress(client *sftp.Client, localPath, remotePath string, progressCallback ProgressCallback) error {
	// Open the local file
	localFile, err := os.Open(localPath)
	if err != nil {
		return fmt.Errorf("failed to open local file: %w", err)
	}
	defer localFile.Close()

	// Get file size
	fileInfo, err := localFile.Stat()
	if err != nil {
		return fmt.Errorf("failed to get file info: %w", err)
	}
	fileSize := fileInfo.Size()

	// Create the remote file
	remoteFile, err := client.Create(remotePath)
	if err != nil {
		return fmt.Errorf("failed to create remote file: %w", err)
	}
	defer remoteFile.Close()

	// If no progress callback is provided, just copy the file
	if progressCallback == nil {
		if _, err = io.Copy(remoteFile, localFile); err != nil {
			return fmt.Errorf("failed to upload file: %w", err)
		}
		return nil
	}

	// Copy with progress reporting
	buffer := make([]byte, 32*1024) // 32KB buffer
	var totalBytesWritten int64
	startTime := time.Now()

	for {
		bytesRead, readErr := localFile.Read(buffer)
		if bytesRead > 0 {
			bytesWritten, writeErr := remoteFile.Write(buffer[:bytesRead])
			if writeErr != nil {
				return fmt.Errorf("failed to write to remote file: %w", writeErr)
			}
			totalBytesWritten += int64(bytesWritten)

			// Calculate progress percentage
			percentage := float64(totalBytesWritten) / float64(fileSize) * 100

			// Calculate estimated time remaining
			elapsed := time.Since(startTime)
			var estimatedTimeRemaining time.Duration
			if totalBytesWritten > 0 {
				bytesPerSecond := float64(totalBytesWritten) / elapsed.Seconds()
				if bytesPerSecond > 0 {
					remainingBytes := fileSize - totalBytesWritten
					estimatedSeconds := float64(remainingBytes) / bytesPerSecond
					estimatedTimeRemaining = time.Duration(estimatedSeconds * float64(time.Second))
				}
			}

			// Report progress
			progressCallback(percentage, estimatedTimeRemaining)
		}

		if readErr != nil {
			if readErr == io.EOF {
				break
			}
			return fmt.Errorf("failed to read from local file: %w", readErr)
		}
	}

	return nil
}