1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"os"
 11	"path/filepath"
 12	"strings"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/permission"
 16)
 17
 18type DownloadParams struct {
 19	URL      string `json:"url"`
 20	FilePath string `json:"file_path"`
 21	Timeout  int    `json:"timeout,omitempty"`
 22}
 23
 24type DownloadPermissionsParams struct {
 25	URL      string `json:"url"`
 26	FilePath string `json:"file_path"`
 27	Timeout  int    `json:"timeout,omitempty"`
 28}
 29
 30type downloadTool struct {
 31	client      *http.Client
 32	permissions permission.Service
 33	workingDir  string
 34}
 35
 36const DownloadToolName = "download"
 37
 38//go:embed download.md
 39var downloadDescription []byte
 40
 41func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool {
 42	return &downloadTool{
 43		client: &http.Client{
 44			Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
 45			Transport: &http.Transport{
 46				MaxIdleConns:        100,
 47				MaxIdleConnsPerHost: 10,
 48				IdleConnTimeout:     90 * time.Second,
 49			},
 50		},
 51		permissions: permissions,
 52		workingDir:  workingDir,
 53	}
 54}
 55
 56func (t *downloadTool) Name() string {
 57	return DownloadToolName
 58}
 59
 60func (t *downloadTool) Info() ToolInfo {
 61	return ToolInfo{
 62		Name:        DownloadToolName,
 63		Description: string(downloadDescription),
 64		Parameters: map[string]any{
 65			"url": map[string]any{
 66				"type":        "string",
 67				"description": "The URL to download from",
 68			},
 69			"file_path": map[string]any{
 70				"type":        "string",
 71				"description": "The local file path where the downloaded content should be saved",
 72			},
 73			"timeout": map[string]any{
 74				"type":        "number",
 75				"description": "Optional timeout in seconds (max 600)",
 76			},
 77		},
 78		Required: []string{"url", "file_path"},
 79	}
 80}
 81
 82func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 83	var params DownloadParams
 84	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 85		return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil
 86	}
 87
 88	if params.URL == "" {
 89		return NewTextErrorResponse("URL parameter is required"), nil
 90	}
 91
 92	if params.FilePath == "" {
 93		return NewTextErrorResponse("file_path parameter is required"), nil
 94	}
 95
 96	if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 97		return NewTextErrorResponse("URL must start with http:// or https://"), nil
 98	}
 99
100	// Convert relative path to absolute path
101	var filePath string
102	if filepath.IsAbs(params.FilePath) {
103		filePath = params.FilePath
104	} else {
105		filePath = filepath.Join(t.workingDir, params.FilePath)
106	}
107
108	sessionID, messageID := GetContextValues(ctx)
109	if sessionID == "" || messageID == "" {
110		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files")
111	}
112
113	p := t.permissions.Request(
114		permission.CreatePermissionRequest{
115			SessionID:   sessionID,
116			Path:        filePath,
117			ToolName:    DownloadToolName,
118			Action:      "download",
119			Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
120			Params:      DownloadPermissionsParams(params),
121		},
122	)
123
124	if !p {
125		return ToolResponse{}, permission.ErrorPermissionDenied
126	}
127
128	// Handle timeout with context
129	requestCtx := ctx
130	if params.Timeout > 0 {
131		maxTimeout := 600 // 10 minutes
132		if params.Timeout > maxTimeout {
133			params.Timeout = maxTimeout
134		}
135		var cancel context.CancelFunc
136		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
137		defer cancel()
138	}
139
140	req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
141	if err != nil {
142		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
143	}
144
145	req.Header.Set("User-Agent", "crush/1.0")
146
147	resp, err := t.client.Do(req)
148	if err != nil {
149		return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
150	}
151	defer resp.Body.Close()
152
153	if resp.StatusCode != http.StatusOK {
154		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
155	}
156
157	// Check content length if available
158	maxSize := int64(100 * 1024 * 1024) // 100MB
159	if resp.ContentLength > maxSize {
160		return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
161	}
162
163	// Create parent directories if they don't exist
164	if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
165		return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
166	}
167
168	// Create the output file
169	outFile, err := os.Create(filePath)
170	if err != nil {
171		return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
172	}
173	defer outFile.Close()
174
175	// Copy data with size limit
176	limitedReader := io.LimitReader(resp.Body, maxSize)
177	bytesWritten, err := io.Copy(outFile, limitedReader)
178	if err != nil {
179		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
180	}
181
182	// Check if we hit the size limit
183	if bytesWritten == maxSize {
184		// Clean up the file since it might be incomplete
185		os.Remove(filePath)
186		return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
187	}
188
189	contentType := resp.Header.Get("Content-Type")
190	responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
191	if contentType != "" {
192		responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
193	}
194
195	return NewTextResponse(responseMsg), nil
196}