download.go

  1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"os"
 11	"path/filepath"
 12	"strings"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"github.com/charmbracelet/crush/internal/filepathext"
 17	"github.com/charmbracelet/crush/internal/permission"
 18)
 19
 20type DownloadParams struct {
 21	URL      string `json:"url" description:"The URL to download from"`
 22	FilePath string `json:"file_path" description:"The local file path where the downloaded content should be saved"`
 23	Timeout  int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 600)"`
 24}
 25
 26type DownloadPermissionsParams struct {
 27	URL      string `json:"url"`
 28	FilePath string `json:"file_path"`
 29	Timeout  int    `json:"timeout,omitempty"`
 30}
 31
 32const DownloadToolName = "download"
 33
 34//go:embed download.md
 35var downloadDescription []byte
 36
 37func NewDownloadTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 38	if client == nil {
 39		client = &http.Client{
 40			Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
 41			Transport: &http.Transport{
 42				MaxIdleConns:        100,
 43				MaxIdleConnsPerHost: 10,
 44				IdleConnTimeout:     90 * time.Second,
 45			},
 46		}
 47	}
 48	return fantasy.NewParallelAgentTool(
 49		DownloadToolName,
 50		string(downloadDescription),
 51		func(ctx context.Context, params DownloadParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 52			if params.URL == "" {
 53				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 54			}
 55
 56			if params.FilePath == "" {
 57				return fantasy.NewTextErrorResponse("file_path parameter is required"), nil
 58			}
 59
 60			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 61				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 62			}
 63
 64			filePath := filepathext.SmartJoin(workingDir, params.FilePath)
 65			relPath, _ := filepath.Rel(workingDir, filePath)
 66			relPath = filepath.ToSlash(cmp.Or(relPath, filePath))
 67
 68			sessionID := GetSessionFromContext(ctx)
 69			if sessionID == "" {
 70				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
 71			}
 72
 73			p, err := permissions.Request(ctx,
 74				permission.CreatePermissionRequest{
 75					SessionID:   sessionID,
 76					Path:        filePath,
 77					ToolName:    DownloadToolName,
 78					Action:      "download",
 79					Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
 80					Params:      DownloadPermissionsParams(params),
 81				},
 82			)
 83			if err != nil {
 84				return fantasy.ToolResponse{}, err
 85			}
 86			if !p.Granted {
 87				if p.Message != "" {
 88					return fantasy.NewTextErrorResponse("User denied permission." + permission.UserCommentaryTag(p.Message)), nil
 89				}
 90				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
 91			}
 92
 93			// Handle timeout with context
 94			requestCtx := ctx
 95			if params.Timeout > 0 {
 96				maxTimeout := 600 // 10 minutes
 97				if params.Timeout > maxTimeout {
 98					params.Timeout = maxTimeout
 99				}
100				var cancel context.CancelFunc
101				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
102				defer cancel()
103			}
104
105			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
106			if err != nil {
107				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
108			}
109
110			req.Header.Set("User-Agent", "crush/1.0")
111
112			resp, err := client.Do(req)
113			if err != nil {
114				return fantasy.ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
115			}
116			defer resp.Body.Close()
117
118			if resp.StatusCode != http.StatusOK {
119				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
120			}
121
122			// Create parent directories if they don't exist
123			if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
124				return fantasy.ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
125			}
126
127			// Create the output file
128			outFile, err := os.Create(filePath)
129			if err != nil {
130				return fantasy.ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
131			}
132			defer outFile.Close()
133
134			// Copy data without an explicit size limit.
135			// The overall download is still constrained by the HTTP client's timeout
136			// and any upstream server limits.
137			bytesWritten, err := io.Copy(outFile, resp.Body)
138			if err != nil {
139				return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
140			}
141
142			contentType := resp.Header.Get("Content-Type")
143			responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, relPath)
144			if contentType != "" {
145				responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
146			}
147
148			return fantasy.NewTextResponse(p.AppendCommentary(responseMsg)), nil
149		})
150}