1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/permission"
 15)
 16
 17type DownloadParams struct {
 18	URL      string `json:"url"`
 19	FilePath string `json:"file_path"`
 20	Timeout  int    `json:"timeout,omitempty"`
 21}
 22
 23type DownloadPermissionsParams struct {
 24	URL      string `json:"url"`
 25	FilePath string `json:"file_path"`
 26	Timeout  int    `json:"timeout,omitempty"`
 27}
 28
 29type downloadTool struct {
 30	client      *http.Client
 31	permissions permission.Service
 32	workingDir  string
 33}
 34
 35const (
 36	DownloadToolName        = "download"
 37	downloadToolDescription = `Downloads binary data from a URL and saves it to a local file.
 38
 39WHEN TO USE THIS TOOL:
 40- Use when you need to download files, images, or other binary data from URLs
 41- Helpful for downloading assets, documents, or any file type
 42- Useful for saving remote content locally for processing or storage
 43
 44HOW TO USE:
 45- Provide the URL to download from
 46- Specify the local file path where the content should be saved
 47- Optionally set a timeout for the request
 48
 49FEATURES:
 50- Downloads any file type (binary or text)
 51- Automatically creates parent directories if they don't exist
 52- Handles large files efficiently with streaming
 53- Sets reasonable timeouts to prevent hanging
 54- Validates input parameters before making requests
 55
 56LIMITATIONS:
 57- Maximum file size is 100MB
 58- Only supports HTTP and HTTPS protocols
 59- Cannot handle authentication or cookies
 60- Some websites may block automated requests
 61- Will overwrite existing files without warning
 62
 63TIPS:
 64- Use absolute paths or paths relative to the working directory
 65- Set appropriate timeouts for large files or slow connections`
 66)
 67
 68func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool {
 69	return &downloadTool{
 70		client: &http.Client{
 71			Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
 72			Transport: &http.Transport{
 73				MaxIdleConns:        100,
 74				MaxIdleConnsPerHost: 10,
 75				IdleConnTimeout:     90 * time.Second,
 76			},
 77		},
 78		permissions: permissions,
 79		workingDir:  workingDir,
 80	}
 81}
 82
 83func (t *downloadTool) Name() string {
 84	return DownloadToolName
 85}
 86
 87func (t *downloadTool) Info() ToolInfo {
 88	return ToolInfo{
 89		Name:        DownloadToolName,
 90		Description: downloadToolDescription,
 91		Parameters: map[string]any{
 92			"url": map[string]any{
 93				"type":        "string",
 94				"description": "The URL to download from",
 95			},
 96			"file_path": map[string]any{
 97				"type":        "string",
 98				"description": "The local file path where the downloaded content should be saved",
 99			},
100			"timeout": map[string]any{
101				"type":        "number",
102				"description": "Optional timeout in seconds (max 600)",
103			},
104		},
105		Required: []string{"url", "file_path"},
106	}
107}
108
109func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
110	var params DownloadParams
111	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
112		return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil
113	}
114
115	if params.URL == "" {
116		return NewTextErrorResponse("URL parameter is required"), nil
117	}
118
119	if params.FilePath == "" {
120		return NewTextErrorResponse("file_path parameter is required"), nil
121	}
122
123	if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
124		return NewTextErrorResponse("URL must start with http:// or https://"), nil
125	}
126
127	// Convert relative path to absolute path
128	var filePath string
129	if filepath.IsAbs(params.FilePath) {
130		filePath = params.FilePath
131	} else {
132		filePath = filepath.Join(t.workingDir, params.FilePath)
133	}
134
135	sessionID, messageID := GetContextValues(ctx)
136	if sessionID == "" || messageID == "" {
137		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files")
138	}
139
140	p := t.permissions.Request(
141		permission.CreatePermissionRequest{
142			SessionID:   sessionID,
143			Path:        filePath,
144			ToolName:    DownloadToolName,
145			Action:      "download",
146			Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
147			Params:      DownloadPermissionsParams(params),
148		},
149	)
150
151	if !p {
152		return ToolResponse{}, permission.ErrorPermissionDenied
153	}
154
155	// Handle timeout with context
156	requestCtx := ctx
157	if params.Timeout > 0 {
158		maxTimeout := 600 // 10 minutes
159		if params.Timeout > maxTimeout {
160			params.Timeout = maxTimeout
161		}
162		var cancel context.CancelFunc
163		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
164		defer cancel()
165	}
166
167	req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
168	if err != nil {
169		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
170	}
171
172	req.Header.Set("User-Agent", "crush/1.0")
173
174	resp, err := t.client.Do(req)
175	if err != nil {
176		return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
177	}
178	defer resp.Body.Close()
179
180	if resp.StatusCode != http.StatusOK {
181		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
182	}
183
184	// Check content length if available
185	maxSize := int64(100 * 1024 * 1024) // 100MB
186	if resp.ContentLength > maxSize {
187		return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
188	}
189
190	// Create parent directories if they don't exist
191	if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
192		return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
193	}
194
195	// Create the output file
196	outFile, err := os.Create(filePath)
197	if err != nil {
198		return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
199	}
200	defer outFile.Close()
201
202	// Copy data with size limit
203	limitedReader := io.LimitReader(resp.Body, maxSize)
204	bytesWritten, err := io.Copy(outFile, limitedReader)
205	if err != nil {
206		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
207	}
208
209	// Check if we hit the size limit
210	if bytesWritten == maxSize {
211		// Clean up the file since it might be incomplete
212		os.Remove(filePath)
213		return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
214	}
215
216	contentType := resp.Header.Get("Content-Type")
217	responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
218	if contentType != "" {
219		responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
220	}
221
222	return NewTextResponse(responseMsg), nil
223}