refactor: add more tools

Kujtim Hoxha created

- diagnostics
- sourcegraph
- fetch
- download

Change summary

internal/ui/chat/bash.go        |  6 +-
internal/ui/chat/diagnostics.go | 68 ++++++++++++++++++++++++++++
internal/ui/chat/fetch.go       | 84 +++++++++++++++++++++++++++++++++++
internal/ui/chat/file.go        | 78 +++++++++++++++++++++++++++++---
internal/ui/chat/search.go      | 74 ++++++++++++++++++++++++++++--
internal/ui/chat/tools.go       | 48 ++++++++++++++++++-
6 files changed, 337 insertions(+), 21 deletions(-)

Detailed changes

internal/ui/chat/bash.go 🔗

@@ -69,8 +69,8 @@ func (b *BashToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *
 		toolParams = append(toolParams, "background", "true")
 	}
 
-	header := toolHeader(sty, opts.Status(), "Bash", cappedWidth, opts.Nested, toolParams...)
-	if opts.Nested {
+	header := toolHeader(sty, opts.Status(), "Bash", cappedWidth, opts.Simple, toolParams...)
+	if opts.Simple {
 		return header
 	}
 
@@ -201,7 +201,7 @@ func (j *JobKillToolRenderContext) RenderTool(sty *styles.Styles, width int, opt
 // header → nested check → early state → body.
 func renderJobTool(sty *styles.Styles, opts *ToolRenderOpts, width int, action, shellID, description, content string) string {
 	header := jobHeader(sty, opts.Status(), action, shellID, description, width)
-	if opts.Nested {
+	if opts.Simple {
 		return header
 	}
 

internal/ui/chat/diagnostics.go 🔗

@@ -0,0 +1,68 @@
+package chat
+
+import (
+	"encoding/json"
+
+	"github.com/charmbracelet/crush/internal/agent/tools"
+	"github.com/charmbracelet/crush/internal/fsext"
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/charmbracelet/crush/internal/ui/styles"
+)
+
+// -----------------------------------------------------------------------------
+// Diagnostics Tool
+// -----------------------------------------------------------------------------
+
+// DiagnosticsToolMessageItem is a message item that represents a diagnostics tool call.
+type DiagnosticsToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*DiagnosticsToolMessageItem)(nil)
+
+// NewDiagnosticsToolMessageItem creates a new [DiagnosticsToolMessageItem].
+func NewDiagnosticsToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &DiagnosticsToolRenderContext{}, canceled)
+}
+
+// DiagnosticsToolRenderContext renders diagnostics tool messages.
+type DiagnosticsToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (d *DiagnosticsToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	cappedWidth := cappedMessageWidth(width)
+	if !opts.ToolCall.Finished && !opts.Canceled {
+		return pendingTool(sty, "Diagnostics", opts.Anim)
+	}
+
+	var params tools.DiagnosticsParams
+	_ = json.Unmarshal([]byte(opts.ToolCall.Input), &params)
+
+	// Show "project" if no file path, otherwise show the file path.
+	mainParam := "project"
+	if params.FilePath != "" {
+		mainParam = fsext.PrettyPath(params.FilePath)
+	}
+
+	header := toolHeader(sty, opts.Status(), "Diagnostics", cappedWidth, opts.Simple, mainParam)
+	if opts.Simple {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if opts.Result == nil || opts.Result.Content == "" {
+		return header
+	}
+
+	bodyWidth := cappedWidth - toolBodyLeftPaddingTotal
+	body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded))
+	return joinToolParts(header, body)
+}

internal/ui/chat/fetch.go 🔗

