download.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/crush/internal/filepathext"
 16	"github.com/charmbracelet/crush/internal/permission"
 17)
 18
 19type DownloadParams struct {
 20	URL      string `json:"url" description:"The URL to download from"`
 21	FilePath string `json:"file_path" description:"The local file path where the downloaded content should be saved"`
 22	Timeout  int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 600)"`
 23}
 24
 25type DownloadPermissionsParams struct {
 26	URL      string `json:"url"`
 27	FilePath string `json:"file_path"`
 28	Timeout  int    `json:"timeout,omitempty"`
 29}
 30
 31const DownloadToolName = "download"
 32
 33//go:embed download.md
 34var downloadDescription []byte
 35
 36func NewDownloadTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 37	if client == nil {
 38		client = &http.Client{
 39			Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
 40			Transport: &http.Transport{
 41				MaxIdleConns:        100,
 42				MaxIdleConnsPerHost: 10,
 43				IdleConnTimeout:     90 * time.Second,
 44			},
 45		}
 46	}
 47	return fantasy.NewAgentTool(
 48		DownloadToolName,
 49		string(downloadDescription),
 50		func(ctx context.Context, params DownloadParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 51			if params.URL == "" {
 52				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 53			}
 54
 55			if params.FilePath == "" {
 56				return fantasy.NewTextErrorResponse("file_path parameter is required"), nil
 57			}
 58
 59			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 60				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 61			}
 62
 63			filePath := filepathext.SmartJoin(workingDir, params.FilePath)
 64
 65			sessionID := GetSessionFromContext(ctx)
 66			if sessionID == "" {
 67				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files")
 68			}
 69
 70			p := permissions.Request(
 71				permission.CreatePermissionRequest{
 72					SessionID:   sessionID,
 73					Path:        filePath,
 74					ToolName:    DownloadToolName,
 75					Action:      "download",
 76					Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
 77					Params:      DownloadPermissionsParams(params),
 78				},
 79			)
 80
 81			if !p {
 82				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
 83			}
 84
 85			// Handle timeout with context
 86			requestCtx := ctx
 87			if params.Timeout > 0 {
 88				maxTimeout := 600 // 10 minutes
 89				if params.Timeout > maxTimeout {
 90					params.Timeout = maxTimeout
 91				}
 92				var cancel context.CancelFunc
 93				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 94				defer cancel()
 95			}
 96
 97			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
 98			if err != nil {
 99				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
100			}
101
102			req.Header.Set("User-Agent", "crush/1.0")
103
104			resp, err := client.Do(req)
105			if err != nil {
106				return fantasy.ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
107			}
108			defer resp.Body.Close()
109
110			if resp.StatusCode != http.StatusOK {
111				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
112			}
113
114			// Check content length if available
115			maxSize := int64(100 * 1024 * 1024) // 100MB
116			if resp.ContentLength > maxSize {
117				return fantasy.NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
118			}
119
120			// Create parent directories if they don't exist
121			if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
122				return fantasy.ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
123			}
124
125			// Create the output file
126			outFile, err := os.Create(filePath)
127			if err != nil {
128				return fantasy.ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
129			}
130			defer outFile.Close()
131
132			// Copy data with size limit
133			limitedReader := io.LimitReader(resp.Body, maxSize)
134			bytesWritten, err := io.Copy(outFile, limitedReader)
135			if err != nil {
136				return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
137			}
138
139			// Check if we hit the size limit
140			if bytesWritten == maxSize {
141				// Clean up the file since it might be incomplete
142				os.Remove(filePath)
143				return fantasy.NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
144			}
145
146			contentType := resp.Header.Get("Content-Type")
147			responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
148			if contentType != "" {
149				responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
150			}
151
152			return fantasy.NewTextResponse(responseMsg), nil
153		})
154}