1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/permission"
15)
16
17type DownloadParams struct {
18 URL string `json:"url"`
19 FilePath string `json:"file_path"`
20 Timeout int `json:"timeout,omitempty"`
21}
22
23type DownloadPermissionsParams struct {
24 URL string `json:"url"`
25 FilePath string `json:"file_path"`
26 Timeout int `json:"timeout,omitempty"`
27}
28
29type downloadTool struct {
30 client *http.Client
31 permissions permission.Service
32 workingDir string
33}
34
35const (
36 DownloadToolName = "download"
37 downloadToolDescription = `Downloads binary data from a URL and saves it to a local file.
38
39WHEN TO USE THIS TOOL:
40- Use when you need to download files, images, or other binary data from URLs
41- Helpful for downloading assets, documents, or any file type
42- Useful for saving remote content locally for processing or storage
43
44HOW TO USE:
45- Provide the URL to download from
46- Specify the local file path where the content should be saved
47- Optionally set a timeout for the request
48
49FEATURES:
50- Downloads any file type (binary or text)
51- Automatically creates parent directories if they don't exist
52- Handles large files efficiently with streaming
53- Sets reasonable timeouts to prevent hanging
54- Validates input parameters before making requests
55
56LIMITATIONS:
57- Maximum file size is 100MB
58- Only supports HTTP and HTTPS protocols
59- Cannot handle authentication or cookies
60- Some websites may block automated requests
61- Will overwrite existing files without warning
62
63TIPS:
64- Use absolute paths or paths relative to the working directory
65- Set appropriate timeouts for large files or slow connections`
66)
67
68func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool {
69 return &downloadTool{
70 client: &http.Client{
71 Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
72 Transport: &http.Transport{
73 MaxIdleConns: 100,
74 MaxIdleConnsPerHost: 10,
75 IdleConnTimeout: 90 * time.Second,
76 },
77 },
78 permissions: permissions,
79 workingDir: workingDir,
80 }
81}
82
83func (t *downloadTool) Name() string {
84 return DownloadToolName
85}
86
87func (t *downloadTool) Info() ToolInfo {
88 return ToolInfo{
89 Name: DownloadToolName,
90 Description: downloadToolDescription,
91 Parameters: map[string]any{
92 "url": map[string]any{
93 "type": "string",
94 "description": "The URL to download from",
95 },
96 "file_path": map[string]any{
97 "type": "string",
98 "description": "The local file path where the downloaded content should be saved",
99 },
100 "timeout": map[string]any{
101 "type": "number",
102 "description": "Optional timeout in seconds (max 600)",
103 },
104 },
105 Required: []string{"url", "file_path"},
106 }
107}
108
109func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
110 var params DownloadParams
111 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
112 return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil
113 }
114
115 if params.URL == "" {
116 return NewTextErrorResponse("URL parameter is required"), nil
117 }
118
119 if params.FilePath == "" {
120 return NewTextErrorResponse("file_path parameter is required"), nil
121 }
122
123 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
124 return NewTextErrorResponse("URL must start with http:// or https://"), nil
125 }
126
127 // Convert relative path to absolute path
128 var filePath string
129 if filepath.IsAbs(params.FilePath) {
130 filePath = params.FilePath
131 } else {
132 filePath = filepath.Join(t.workingDir, params.FilePath)
133 }
134
135 sessionID, messageID := GetContextValues(ctx)
136 if sessionID == "" || messageID == "" {
137 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files")
138 }
139
140 p := t.permissions.Request(
141 permission.CreatePermissionRequest{
142 SessionID: sessionID,
143 Path: filePath,
144 ToolName: DownloadToolName,
145 Action: "download",
146 Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
147 Params: DownloadPermissionsParams(params),
148 },
149 )
150
151 if !p {
152 return ToolResponse{}, permission.ErrorPermissionDenied
153 }
154
155 // Handle timeout with context
156 requestCtx := ctx
157 if params.Timeout > 0 {
158 maxTimeout := 600 // 10 minutes
159 if params.Timeout > maxTimeout {
160 params.Timeout = maxTimeout
161 }
162 var cancel context.CancelFunc
163 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
164 defer cancel()
165 }
166
167 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
168 if err != nil {
169 return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
170 }
171
172 req.Header.Set("User-Agent", "crush/1.0")
173
174 resp, err := t.client.Do(req)
175 if err != nil {
176 return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
177 }
178 defer resp.Body.Close()
179
180 if resp.StatusCode != http.StatusOK {
181 return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
182 }
183
184 // Check content length if available
185 maxSize := int64(100 * 1024 * 1024) // 100MB
186 if resp.ContentLength > maxSize {
187 return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
188 }
189
190 // Create parent directories if they don't exist
191 if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
192 return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
193 }
194
195 // Create the output file
196 outFile, err := os.Create(filePath)
197 if err != nil {
198 return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
199 }
200 defer outFile.Close()
201
202 // Copy data with size limit
203 limitedReader := io.LimitReader(resp.Body, maxSize)
204 bytesWritten, err := io.Copy(outFile, limitedReader)
205 if err != nil {
206 return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
207 }
208
209 // Check if we hit the size limit
210 if bytesWritten == maxSize {
211 // Clean up the file since it might be incomplete
212 os.Remove(filePath)
213 return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
214 }
215
216 contentType := resp.Header.Get("Content-Type")
217 responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
218 if contentType != "" {
219 responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
220 }
221
222 return NewTextResponse(responseMsg), nil
223}