@@ -0,0 +1,84 @@
+package chat
+
+import (
+	"encoding/json"
+
+	"github.com/charmbracelet/crush/internal/agent/tools"
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/charmbracelet/crush/internal/ui/styles"
+)
+
+// -----------------------------------------------------------------------------
+// Fetch Tool
+// -----------------------------------------------------------------------------
+
+// FetchToolMessageItem is a message item that represents a fetch tool call.
+type FetchToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*FetchToolMessageItem)(nil)
+
+// NewFetchToolMessageItem creates a new [FetchToolMessageItem].
+func NewFetchToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &FetchToolRenderContext{}, canceled)
+}
+
+// FetchToolRenderContext renders fetch tool messages.
+type FetchToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (f *FetchToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	cappedWidth := cappedMessageWidth(width)
+	if !opts.ToolCall.Finished && !opts.Canceled {
+		return pendingTool(sty, "Fetch", opts.Anim)
+	}
+
+	var params tools.FetchParams
+	if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth)
+	}
+
+	toolParams := []string{params.URL}
+	if params.Format != "" {
+		toolParams = append(toolParams, "format", params.Format)
+	}
+	if params.Timeout != 0 {
+		toolParams = append(toolParams, "timeout", formatTimeout(params.Timeout))
+	}
+
+	header := toolHeader(sty, opts.Status(), "Fetch", cappedWidth, opts.Simple, toolParams...)
+	if opts.Simple {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if opts.Result == nil || opts.Result.Content == "" {
+		return header
+	}
+
+	// Determine file extension for syntax highlighting based on format.
+	file := getFileExtensionForFormat(params.Format)
+	body := toolOutputCodeContent(sty, file, opts.Result.Content, 0, cappedWidth, opts.Expanded)
+	return joinToolParts(header, body)
+}
+
+// getFileExtensionForFormat returns a filename with appropriate extension for syntax highlighting.
+func getFileExtensionForFormat(format string) string {
+	switch format {
+	case "text":
+		return "fetch.txt"
+	case "html":
+		return "fetch.html"
+	default:
+		return "fetch.md"
+	}
+}

internal/ui/chat/file.go 🔗

@@ -56,8 +56,8 @@ func (v *ViewToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *
 		toolParams = append(toolParams, "offset", fmt.Sprintf("%d", params.Offset))
 	}
 
-	header := toolHeader(sty, opts.Status(), "View", cappedWidth, opts.Nested, toolParams...)
-	if opts.Nested {
+	header := toolHeader(sty, opts.Status(), "View", cappedWidth, opts.Simple, toolParams...)
+	if opts.Simple {
 		return header
 	}
 
@@ -128,8 +128,8 @@ func (w *WriteToolRenderContext) RenderTool(sty *styles.Styles, width int, opts
 	}
 
 	file := fsext.PrettyPath(params.FilePath)
-	header := toolHeader(sty, opts.Status(), "Write", cappedWidth, opts.Nested, file)
-	if opts.Nested {
+	header := toolHeader(sty, opts.Status(), "Write", cappedWidth, opts.Simple, file)
+	if opts.Simple {
 		return header
 	}
 
@@ -183,8 +183,8 @@ func (e *EditToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *
 	}
 
 	file := fsext.PrettyPath(params.FilePath)
-	header := toolHeader(sty, opts.Status(), "Edit", width, opts.Nested, file)
-	if opts.Nested {
+	header := toolHeader(sty, opts.Status(), "Edit", width, opts.Simple, file)
+	if opts.Simple {
 		return header
 	}
 
@@ -251,8 +251,8 @@ func (m *MultiEditToolRenderContext) RenderTool(sty *styles.Styles, width int, o
 		toolParams = append(toolParams, "edits", fmt.Sprintf("%d", len(params.Edits)))
 	}
 
-	header := toolHeader(sty, opts.Status(), "Multi-Edit", width, opts.Nested, toolParams...)
-	if opts.Nested {
+	header := toolHeader(sty, opts.Status(), "Multi-Edit", width, opts.Simple, toolParams...)
+	if opts.Simple {
 		return header
 	}
 
@@ -276,3 +276,65 @@ func (m *MultiEditToolRenderContext) RenderTool(sty *styles.Styles, width int, o
 	body := toolOutputMultiEditDiffContent(sty, file, meta, len(params.Edits), width, opts.Expanded)
 	return joinToolParts(header, body)
 }
+
+// -----------------------------------------------------------------------------
+// Download Tool
+// -----------------------------------------------------------------------------
+
+// DownloadToolMessageItem is a message item that represents a download tool call.
+type DownloadToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*DownloadToolMessageItem)(nil)
+
+// NewDownloadToolMessageItem creates a new [DownloadToolMessageItem].
+func NewDownloadToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &DownloadToolRenderContext{}, canceled)
+}
+
+// DownloadToolRenderContext renders download tool messages.
+type DownloadToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (d *DownloadToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	cappedWidth := cappedMessageWidth(width)
+	if !opts.ToolCall.Finished && !opts.Canceled {
+		return pendingTool(sty, "Download", opts.Anim)
+	}
+
+	var params tools.DownloadParams
+	if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth)
+	}
+
+	toolParams := []string{params.URL}
+	if params.FilePath != "" {
+		toolParams = append(toolParams, "file_path", fsext.PrettyPath(params.FilePath))
+	}
+	if params.Timeout != 0 {
+		toolParams = append(toolParams, "timeout", formatTimeout(params.Timeout))
+	}
+
+	header := toolHeader(sty, opts.Status(), "Download", cappedWidth, opts.Simple, toolParams...)
+	if opts.Simple {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if opts.Result == nil || opts.Result.Content == "" {
+		return header
+	}
+
+	bodyWidth := cappedWidth - toolBodyLeftPaddingTotal
+	body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded))
+	return joinToolParts(header, body)
+}

