1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/crush/internal/filepathext"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19type DownloadParams struct {
20 URL string `json:"url" description:"The URL to download from"`
21 FilePath string `json:"file_path" description:"The local file path where the downloaded content should be saved"`
22 Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 600)"`
23}
24
25type DownloadPermissionsParams struct {
26 URL string `json:"url"`
27 FilePath string `json:"file_path"`
28 Timeout int `json:"timeout,omitempty"`
29}
30
31const DownloadToolName = "download"
32
33//go:embed download.md
34var downloadDescription []byte
35
36func NewDownloadTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
37 if client == nil {
38 client = &http.Client{
39 Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
40 Transport: &http.Transport{
41 MaxIdleConns: 100,
42 MaxIdleConnsPerHost: 10,
43 IdleConnTimeout: 90 * time.Second,
44 },
45 }
46 }
47 return fantasy.NewAgentTool(
48 DownloadToolName,
49 string(downloadDescription),
50 func(ctx context.Context, params DownloadParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
51 if params.URL == "" {
52 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
53 }
54
55 if params.FilePath == "" {
56 return fantasy.NewTextErrorResponse("file_path parameter is required"), nil
57 }
58
59 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
60 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
61 }
62
63 filePath := filepathext.SmartJoin(workingDir, params.FilePath)
64
65 sessionID := GetSessionFromContext(ctx)
66 if sessionID == "" {
67 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
68 }
69
70 p := permissions.Request(
71 permission.CreatePermissionRequest{
72 SessionID: sessionID,
73 Path: filePath,
74 ToolName: DownloadToolName,
75 Action: "download",
76 Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
77 Params: DownloadPermissionsParams(params),
78 },
79 )
80
81 if !p {
82 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
83 }
84
85 // Handle timeout with context
86 requestCtx := ctx
87 if params.Timeout > 0 {
88 maxTimeout := 600 // 10 minutes
89 if params.Timeout > maxTimeout {
90 params.Timeout = maxTimeout
91 }
92 var cancel context.CancelFunc
93 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
94 defer cancel()
95 }
96
97 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
98 if err != nil {
99 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
100 }
101
102 req.Header.Set("User-Agent", "crush/1.0")
103
104 resp, err := client.Do(req)
105 if err != nil {
106 return fantasy.ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
107 }
108 defer resp.Body.Close()
109
110 if resp.StatusCode != http.StatusOK {
111 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
112 }
113
114 // Check content length if available
115 maxSize := int64(100 * 1024 * 1024) // 100MB
116 if resp.ContentLength > maxSize {
117 return fantasy.NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
118 }
119
120 // Create parent directories if they don't exist
121 if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
122 return fantasy.ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
123 }
124
125 // Create the output file
126 outFile, err := os.Create(filePath)
127 if err != nil {
128 return fantasy.ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
129 }
130 defer outFile.Close()
131
132 // Copy data with size limit
133 limitedReader := io.LimitReader(resp.Body, maxSize)
134 bytesWritten, err := io.Copy(outFile, limitedReader)
135 if err != nil {
136 return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
137 }
138
139 // Check if we hit the size limit
140 if bytesWritten == maxSize {
141 // Clean up the file since it might be incomplete
142 os.Remove(filePath)
143 return fantasy.NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
144 }
145
146 contentType := resp.Header.Get("Content-Type")
147 responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
148 if contentType != "" {
149 responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
150 }
151
152 return fantasy.NewTextResponse(responseMsg), nil
153 })
154}