1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "os"
11 "path/filepath"
12 "strings"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/permission"
16 "github.com/charmbracelet/crush/internal/proto"
17)
18
19type (
20 DownloadParams = proto.DownloadParams
21 DownloadPermissionsParams = proto.DownloadPermissionsParams
22)
23
24type downloadTool struct {
25 client *http.Client
26 permissions permission.Service
27 workingDir string
28}
29
30const DownloadToolName = proto.DownloadToolName
31
32//go:embed download.md
33var downloadDescription []byte
34
35func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool {
36 return &downloadTool{
37 client: &http.Client{
38 Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
39 Transport: &http.Transport{
40 MaxIdleConns: 100,
41 MaxIdleConnsPerHost: 10,
42 IdleConnTimeout: 90 * time.Second,
43 },
44 },
45 permissions: permissions,
46 workingDir: workingDir,
47 }
48}
49
50func (t *downloadTool) Name() string {
51 return DownloadToolName
52}
53
54func (t *downloadTool) Info() ToolInfo {
55 return ToolInfo{
56 Name: DownloadToolName,
57 Description: string(downloadDescription),
58 Parameters: map[string]any{
59 "url": map[string]any{
60 "type": "string",
61 "description": "The URL to download from",
62 },
63 "file_path": map[string]any{
64 "type": "string",
65 "description": "The local file path where the downloaded content should be saved",
66 },
67 "timeout": map[string]any{
68 "type": "number",
69 "description": "Optional timeout in seconds (max 600)",
70 },
71 },
72 Required: []string{"url", "file_path"},
73 }
74}
75
76func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
77 var params DownloadParams
78 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
79 return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil
80 }
81
82 if params.URL == "" {
83 return NewTextErrorResponse("URL parameter is required"), nil
84 }
85
86 if params.FilePath == "" {
87 return NewTextErrorResponse("file_path parameter is required"), nil
88 }
89
90 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
91 return NewTextErrorResponse("URL must start with http:// or https://"), nil
92 }
93
94 // Convert relative path to absolute path
95 var filePath string
96 if filepath.IsAbs(params.FilePath) {
97 filePath = params.FilePath
98 } else {
99 filePath = filepath.Join(t.workingDir, params.FilePath)
100 }
101
102 sessionID, messageID := GetContextValues(ctx)
103 if sessionID == "" || messageID == "" {
104 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files")
105 }
106
107 p := t.permissions.Request(
108 permission.CreatePermissionRequest{
109 SessionID: sessionID,
110 Path: filePath,
111 ToolName: DownloadToolName,
112 Action: "download",
113 Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
114 Params: DownloadPermissionsParams(params),
115 },
116 )
117
118 if !p {
119 return ToolResponse{}, permission.ErrorPermissionDenied
120 }
121
122 // Handle timeout with context
123 requestCtx := ctx
124 if params.Timeout > 0 {
125 maxTimeout := 600 // 10 minutes
126 if params.Timeout > maxTimeout {
127 params.Timeout = maxTimeout
128 }
129 var cancel context.CancelFunc
130 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
131 defer cancel()
132 }
133
134 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
135 if err != nil {
136 return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
137 }
138
139 req.Header.Set("User-Agent", "crush/1.0")
140
141 resp, err := t.client.Do(req)
142 if err != nil {
143 return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
144 }
145 defer resp.Body.Close()
146
147 if resp.StatusCode != http.StatusOK {
148 return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
149 }
150
151 // Check content length if available
152 maxSize := int64(100 * 1024 * 1024) // 100MB
153 if resp.ContentLength > maxSize {
154 return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
155 }
156
157 // Create parent directories if they don't exist
158 if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
159 return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
160 }
161
162 // Create the output file
163 outFile, err := os.Create(filePath)
164 if err != nil {
165 return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
166 }
167 defer outFile.Close()
168
169 // Copy data with size limit
170 limitedReader := io.LimitReader(resp.Body, maxSize)
171 bytesWritten, err := io.Copy(outFile, limitedReader)
172 if err != nil {
173 return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
174 }
175
176 // Check if we hit the size limit
177 if bytesWritten == maxSize {
178 // Clean up the file since it might be incomplete
179 os.Remove(filePath)
180 return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
181 }
182
183 contentType := resp.Header.Get("Content-Type")
184 responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
185 if contentType != "" {
186 responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
187 }
188
189 return NewTextResponse(responseMsg), nil
190}