1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"os"
 11	"path/filepath"
 12	"strings"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/permission"
 16	"github.com/charmbracelet/crush/internal/proto"
 17)
 18
 19type (
 20	DownloadParams            = proto.DownloadParams
 21	DownloadPermissionsParams = proto.DownloadPermissionsParams
 22)
 23
 24type downloadTool struct {
 25	client      *http.Client
 26	permissions permission.Service
 27	workingDir  string
 28}
 29
 30const DownloadToolName = proto.DownloadToolName
 31
 32//go:embed download.md
 33var downloadDescription []byte
 34
 35func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool {
 36	return &downloadTool{
 37		client: &http.Client{
 38			Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
 39			Transport: &http.Transport{
 40				MaxIdleConns:        100,
 41				MaxIdleConnsPerHost: 10,
 42				IdleConnTimeout:     90 * time.Second,
 43			},
 44		},
 45		permissions: permissions,
 46		workingDir:  workingDir,
 47	}
 48}
 49
 50func (t *downloadTool) Name() string {
 51	return DownloadToolName
 52}
 53
 54func (t *downloadTool) Info() ToolInfo {
 55	return ToolInfo{
 56		Name:        DownloadToolName,
 57		Description: string(downloadDescription),
 58		Parameters: map[string]any{
 59			"url": map[string]any{
 60				"type":        "string",
 61				"description": "The URL to download from",
 62			},
 63			"file_path": map[string]any{
 64				"type":        "string",
 65				"description": "The local file path where the downloaded content should be saved",
 66			},
 67			"timeout": map[string]any{
 68				"type":        "number",
 69				"description": "Optional timeout in seconds (max 600)",
 70			},
 71		},
 72		Required: []string{"url", "file_path"},
 73	}
 74}
 75
 76func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 77	var params DownloadParams
 78	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 79		return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil
 80	}
 81
 82	if params.URL == "" {
 83		return NewTextErrorResponse("URL parameter is required"), nil
 84	}
 85
 86	if params.FilePath == "" {
 87		return NewTextErrorResponse("file_path parameter is required"), nil
 88	}
 89
 90	if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 91		return NewTextErrorResponse("URL must start with http:// or https://"), nil
 92	}
 93
 94	// Convert relative path to absolute path
 95	var filePath string
 96	if filepath.IsAbs(params.FilePath) {
 97		filePath = params.FilePath
 98	} else {
 99		filePath = filepath.Join(t.workingDir, params.FilePath)
100	}
101
102	sessionID, messageID := GetContextValues(ctx)
103	if sessionID == "" || messageID == "" {
104		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files")
105	}
106
107	p := t.permissions.Request(
108		permission.CreatePermissionRequest{
109			SessionID:   sessionID,
110			Path:        filePath,
111			ToolName:    DownloadToolName,
112			Action:      "download",
113			Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
114			Params:      DownloadPermissionsParams(params),
115		},
116	)
117
118	if !p {
119		return ToolResponse{}, permission.ErrorPermissionDenied
120	}
121
122	// Handle timeout with context
123	requestCtx := ctx
124	if params.Timeout > 0 {
125		maxTimeout := 600 // 10 minutes
126		if params.Timeout > maxTimeout {
127			params.Timeout = maxTimeout
128		}
129		var cancel context.CancelFunc
130		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
131		defer cancel()
132	}
133
134	req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
135	if err != nil {
136		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
137	}
138
139	req.Header.Set("User-Agent", "crush/1.0")
140
141	resp, err := t.client.Do(req)
142	if err != nil {
143		return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
144	}
145	defer resp.Body.Close()
146
147	if resp.StatusCode != http.StatusOK {
148		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
149	}
150
151	// Check content length if available
152	maxSize := int64(100 * 1024 * 1024) // 100MB
153	if resp.ContentLength > maxSize {
154		return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
155	}
156
157	// Create parent directories if they don't exist
158	if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
159		return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
160	}
161
162	// Create the output file
163	outFile, err := os.Create(filePath)
164	if err != nil {
165		return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
166	}
167	defer outFile.Close()
168
169	// Copy data with size limit
170	limitedReader := io.LimitReader(resp.Body, maxSize)
171	bytesWritten, err := io.Copy(outFile, limitedReader)
172	if err != nil {
173		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
174	}
175
176	// Check if we hit the size limit
177	if bytesWritten == maxSize {
178		// Clean up the file since it might be incomplete
179		os.Remove(filePath)
180		return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
181	}
182
183	contentType := resp.Header.Get("Content-Type")
184	responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
185	if contentType != "" {
186		responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
187	}
188
189	return NewTextResponse(responseMsg), nil
190}