1package tools
2
3import (
4 "cmp"
5 "context"
6 _ "embed"
7 "fmt"
8 "html/template"
9 "io"
10 "net/http"
11 "os"
12 "path/filepath"
13 "strings"
14 "time"
15
16 "charm.land/fantasy"
17 "github.com/charmbracelet/crush/internal/filepathext"
18 "github.com/charmbracelet/crush/internal/permission"
19)
20
21type DownloadParams struct {
22 URL string `json:"url" description:"The URL to download from"`
23 FilePath string `json:"file_path" description:"The local file path where the downloaded content should be saved"`
24 Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 600)"`
25}
26
27type DownloadPermissionsParams struct {
28 URL string `json:"url"`
29 FilePath string `json:"file_path"`
30 Timeout int `json:"timeout,omitempty"`
31}
32
33const DownloadToolName = "download"
34
35//go:embed download.md.tpl
36var downloadDescriptionTmpl []byte
37
38var downloadDescriptionTpl = template.Must(
39 template.New("downloadDescription").
40 Parse(string(downloadDescriptionTmpl)),
41)
42
43type downloadDescriptionData struct {
44 MaxDownloadTimeout int
45}
46
47func downloadDescription() string {
48 return renderTemplate(downloadDescriptionTpl, downloadDescriptionData{
49 MaxDownloadTimeout: 600,
50 })
51}
52
53func NewDownloadTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
54 if client == nil {
55 transport := http.DefaultTransport.(*http.Transport).Clone()
56 transport.MaxIdleConns = 100
57 transport.MaxIdleConnsPerHost = 10
58 transport.IdleConnTimeout = 90 * time.Second
59
60 client = &http.Client{
61 Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
62 Transport: transport,
63 }
64 }
65 return fantasy.NewParallelAgentTool(
66 DownloadToolName,
67 downloadDescription(),
68 func(ctx context.Context, params DownloadParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
69 if params.URL == "" {
70 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
71 }
72
73 if params.FilePath == "" {
74 return fantasy.NewTextErrorResponse("file_path parameter is required"), nil
75 }
76
77 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
78 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
79 }
80
81 filePath := filepathext.SmartJoin(workingDir, params.FilePath)
82 relPath, _ := filepath.Rel(workingDir, filePath)
83 relPath = filepath.ToSlash(cmp.Or(relPath, filePath))
84
85 sessionID := GetSessionFromContext(ctx)
86 if sessionID == "" {
87 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
88 }
89
90 p, err := permissions.Request(
91 ctx,
92 permission.CreatePermissionRequest{
93 SessionID: sessionID,
94 Path: filePath,
95 ToolName: DownloadToolName,
96 Action: "download",
97 Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
98 Params: DownloadPermissionsParams(params),
99 },
100 )
101 if err != nil {
102 return fantasy.ToolResponse{}, err
103 }
104 if !p {
105 return NewPermissionDeniedResponse(), nil
106 }
107
108 // Handle timeout with context
109 requestCtx := ctx
110 if params.Timeout > 0 {
111 maxTimeout := 600 // 10 minutes
112 if params.Timeout > maxTimeout {
113 params.Timeout = maxTimeout
114 }
115 var cancel context.CancelFunc
116 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
117 defer cancel()
118 }
119
120 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
121 if err != nil {
122 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
123 }
124
125 req.Header.Set("User-Agent", "crush/1.0")
126
127 resp, err := client.Do(req)
128 if err != nil {
129 return fantasy.ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
130 }
131 defer resp.Body.Close()
132
133 if resp.StatusCode != http.StatusOK {
134 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
135 }
136
137 // Create parent directories if they don't exist
138 if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
139 return fantasy.ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
140 }
141
142 // Create the output file
143 outFile, err := os.Create(filePath)
144 if err != nil {
145 return fantasy.ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
146 }
147 defer outFile.Close()
148
149 // Copy data without an explicit size limit.
150 // The overall download is still constrained by the HTTP client's timeout
151 // and any upstream server limits.
152 bytesWritten, err := io.Copy(outFile, resp.Body)
153 if err != nil {
154 return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
155 }
156
157 contentType := resp.Header.Get("Content-Type")
158 responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, relPath)
159 if contentType != "" {
160 responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
161 }
162
163 return fantasy.NewTextResponse(responseMsg), nil
164 },
165 )
166}