1package tools
2
3import (
4 "cmp"
5 "context"
6 _ "embed"
7 "fmt"
8 "html/template"
9 "io"
10 "net/http"
11 "os"
12 "path/filepath"
13 "strings"
14 "time"
15
16 "charm.land/fantasy"
17 "github.com/charmbracelet/crush/internal/filepathext"
18 "github.com/charmbracelet/crush/internal/permission"
19)
20
21type DownloadParams struct {
22 URL string `json:"url" description:"The URL to download from"`
23 FilePath string `json:"file_path" description:"The local file path where the downloaded content should be saved"`
24 Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 600)"`
25}
26
27type DownloadPermissionsParams struct {
28 URL string `json:"url"`
29 FilePath string `json:"file_path"`
30 Timeout int `json:"timeout,omitempty"`
31}
32
33const DownloadToolName = "download"
34
35//go:embed download.md.tpl
36var downloadDescriptionTmpl []byte
37
38var downloadDescriptionTpl = template.Must(
39 template.New("downloadDescription").
40 Parse(string(downloadDescriptionTmpl)),
41)
42
43type downloadDescriptionData struct {
44 MaxDownloadTimeout int
45}
46
47func downloadDescription() string {
48 return renderTemplate(downloadDescriptionTpl, downloadDescriptionData{
49 MaxDownloadTimeout: 600,
50 })
51}
52
53func NewDownloadTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
54 if client == nil {
55 transport := http.DefaultTransport.(*http.Transport).Clone()
56 transport.MaxIdleConns = 100
57 transport.MaxIdleConnsPerHost = 10
58 transport.IdleConnTimeout = 90 * time.Second
59
60 client = &http.Client{
61 Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
62 Transport: transport,
63 }
64 }
65 return fantasy.NewParallelAgentTool(
66 DownloadToolName,
67 downloadDescription(),
68 func(ctx context.Context, params DownloadParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
69 if params.URL == "" {
70 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
71 }
72
73 if params.FilePath == "" {
74 return fantasy.NewTextErrorResponse("file_path parameter is required"), nil
75 }
76
77 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
78 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
79 }
80
81 filePath := filepathext.SmartJoin(workingDir, params.FilePath)
82 relPath, _ := filepath.Rel(workingDir, filePath)
83 relPath = filepath.ToSlash(cmp.Or(relPath, filePath))
84
85 sessionID := GetSessionFromContext(ctx)
86 if sessionID == "" {
87 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
88 }
89
90 p, err := permissions.Request(ctx,
91 permission.CreatePermissionRequest{
92 SessionID: sessionID,
93 Path: filePath,
94 ToolName: DownloadToolName,
95 Action: "download",
96 Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
97 Params: DownloadPermissionsParams(params),
98 },
99 )
100 if err != nil {
101 return fantasy.ToolResponse{}, err
102 }
103 if !p {
104 return NewPermissionDeniedResponse(), nil
105 }
106
107 // Handle timeout with context
108 requestCtx := ctx
109 if params.Timeout > 0 {
110 maxTimeout := 600 // 10 minutes
111 if params.Timeout > maxTimeout {
112 params.Timeout = maxTimeout
113 }
114 var cancel context.CancelFunc
115 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
116 defer cancel()
117 }
118
119 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
120 if err != nil {
121 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
122 }
123
124 req.Header.Set("User-Agent", "crush/1.0")
125
126 resp, err := client.Do(req)
127 if err != nil {
128 return fantasy.ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
129 }
130 defer resp.Body.Close()
131
132 if resp.StatusCode != http.StatusOK {
133 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
134 }
135
136 // Create parent directories if they don't exist
137 if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
138 return fantasy.ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
139 }
140
141 // Create the output file
142 outFile, err := os.Create(filePath)
143 if err != nil {
144 return fantasy.ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
145 }
146 defer outFile.Close()
147
148 // Copy data without an explicit size limit.
149 // The overall download is still constrained by the HTTP client's timeout
150 // and any upstream server limits.
151 bytesWritten, err := io.Copy(outFile, resp.Body)
152 if err != nil {
153 return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
154 }
155
156 contentType := resp.Header.Get("Content-Type")
157 responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, relPath)
158 if contentType != "" {
159 responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
160 }
161
162 return fantasy.NewTextResponse(responseMsg), nil
163 })
164}