diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 4e8070158652f04205f51ca6d38ba1f5db81ef2a..907a2348f838aa2f2ba6792db9b768eb656904a8 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -102,6 +102,7 @@ func NewAgent( cwd := cfg.WorkingDir() allTools := []tools.BaseTool{ tools.NewBashTool(permissions, cwd), + tools.NewDownloadTool(permissions, cwd), tools.NewEditTool(lspClients, permissions, history, cwd), tools.NewFetchTool(permissions, cwd), tools.NewGlobTool(cwd), diff --git a/internal/llm/tools/download.go b/internal/llm/tools/download.go new file mode 100644 index 0000000000000000000000000000000000000000..fc0c33a846305d002df2bd6e21a54cbe088a511e --- /dev/null +++ b/internal/llm/tools/download.go @@ -0,0 +1,223 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/permission" +) + +type DownloadParams struct { + URL string `json:"url"` + FilePath string `json:"file_path"` + Timeout int `json:"timeout,omitempty"` +} + +type DownloadPermissionsParams struct { + URL string `json:"url"` + FilePath string `json:"file_path"` + Timeout int `json:"timeout,omitempty"` +} + +type downloadTool struct { + client *http.Client + permissions permission.Service + workingDir string +} + +const ( + DownloadToolName = "download" + downloadToolDescription = `Downloads binary data from a URL and saves it to a local file. + +WHEN TO USE THIS TOOL: +- Use when you need to download files, images, or other binary data from URLs +- Helpful for downloading assets, documents, or any file type +- Useful for saving remote content locally for processing or storage + +HOW TO USE: +- Provide the URL to download from +- Specify the local file path where the content should be saved +- Optionally set a timeout for the request + +FEATURES: +- Downloads any file type (binary or text) +- Automatically creates parent directories if they don't exist +- Handles large files efficiently with streaming +- Sets reasonable timeouts to prevent hanging +- Validates input parameters before making requests + +LIMITATIONS: +- Maximum file size is 100MB +- Only supports HTTP and HTTPS protocols +- Cannot handle authentication or cookies +- Some websites may block automated requests +- Will overwrite existing files without warning + +TIPS: +- Use absolute paths or paths relative to the working directory +- Set appropriate timeouts for large files or slow connections` +) + +func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool { + return &downloadTool{ + client: &http.Client{ + Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + }, + }, + permissions: permissions, + workingDir: workingDir, + } +} + +func (t *downloadTool) Name() string { + return DownloadToolName +} + +func (t *downloadTool) Info() ToolInfo { + return ToolInfo{ + Name: DownloadToolName, + Description: downloadToolDescription, + Parameters: map[string]any{ + "url": map[string]any{ + "type": "string", + "description": "The URL to download from", + }, + "file_path": map[string]any{ + "type": "string", + "description": "The local file path where the downloaded content should be saved", + }, + "timeout": map[string]any{ + "type": "number", + "description": "Optional timeout in seconds (max 600)", + }, + }, + Required: []string{"url", "file_path"}, + } +} + +func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + var params DownloadParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil + } + + if params.URL == "" { + return NewTextErrorResponse("URL parameter is required"), nil + } + + if params.FilePath == "" { + return NewTextErrorResponse("file_path parameter is required"), nil + } + + if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") { + return NewTextErrorResponse("URL must start with http:// or https://"), nil + } + + // Convert relative path to absolute path + var filePath string + if filepath.IsAbs(params.FilePath) { + filePath = params.FilePath + } else { + filePath = filepath.Join(t.workingDir, params.FilePath) + } + + sessionID, messageID := GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files") + } + + p := t.permissions.Request( + permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: filePath, + ToolName: DownloadToolName, + Action: "download", + Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath), + Params: DownloadPermissionsParams(params), + }, + ) + + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + + // Handle timeout with context + requestCtx := ctx + if params.Timeout > 0 { + maxTimeout := 600 // 10 minutes + if params.Timeout > maxTimeout { + params.Timeout = maxTimeout + } + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second) + defer cancel() + } + + req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("User-Agent", "crush/1.0") + + resp, err := t.client.Do(req) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil + } + + // Check content length if available + maxSize := int64(100 * 1024 * 1024) // 100MB + if resp.ContentLength > maxSize { + return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil + } + + // Create parent directories if they don't exist + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err) + } + + // Create the output file + outFile, err := os.Create(filePath) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err) + } + defer outFile.Close() + + // Copy data with size limit + limitedReader := io.LimitReader(resp.Body, maxSize) + bytesWritten, err := io.Copy(outFile, limitedReader) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) + } + + // Check if we hit the size limit + if bytesWritten == maxSize { + // Clean up the file since it might be incomplete + os.Remove(filePath) + return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil + } + + contentType := resp.Header.Get("Content-Type") + responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath) + if contentType != "" { + responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType) + } + + return NewTextResponse(responseMsg), nil +} diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 28e15d19cee8219ccc4575ed036f29e8286db229..1e44151b1124c643d2ddd428144e66c5d365e609 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -8,6 +8,7 @@ import ( "net/http" "strings" "time" + "unicode/utf8" md "github.com/JohannesKaufmann/html-to-markdown" "github.com/PuerkitoBio/goquery" @@ -182,6 +183,11 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error } content := string(body) + + isValidUt8 := utf8.ValidString(content) + if !isValidUt8 { + return NewTextErrorResponse("Response content is not valid UTF-8"), nil + } contentType := resp.Header.Get("Content-Type") switch format { @@ -191,9 +197,8 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error if err != nil { return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil } - return NewTextResponse(text), nil + content = text } - return NewTextResponse(content), nil case "markdown": if strings.Contains(contentType, "text/html") { @@ -201,17 +206,36 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error if err != nil { return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil } - return NewTextResponse(markdown), nil + content = markdown } - return NewTextResponse("```\n" + content + "\n```"), nil + content = "```\n" + content + "\n```" case "html": - return NewTextResponse(content), nil - - default: - return NewTextResponse(content), nil + // return only the body of the HTML document + if strings.Contains(contentType, "text/html") { + doc, err := goquery.NewDocumentFromReader(strings.NewReader(content)) + if err != nil { + return NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil + } + body, err := doc.Find("body").Html() + if err != nil { + return NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil + } + if body == "" { + return NewTextErrorResponse("No body content found in HTML"), nil + } + content = "\n\n" + body + "\n\n" + } + } + // calculate byte size of content + contentSize := int64(len(content)) + if contentSize > MaxReadSize { + content = content[:MaxReadSize] + content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize) } + + return NewTextResponse(content), nil } func extractTextFromHTML(html string) (string, error) { @@ -220,7 +244,7 @@ func extractTextFromHTML(html string) (string, error) { return "", err } - text := doc.Text() + text := doc.Find("body").Text() text = strings.Join(strings.Fields(text), " ") return text, nil diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 27bbc237209e64637cfefb0f4ff1097f96641c2e..d8ca7e9e8c7880a760e8eb2096c83914a9dc13b5 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "strings" + "unicode/utf8" "github.com/charmbracelet/crush/internal/lsp" ) @@ -173,11 +174,15 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) isImage, imageType := isImageFile(filePath) // TODO: handle images if isImage { - return NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\nUse a different tool to process images", imageType)), nil + return NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\n", imageType)), nil } // Read the file content content, lineCount, err := readTextFile(filePath, params.Offset, params.Limit) + isValidUt8 := utf8.ValidString(content) + if !isValidUt8 { + return NewTextErrorResponse("File content is not valid UTF-8"), nil + } if err != nil { return ToolResponse{}, fmt.Errorf("error reading file: %w", err) } diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index 053870476f5a47d67eb827bbc4143c619049f13f..898780495242819e42257ba72d90a02e10e1cd71 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -162,6 +162,7 @@ func (br baseRenderer) renderError(v *toolCallCmp, message string) string { // Register tool renderers func init() { registry.register(tools.BashToolName, func() renderer { return bashRenderer{} }) + registry.register(tools.DownloadToolName, func() renderer { return downloadRenderer{} }) registry.register(tools.ViewToolName, func() renderer { return viewRenderer{} }) registry.register(tools.EditToolName, func() renderer { return editRenderer{} }) registry.register(tools.WriteToolName, func() renderer { return writeRenderer{} }) @@ -376,6 +377,32 @@ func formatTimeout(timeout int) string { return (time.Duration(timeout) * time.Second).String() } +// ----------------------------------------------------------------------------- +// Download renderer +// ----------------------------------------------------------------------------- + +// downloadRenderer handles file downloading with URL and file path display +type downloadRenderer struct { + baseRenderer +} + +// Render displays the download URL and destination file path with timeout parameter +func (dr downloadRenderer) Render(v *toolCallCmp) string { + var params tools.DownloadParams + var args []string + if err := dr.unmarshalParams(v.call.Input, ¶ms); err == nil { + args = newParamBuilder(). + addMain(params.URL). + addKeyValue("file_path", fsext.PrettyPath(params.FilePath)). + addKeyValue("timeout", formatTimeout(params.Timeout)). + build() + } + + return dr.renderWithParams(v, "Download", args, func() string { + return renderPlainContent(v, v.result.Content) + }) +} + // ----------------------------------------------------------------------------- // Glob renderer // ----------------------------------------------------------------------------- @@ -758,6 +785,8 @@ func prettifyToolName(name string) string { return "Agent" case tools.BashToolName: return "Bash" + case tools.DownloadToolName: + return "Download" case tools.EditToolName: return "Edit" case tools.FetchToolName: diff --git a/internal/tui/components/dialogs/permissions/permissions.go b/internal/tui/components/dialogs/permissions/permissions.go index dd8668ad393fefdb0161a933f6baf7e7250ce05d..1b41094c9c69ba91bbbefdf86e7040cd77d3ce8e 100644 --- a/internal/tui/components/dialogs/permissions/permissions.go +++ b/internal/tui/components/dialogs/permissions/permissions.go @@ -252,6 +252,30 @@ func (p *permissionDialogCmp) renderHeader() string { switch p.permission.ToolName { case tools.BashToolName: headerParts = append(headerParts, t.S().Muted.Width(p.width).Render("Command")) + case tools.DownloadToolName: + params := p.permission.Params.(tools.DownloadPermissionsParams) + urlKey := t.S().Muted.Render("URL") + urlValue := t.S().Text. + Width(p.width - lipgloss.Width(urlKey)). + Render(fmt.Sprintf(" %s", params.URL)) + fileKey := t.S().Muted.Render("File") + filePath := t.S().Text. + Width(p.width - lipgloss.Width(fileKey)). + Render(fmt.Sprintf(" %s", fsext.PrettyPath(params.FilePath))) + headerParts = append(headerParts, + lipgloss.JoinHorizontal( + lipgloss.Left, + urlKey, + urlValue, + ), + baseStyle.Render(strings.Repeat(" ", p.width)), + lipgloss.JoinHorizontal( + lipgloss.Left, + fileKey, + filePath, + ), + baseStyle.Render(strings.Repeat(" ", p.width)), + ) case tools.EditToolName: params := p.permission.Params.(tools.EditPermissionsParams) fileKey := t.S().Muted.Render("File") @@ -299,6 +323,8 @@ func (p *permissionDialogCmp) getOrGenerateContent() string { switch p.permission.ToolName { case tools.BashToolName: content = p.generateBashContent() + case tools.DownloadToolName: + content = p.generateDownloadContent() case tools.EditToolName: content = p.generateEditContent() case tools.WriteToolName: @@ -391,6 +417,24 @@ func (p *permissionDialogCmp) generateWriteContent() string { return "" } +func (p *permissionDialogCmp) generateDownloadContent() string { + t := styles.CurrentTheme() + baseStyle := t.S().Base.Background(t.BgSubtle) + if pr, ok := p.permission.Params.(tools.DownloadPermissionsParams); ok { + content := fmt.Sprintf("URL: %s\nFile: %s", pr.URL, fsext.PrettyPath(pr.FilePath)) + if pr.Timeout > 0 { + content += fmt.Sprintf("\nTimeout: %ds", pr.Timeout) + } + + finalContent := baseStyle. + Padding(1, 2). + Width(p.contentViewPort.Width()). + Render(content) + return finalContent + } + return "" +} + func (p *permissionDialogCmp) generateFetchContent() string { t := styles.CurrentTheme() baseStyle := t.S().Base.Background(t.BgSubtle) @@ -526,6 +570,9 @@ func (p *permissionDialogCmp) SetSize() tea.Cmd { case tools.BashToolName: p.width = int(float64(p.wWidth) * 0.8) p.height = int(float64(p.wHeight) * 0.3) + case tools.DownloadToolName: + p.width = int(float64(p.wWidth) * 0.8) + p.height = int(float64(p.wHeight) * 0.4) case tools.EditToolName: p.width = int(float64(p.wWidth) * 0.8) p.height = int(float64(p.wHeight) * 0.8)