internal/ui/chat/search.go 🔗

@@ -50,8 +50,8 @@ func (g *GlobToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *
 		toolParams = append(toolParams, "path", params.Path)
 	}
 
-	header := toolHeader(sty, opts.Status(), "Glob", cappedWidth, opts.Nested, toolParams...)
-	if opts.Nested {
+	header := toolHeader(sty, opts.Status(), "Glob", cappedWidth, opts.Simple, toolParams...)
+	if opts.Simple {
 		return header
 	}
 
@@ -115,8 +115,8 @@ func (g *GrepToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *
 		toolParams = append(toolParams, "literal", "true")
 	}
 
-	header := toolHeader(sty, opts.Status(), "Grep", cappedWidth, opts.Nested, toolParams...)
-	if opts.Nested {
+	header := toolHeader(sty, opts.Status(), "Grep", cappedWidth, opts.Simple, toolParams...)
+	if opts.Simple {
 		return header
 	}
 
@@ -175,8 +175,70 @@ func (l *LSToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *To
 	}
 	path = fsext.PrettyPath(path)
 
-	header := toolHeader(sty, opts.Status(), "List", cappedWidth, opts.Nested, path)
-	if opts.Nested {
+	header := toolHeader(sty, opts.Status(), "List", cappedWidth, opts.Simple, path)
+	if opts.Simple {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if opts.Result == nil || opts.Result.Content == "" {
+		return header
+	}
+
+	bodyWidth := cappedWidth - toolBodyLeftPaddingTotal
+	body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded))
+	return joinToolParts(header, body)
+}
+
+// -----------------------------------------------------------------------------
+// Sourcegraph Tool
+// -----------------------------------------------------------------------------
+
+// SourcegraphToolMessageItem is a message item that represents a sourcegraph tool call.
+type SourcegraphToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*SourcegraphToolMessageItem)(nil)
+
+// NewSourcegraphToolMessageItem creates a new [SourcegraphToolMessageItem].
+func NewSourcegraphToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &SourcegraphToolRenderContext{}, canceled)
+}
+
+// SourcegraphToolRenderContext renders sourcegraph tool messages.
+type SourcegraphToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (s *SourcegraphToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	cappedWidth := cappedMessageWidth(width)
+	if !opts.ToolCall.Finished && !opts.Canceled {
+		return pendingTool(sty, "Sourcegraph", opts.Anim)
+	}
+
+	var params tools.SourcegraphParams
+	if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth)
+	}
+
+	toolParams := []string{params.Query}
+	if params.Count != 0 {
+		toolParams = append(toolParams, "count", formatNonZero(params.Count))
+	}
+	if params.ContextWindow != 0 {
+		toolParams = append(toolParams, "context", formatNonZero(params.ContextWindow))
+	}
+
+	header := toolHeader(sty, opts.Status(), "Sourcegraph", cappedWidth, opts.Simple, toolParams...)
+	if opts.Simple {
 		return header
 	}
 

internal/ui/chat/tools.go 🔗

@@ -40,6 +40,12 @@ type ToolMessageItem interface {
 	SetResult(res *message.ToolResult)
 }
 
