1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/permission"
 15	"github.com/charmbracelet/fantasy/ai"
 16)
 17
 18type DownloadParams struct {
 19	URL      string `json:"url" description:"The URL to download from"`
 20	FilePath string `json:"file_path" description:"The local file path where the downloaded content should be saved"`
 21	Timeout  int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 600)"`
 22}
 23
 24type DownloadPermissionsParams struct {
 25	URL      string `json:"url"`
 26	FilePath string `json:"file_path"`
 27	Timeout  int    `json:"timeout,omitempty"`
 28}
 29
 30const DownloadToolName = "download"
 31
 32//go:embed download.md
 33var downloadDescription []byte
 34
 35func NewDownloadTool(permissions permission.Service, workingDir string, client *http.Client) ai.AgentTool {
 36	if client == nil {
 37		client = &http.Client{
 38			Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
 39			Transport: &http.Transport{
 40				MaxIdleConns:        100,
 41				MaxIdleConnsPerHost: 10,
 42				IdleConnTimeout:     90 * time.Second,
 43			},
 44		}
 45	}
 46	return ai.NewAgentTool(
 47		DownloadToolName,
 48		string(downloadDescription),
 49		func(ctx context.Context, params DownloadParams, call ai.ToolCall) (ai.ToolResponse, error) {
 50			if params.URL == "" {
 51				return ai.NewTextErrorResponse("URL parameter is required"), nil
 52			}
 53
 54			if params.FilePath == "" {
 55				return ai.NewTextErrorResponse("file_path parameter is required"), nil
 56			}
 57
 58			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 59				return ai.NewTextErrorResponse("URL must start with http:// or https://"), nil
 60			}
 61
 62			// Convert relative path to absolute path
 63			var filePath string
 64			if filepath.IsAbs(params.FilePath) {
 65				filePath = params.FilePath
 66			} else {
 67				filePath = filepath.Join(workingDir, params.FilePath)
 68			}
 69
 70			sessionID := GetSessionFromContext(ctx)
 71			if sessionID == "" {
 72				return ai.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
 73			}
 74
 75			p := permissions.Request(
 76				permission.CreatePermissionRequest{
 77					SessionID:   sessionID,
 78					Path:        filePath,
 79					ToolName:    DownloadToolName,
 80					Action:      "download",
 81					Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
 82					Params:      DownloadPermissionsParams(params),
 83				},
 84			)
 85
 86			if !p {
 87				return ai.ToolResponse{}, permission.ErrorPermissionDenied
 88			}
 89
 90			// Handle timeout with context
 91			requestCtx := ctx
 92			if params.Timeout > 0 {
 93				maxTimeout := 600 // 10 minutes
 94				if params.Timeout > maxTimeout {
 95					params.Timeout = maxTimeout
 96				}
 97				var cancel context.CancelFunc
 98				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 99				defer cancel()
100			}
101
102			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
103			if err != nil {
104				return ai.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
105			}
106
107			req.Header.Set("User-Agent", "crush/1.0")
108
109			resp, err := client.Do(req)
110			if err != nil {
111				return ai.ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
112			}
113			defer resp.Body.Close()
114
115			if resp.StatusCode != http.StatusOK {
116				return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
117			}
118
119			// Check content length if available
120			maxSize := int64(100 * 1024 * 1024) // 100MB
121			if resp.ContentLength > maxSize {
122				return ai.NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
123			}
124
125			// Create parent directories if they don't exist
126			if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
127				return ai.ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
128			}
129
130			// Create the output file
131			outFile, err := os.Create(filePath)
132			if err != nil {
133				return ai.ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
134			}
135			defer outFile.Close()
136
137			// Copy data with size limit
138			limitedReader := io.LimitReader(resp.Body, maxSize)
139			bytesWritten, err := io.Copy(outFile, limitedReader)
140			if err != nil {
141				return ai.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
142			}
143
144			// Check if we hit the size limit
145			if bytesWritten == maxSize {
146				// Clean up the file since it might be incomplete
147				os.Remove(filePath)
148				return ai.NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
149			}
150
151			contentType := resp.Header.Get("Content-Type")
152			responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
153			if contentType != "" {
154				responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
155			}
156
157			return ai.NewTextResponse(responseMsg), nil
158		})
159}