1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "os"
11 "path/filepath"
12 "strings"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/permission"
16)
17
18type DownloadParams struct {
19 URL string `json:"url"`
20 FilePath string `json:"file_path"`
21 Timeout int `json:"timeout,omitempty"`
22}
23
24type DownloadPermissionsParams struct {
25 URL string `json:"url"`
26 FilePath string `json:"file_path"`
27 Timeout int `json:"timeout,omitempty"`
28}
29
30type downloadTool struct {
31 client *http.Client
32 permissions permission.Service
33 workingDir string
34}
35
36const DownloadToolName = "download"
37
38//go:embed download.md
39var downloadDescription []byte
40
41func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool {
42 return &downloadTool{
43 client: &http.Client{
44 Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
45 Transport: &http.Transport{
46 MaxIdleConns: 100,
47 MaxIdleConnsPerHost: 10,
48 IdleConnTimeout: 90 * time.Second,
49 },
50 },
51 permissions: permissions,
52 workingDir: workingDir,
53 }
54}
55
56func (t *downloadTool) Name() string {
57 return DownloadToolName
58}
59
60func (t *downloadTool) Info() ToolInfo {
61 return ToolInfo{
62 Name: DownloadToolName,
63 Description: string(downloadDescription),
64 Parameters: map[string]any{
65 "url": map[string]any{
66 "type": "string",
67 "description": "The URL to download from",
68 },
69 "file_path": map[string]any{
70 "type": "string",
71 "description": "The local file path where the downloaded content should be saved",
72 },
73 "timeout": map[string]any{
74 "type": "number",
75 "description": "Optional timeout in seconds (max 600)",
76 },
77 },
78 Required: []string{"url", "file_path"},
79 }
80}
81
82func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
83 var params DownloadParams
84 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
85 return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil
86 }
87
88 if params.URL == "" {
89 return NewTextErrorResponse("URL parameter is required"), nil
90 }
91
92 if params.FilePath == "" {
93 return NewTextErrorResponse("file_path parameter is required"), nil
94 }
95
96 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
97 return NewTextErrorResponse("URL must start with http:// or https://"), nil
98 }
99
100 // Convert relative path to absolute path
101 var filePath string
102 if filepath.IsAbs(params.FilePath) {
103 filePath = params.FilePath
104 } else {
105 filePath = filepath.Join(t.workingDir, params.FilePath)
106 }
107
108 sessionID, messageID := GetContextValues(ctx)
109 if sessionID == "" || messageID == "" {
110 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files")
111 }
112
113 p := t.permissions.Request(
114 permission.CreatePermissionRequest{
115 SessionID: sessionID,
116 Path: filePath,
117 ToolName: DownloadToolName,
118 Action: "download",
119 Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
120 Params: DownloadPermissionsParams(params),
121 },
122 )
123
124 if !p {
125 return ToolResponse{}, permission.ErrorPermissionDenied
126 }
127
128 // Handle timeout with context
129 requestCtx := ctx
130 if params.Timeout > 0 {
131 maxTimeout := 600 // 10 minutes
132 if params.Timeout > maxTimeout {
133 params.Timeout = maxTimeout
134 }
135 var cancel context.CancelFunc
136 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
137 defer cancel()
138 }
139
140 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
141 if err != nil {
142 return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
143 }
144
145 req.Header.Set("User-Agent", "crush/1.0")
146
147 resp, err := t.client.Do(req)
148 if err != nil {
149 return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
150 }
151 defer resp.Body.Close()
152
153 if resp.StatusCode != http.StatusOK {
154 return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
155 }
156
157 // Check content length if available
158 maxSize := int64(100 * 1024 * 1024) // 100MB
159 if resp.ContentLength > maxSize {
160 return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
161 }
162
163 // Create parent directories if they don't exist
164 if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
165 return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
166 }
167
168 // Create the output file
169 outFile, err := os.Create(filePath)
170 if err != nil {
171 return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
172 }
173 defer outFile.Close()
174
175 // Copy data with size limit
176 limitedReader := io.LimitReader(resp.Body, maxSize)
177 bytesWritten, err := io.Copy(outFile, limitedReader)
178 if err != nil {
179 return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
180 }
181
182 // Check if we hit the size limit
183 if bytesWritten == maxSize {
184 // Clean up the file since it might be incomplete
185 os.Remove(filePath)
186 return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
187 }
188
189 contentType := resp.Header.Get("Content-Type")
190 responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
191 if contentType != "" {
192 responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
193 }
194
195 return NewTextResponse(responseMsg), nil
196}