download.go

  1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"html/template"
  9	"io"
 10	"net/http"
 11	"os"
 12	"path/filepath"
 13	"strings"
 14	"time"
 15
 16	"charm.land/fantasy"
 17	"github.com/charmbracelet/crush/internal/filepathext"
 18	"github.com/charmbracelet/crush/internal/permission"
 19)
 20
 21type DownloadParams struct {
 22	URL      string `json:"url" description:"The URL to download from"`
 23	FilePath string `json:"file_path" description:"The local file path where the downloaded content should be saved"`
 24	Timeout  int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 600)"`
 25}
 26
 27type DownloadPermissionsParams struct {
 28	URL      string `json:"url"`
 29	FilePath string `json:"file_path"`
 30	Timeout  int    `json:"timeout,omitempty"`
 31}
 32
 33const DownloadToolName = "download"
 34
 35//go:embed download.md.tpl
 36var downloadDescriptionTmpl []byte
 37
 38var downloadDescriptionTpl = template.Must(
 39	template.New("downloadDescription").
 40		Parse(string(downloadDescriptionTmpl)),
 41)
 42
 43type downloadDescriptionData struct {
 44	MaxDownloadTimeout int
 45}
 46
 47func downloadDescription() string {
 48	return renderTemplate(downloadDescriptionTpl, downloadDescriptionData{
 49		MaxDownloadTimeout: 600,
 50	})
 51}
 52
 53func NewDownloadTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 54	if client == nil {
 55		transport := http.DefaultTransport.(*http.Transport).Clone()
 56		transport.MaxIdleConns = 100
 57		transport.MaxIdleConnsPerHost = 10
 58		transport.IdleConnTimeout = 90 * time.Second
 59
 60		client = &http.Client{
 61			Timeout:   5 * time.Minute, // Default 5 minute timeout for downloads
 62			Transport: transport,
 63		}
 64	}
 65	return fantasy.NewParallelAgentTool(
 66		DownloadToolName,
 67		downloadDescription(),
 68		func(ctx context.Context, params DownloadParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 69			if params.URL == "" {
 70				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 71			}
 72
 73			if params.FilePath == "" {
 74				return fantasy.NewTextErrorResponse("file_path parameter is required"), nil
 75			}
 76
 77			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 78				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 79			}
 80
 81			filePath := filepathext.SmartJoin(workingDir, params.FilePath)
 82			relPath, _ := filepath.Rel(workingDir, filePath)
 83			relPath = filepath.ToSlash(cmp.Or(relPath, filePath))
 84
 85			sessionID := GetSessionFromContext(ctx)
 86			if sessionID == "" {
 87				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
 88			}
 89
 90			p, err := permissions.Request(
 91				ctx,
 92				permission.CreatePermissionRequest{
 93					SessionID:   sessionID,
 94					Path:        filePath,
 95					ToolName:    DownloadToolName,
 96					Action:      "download",
 97					Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
 98					Params:      DownloadPermissionsParams(params),
 99				},
100			)
101			if err != nil {
102				return fantasy.ToolResponse{}, err
103			}
104			if !p {
105				return NewPermissionDeniedResponse(), nil
106			}
107
108			// Handle timeout with context
109			requestCtx := ctx
110			if params.Timeout > 0 {
111				maxTimeout := 600 // 10 minutes
112				if params.Timeout > maxTimeout {
113					params.Timeout = maxTimeout
114				}
115				var cancel context.CancelFunc
116				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
117				defer cancel()
118			}
119
120			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
121			if err != nil {
122				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
123			}
124
125			req.Header.Set("User-Agent", "crush/1.0")
126
127			resp, err := client.Do(req)
128			if err != nil {
129				return fantasy.ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
130			}
131			defer resp.Body.Close()
132
133			if resp.StatusCode != http.StatusOK {
134				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
135			}
136
137			// Create parent directories if they don't exist
138			if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
139				return fantasy.ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
140			}
141
142			// Create the output file
143			outFile, err := os.Create(filePath)
144			if err != nil {
145				return fantasy.ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
146			}
147			defer outFile.Close()
148
149			// Copy data without an explicit size limit.
150			// The overall download is still constrained by the HTTP client's timeout
151			// and any upstream server limits.
152			bytesWritten, err := io.Copy(outFile, resp.Body)
153			if err != nil {
154				return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
155			}
156
157			contentType := resp.Header.Get("Content-Type")
158			responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, relPath)
159			if contentType != "" {
160				responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
161			}
162
163			return fantasy.NewTextResponse(responseMsg), nil
164		},
165	)
166}