1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/crush/internal/permission"
16)
17
18type DownloadParams struct {
19 URL string `json:"url" description:"The URL to download from"`
20 FilePath string `json:"file_path" description:"The local file path where the downloaded content should be saved"`
21 Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 600)"`
22}
23
24type DownloadPermissionsParams struct {
25 URL string `json:"url"`
26 FilePath string `json:"file_path"`
27 Timeout int `json:"timeout,omitempty"`
28}
29
30const DownloadToolName = "download"
31
32//go:embed download.md
33var downloadDescription []byte
34
35func NewDownloadTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
36 if client == nil {
37 client = &http.Client{
38 Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
39 Transport: &http.Transport{
40 MaxIdleConns: 100,
41 MaxIdleConnsPerHost: 10,
42 IdleConnTimeout: 90 * time.Second,
43 },
44 }
45 }
46 return fantasy.NewAgentTool(
47 DownloadToolName,
48 string(downloadDescription),
49 func(ctx context.Context, params DownloadParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
50 if params.URL == "" {
51 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
52 }
53
54 if params.FilePath == "" {
55 return fantasy.NewTextErrorResponse("file_path parameter is required"), nil
56 }
57
58 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
59 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
60 }
61
62 // Convert relative path to absolute path
63 var filePath string
64 if filepath.IsAbs(params.FilePath) {
65 filePath = params.FilePath
66 } else {
67 filePath = filepath.Join(workingDir, params.FilePath)
68 }
69
70 sessionID := GetSessionFromContext(ctx)
71 if sessionID == "" {
72 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
73 }
74
75 p := permissions.Request(
76 permission.CreatePermissionRequest{
77 SessionID: sessionID,
78 Path: filePath,
79 ToolName: DownloadToolName,
80 Action: "download",
81 Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
82 Params: DownloadPermissionsParams(params),
83 },
84 )
85
86 if !p {
87 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
88 }
89
90 // Handle timeout with context
91 requestCtx := ctx
92 if params.Timeout > 0 {
93 maxTimeout := 600 // 10 minutes
94 if params.Timeout > maxTimeout {
95 params.Timeout = maxTimeout
96 }
97 var cancel context.CancelFunc
98 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
99 defer cancel()
100 }
101
102 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
103 if err != nil {
104 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
105 }
106
107 req.Header.Set("User-Agent", "crush/1.0")
108
109 resp, err := client.Do(req)
110 if err != nil {
111 return fantasy.ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
112 }
113 defer resp.Body.Close()
114
115 if resp.StatusCode != http.StatusOK {
116 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
117 }
118
119 // Check content length if available
120 maxSize := int64(100 * 1024 * 1024) // 100MB
121 if resp.ContentLength > maxSize {
122 return fantasy.NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
123 }
124
125 // Create parent directories if they don't exist
126 if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
127 return fantasy.ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
128 }
129
130 // Create the output file
131 outFile, err := os.Create(filePath)
132 if err != nil {
133 return fantasy.ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
134 }
135 defer outFile.Close()
136
137 // Copy data with size limit
138 limitedReader := io.LimitReader(resp.Body, maxSize)
139 bytesWritten, err := io.Copy(outFile, limitedReader)
140 if err != nil {
141 return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
142 }
143
144 // Check if we hit the size limit
145 if bytesWritten == maxSize {
146 // Clean up the file since it might be incomplete
147 os.Remove(filePath)
148 return fantasy.NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
149 }
150
151 contentType := resp.Header.Get("Content-Type")
152 responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
153 if contentType != "" {
154 responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
155 }
156
157 return fantasy.NewTextResponse(responseMsg), nil
158 })
159}