+// Simplifiable is an interface for tool items that can render in a simplified mode.
+// When simple mode is enabled, tools render as a compact single-line header.
+type Simplifiable interface {
+	SetSimple(simple bool)
+}
+
 // DefaultToolRenderContext implements the default [ToolRenderer] interface.
 type DefaultToolRenderContext struct{}
 
@@ -55,7 +61,7 @@ type ToolRenderOpts struct {
 	Canceled            bool
 	Anim                *anim.Anim
 	Expanded            bool
-	Nested              bool
+	Simple              bool
 	IsSpinning          bool
 	PermissionRequested bool
 	PermissionGranted   bool
@@ -106,6 +112,8 @@ type baseToolMessageItem struct {
 	// we use this so we can efficiently cache
 	// tools that have a capped width (e.x bash.. and others)
 	hasCappedWidth bool
+	// isSimple indicates this tool should render in simplified/compact mode.
+	isSimple bool
 
 	sty      *styles.Styles
 	anim     *anim.Anim
@@ -177,6 +185,14 @@ func NewToolMessageItem(
 		return NewGrepToolMessageItem(sty, toolCall, result, canceled)
 	case tools.LSToolName:
 		return NewLSToolMessageItem(sty, toolCall, result, canceled)
+	case tools.DownloadToolName:
+		return NewDownloadToolMessageItem(sty, toolCall, result, canceled)
+	case tools.FetchToolName:
+		return NewFetchToolMessageItem(sty, toolCall, result, canceled)
+	case tools.SourcegraphToolName:
+		return NewSourcegraphToolMessageItem(sty, toolCall, result, canceled)
+	case tools.DiagnosticsToolName:
+		return NewDiagnosticsToolMessageItem(sty, toolCall, result, canceled)
 	default:
 		// TODO: Implement other tool items
 		return newBaseToolMessageItem(
@@ -189,6 +205,12 @@ func NewToolMessageItem(
 	}
 }
 
+// SetSimple implements the Simplifiable interface.
+func (t *baseToolMessageItem) SetSimple(simple bool) {
+	t.isSimple = simple
+	t.clearCache()
+}
+
 // ID returns the unique identifier for this tool message item.
 func (t *baseToolMessageItem) ID() string {
 	return t.toolCall.ID
@@ -230,6 +252,7 @@ func (t *baseToolMessageItem) Render(width int) string {
 			Canceled:            t.canceled,
 			Anim:                t.anim,
 			Expanded:            t.expanded,
+			Simple:              t.isSimple,
 			PermissionRequested: t.permissionRequested,
 			PermissionGranted:   t.permissionGranted,
 			IsSpinning:          t.isSpinning(),
@@ -487,10 +510,10 @@ func toolOutputCodeContent(sty *styles.Styles, path, content string, offset, wid
 
 	// Add truncation message if needed.
 	if len(lines) > maxLines && !expanded {
-		truncMsg := sty.Tool.ContentCodeTruncation.
+		out = append(out, sty.Tool.ContentCodeTruncation.
 			Width(bodyWidth).
-			Render(fmt.Sprintf(assistantMessageTruncateFormat, len(lines)-maxLines))
-		out = append([]string{truncMsg}, out...)
+			Render(fmt.Sprintf(assistantMessageTruncateFormat, len(lines)-maxLines)),
+		)
 	}
 
 	return sty.Tool.Body.Render(strings.Join(out, "\n"))
@@ -574,6 +597,23 @@ func toolOutputDiffContent(sty *styles.Styles, file, oldContent, newContent stri
 	return sty.Tool.Body.Render(formatted)
 }
 
+// formatTimeout converts timeout seconds to a duration string (e.g., "30s").
+// Returns empty string if timeout is 0.
+func formatTimeout(timeout int) string {
+	if timeout == 0 {
+		return ""
+	}
+	return fmt.Sprintf("%ds", timeout)
+}
+
+// formatNonZero returns string representation of non-zero integers, empty string for zero.
+func formatNonZero(value int) string {
+	if value == 0 {
+		return ""
+	}
+	return fmt.Sprintf("%d", value)
+}
+
 // toolOutputMultiEditDiffContent renders a diff with optional failed edits note.
 func toolOutputMultiEditDiffContent(sty *styles.Styles, file string, meta tools.MultiEditResponseMetadata, totalEdits, width int, expanded bool) string {
 	bodyWidth := width - toolBodyLeftPaddingTotal