refactor(chat): add file tools

Kujtim Hoxha created

Change summary

internal/ui/chat/bash.go     |   2 
internal/ui/chat/file.go     | 278 ++++++++++++++++++++++++++++++++++++++
internal/ui/chat/tools.go    | 192 +++++++++++++++++++++++++
internal/ui/styles/styles.go |  18 +-
4 files changed, 479 insertions(+), 11 deletions(-)

Detailed changes

internal/ui/chat/bash.go 🔗

@@ -69,7 +69,7 @@ func (b *BashToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *
 		toolParams = append(toolParams, "background", "true")
 	}
 
-	header := toolHeader(sty, opts.Status(), "Bash", cappedWidth, toolParams...)
+	header := toolHeader(sty, opts.Status(), "Bash", cappedWidth, opts.Nested, toolParams...)
 	if opts.Nested {
 		return header
 	}

internal/ui/chat/file.go 🔗

@@ -0,0 +1,278 @@
+package chat
+
+import (
+	"encoding/json"
+	"fmt"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/agent/tools"
+	"github.com/charmbracelet/crush/internal/fsext"
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/charmbracelet/crush/internal/ui/styles"
+)
+
+// -----------------------------------------------------------------------------
+// View Tool
+// -----------------------------------------------------------------------------
+
+// ViewToolMessageItem is a message item that represents a view tool call.
+type ViewToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*ViewToolMessageItem)(nil)
+
+// NewViewToolMessageItem creates a new [ViewToolMessageItem].
+func NewViewToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &ViewToolRenderContext{}, canceled)
+}
+
+// ViewToolRenderContext renders view tool messages.
+type ViewToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (v *ViewToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	cappedWidth := cappedMessageWidth(width)
+	if !opts.ToolCall.Finished && !opts.Canceled {
+		return pendingTool(sty, "View", opts.Anim)
+	}
+
+	var params tools.ViewParams
+	if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth)
+	}
+
+	file := fsext.PrettyPath(params.FilePath)
+	toolParams := []string{file}
+	if params.Limit != 0 {
+		toolParams = append(toolParams, "limit", fmt.Sprintf("%d", params.Limit))
+	}
+	if params.Offset != 0 {
+		toolParams = append(toolParams, "offset", fmt.Sprintf("%d", params.Offset))
+	}
+
+	header := toolHeader(sty, opts.Status(), "View", cappedWidth, opts.Nested, toolParams...)
+	if opts.Nested {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if opts.Result == nil {
+		return header
+	}
+
+	// Handle image content.
+	if opts.Result.Data != "" && strings.HasPrefix(opts.Result.MIMEType, "image/") {
+		body := toolOutputImageContent(sty, opts.Result.Data, opts.Result.MIMEType)
+		return joinToolParts(header, body)
+	}
+
+	// Try to get content from metadata first (contains actual file content).
+	var meta tools.ViewResponseMetadata
+	content := opts.Result.Content
+	if err := json.Unmarshal([]byte(opts.Result.Metadata), &meta); err == nil && meta.Content != "" {
+		content = meta.Content
+	}
+
+	if content == "" {
+		return header
+	}
+
+	// Render code content with syntax highlighting.
+	body := toolOutputCodeContent(sty, params.FilePath, content, params.Offset, cappedWidth, opts.Expanded)
+	return joinToolParts(header, body)
+}
+
+// -----------------------------------------------------------------------------
+// Write Tool
+// -----------------------------------------------------------------------------
+
+// WriteToolMessageItem is a message item that represents a write tool call.
+type WriteToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*WriteToolMessageItem)(nil)
+
+// NewWriteToolMessageItem creates a new [WriteToolMessageItem].
+func NewWriteToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &WriteToolRenderContext{}, canceled)
+}
+
+// WriteToolRenderContext renders write tool messages.
+type WriteToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (w *WriteToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	cappedWidth := cappedMessageWidth(width)
+	if !opts.ToolCall.Finished && !opts.Canceled {
+		return pendingTool(sty, "Write", opts.Anim)
+	}
+
+	var params tools.WriteParams
+	if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth)
+	}
+
+	file := fsext.PrettyPath(params.FilePath)
+	header := toolHeader(sty, opts.Status(), "Write", cappedWidth, opts.Nested, file)
+	if opts.Nested {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if params.Content == "" {
+		return header
+	}
+
+	// Render code content with syntax highlighting.
+	body := toolOutputCodeContent(sty, params.FilePath, params.Content, 0, cappedWidth, opts.Expanded)
+	return joinToolParts(header, body)
+}
+
+// -----------------------------------------------------------------------------
+// Edit Tool
+// -----------------------------------------------------------------------------
+
+// EditToolMessageItem is a message item that represents an edit tool call.
+type EditToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*EditToolMessageItem)(nil)
+
+// NewEditToolMessageItem creates a new [EditToolMessageItem].
+func NewEditToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &EditToolRenderContext{}, canceled)
+}
+
+// EditToolRenderContext renders edit tool messages.
+type EditToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (e *EditToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	// Edit tool uses full width for diffs.
+	if !opts.ToolCall.Finished && !opts.Canceled {
+		return pendingTool(sty, "Edit", opts.Anim)
+	}
+
+	var params tools.EditParams
+	if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, width)
+	}
+
+	file := fsext.PrettyPath(params.FilePath)
+	header := toolHeader(sty, opts.Status(), "Edit", width, opts.Nested, file)
+	if opts.Nested {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, width); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if opts.Result == nil {
+		return header
+	}
+
+	// Get diff content from metadata.
+	var meta tools.EditResponseMetadata
+	if err := json.Unmarshal([]byte(opts.Result.Metadata), &meta); err != nil {
+		bodyWidth := width - toolBodyLeftPaddingTotal
+		body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded))
+		return joinToolParts(header, body)
+	}
+
+	// Render diff.
+	body := toolOutputDiffContent(sty, file, meta.OldContent, meta.NewContent, width, opts.Expanded)
+	return joinToolParts(header, body)
+}
+
+// -----------------------------------------------------------------------------
+// MultiEdit Tool
+// -----------------------------------------------------------------------------
+
+// MultiEditToolMessageItem is a message item that represents a multi-edit tool call.
+type MultiEditToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*MultiEditToolMessageItem)(nil)
+
+// NewMultiEditToolMessageItem creates a new [MultiEditToolMessageItem].
+func NewMultiEditToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(sty, toolCall, result, &MultiEditToolRenderContext{}, canceled)
+}
+
+// MultiEditToolRenderContext renders multi-edit tool messages.
+type MultiEditToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (m *MultiEditToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	// MultiEdit tool uses full width for diffs.
+	if !opts.ToolCall.Finished && !opts.Canceled {
+		return pendingTool(sty, "Multi-Edit", opts.Anim)
+	}
+
+	var params tools.MultiEditParams
+	if err := json.Unmarshal([]byte(opts.ToolCall.Input), &params); err != nil {
+		return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, width)
+	}
+
+	file := fsext.PrettyPath(params.FilePath)
+	toolParams := []string{file}
+	if len(params.Edits) > 0 {
+		toolParams = append(toolParams, "edits", fmt.Sprintf("%d", len(params.Edits)))
+	}
+
+	header := toolHeader(sty, opts.Status(), "Multi-Edit", width, opts.Nested, toolParams...)
+	if opts.Nested {
+		return header
+	}
+
+	if earlyState, ok := toolEarlyStateContent(sty, opts, width); ok {
+		return joinToolParts(header, earlyState)
+	}
+
+	if opts.Result == nil {
+		return header
+	}
+
+	// Get diff content from metadata.
+	var meta tools.MultiEditResponseMetadata
+	if err := json.Unmarshal([]byte(opts.Result.Metadata), &meta); err != nil {
+		bodyWidth := width - toolBodyLeftPaddingTotal
+		body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded))
+		return joinToolParts(header, body)
+	}
+
+	// Render diff with optional failed edits note.
+	body := toolOutputMultiEditDiffContent(sty, file, meta, len(params.Edits), width, opts.Expanded)
+	return joinToolParts(header, body)
+}

internal/ui/chat/tools.go 🔗

@@ -9,6 +9,7 @@ import (
 	"github.com/charmbracelet/crush/internal/agent/tools"
 	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/ui/anim"
+	"github.com/charmbracelet/crush/internal/ui/common"
 	"github.com/charmbracelet/crush/internal/ui/styles"
 	"github.com/charmbracelet/x/ansi"
 )
@@ -162,6 +163,14 @@ func NewToolMessageItem(
 		return NewJobOutputToolMessageItem(sty, toolCall, result, canceled)
 	case tools.JobKillToolName:
 		return NewJobKillToolMessageItem(sty, toolCall, result, canceled)
+	case tools.ViewToolName:
+		return NewViewToolMessageItem(sty, toolCall, result, canceled)
+	case tools.WriteToolName:
+		return NewWriteToolMessageItem(sty, toolCall, result, canceled)
+	case tools.EditToolName:
+		return NewEditToolMessageItem(sty, toolCall, result, canceled)
+	case tools.MultiEditToolName:
+		return NewMultiEditToolMessageItem(sty, toolCall, result, canceled)
 	default:
 		// TODO: Implement other tool items
 		return newBaseToolMessageItem(
@@ -376,9 +385,13 @@ func toolParamList(sty *styles.Styles, params []string, width int) string {
 }
 
 // toolHeader builds the tool header line: "● ToolName params..."
-func toolHeader(sty *styles.Styles, status ToolStatus, name string, width int, params ...string) string {
+func toolHeader(sty *styles.Styles, status ToolStatus, name string, width int, nested bool, params ...string) string {
 	icon := toolIcon(sty, status)
-	toolName := sty.Tool.NameNested.Render(name)
+	nameStyle := sty.Tool.NameNormal
+	if nested {
+		nameStyle = sty.Tool.NameNested
+	}
+	toolName := nameStyle.Render(name)
 	prefix := fmt.Sprintf("%s %s ", icon, toolName)
 	prefixWidth := lipgloss.Width(prefix)
 	remainingWidth := width - prefixWidth
@@ -420,3 +433,178 @@ func toolOutputPlainContent(sty *styles.Styles, content string, width int, expan
 
 	return strings.Join(out, "\n")
 }
+
+// toolOutputCodeContent renders code with syntax highlighting and line numbers.
+func toolOutputCodeContent(sty *styles.Styles, path, content string, offset, width int, expanded bool) string {
+	content = strings.ReplaceAll(content, "\r\n", "\n")
+	content = strings.ReplaceAll(content, "\t", "    ")
+
+	lines := strings.Split(content, "\n")
+	maxLines := responseContextHeight
+	if expanded {
+		maxLines = len(lines)
+	}
+
+	// Truncate if needed.
+	displayLines := lines
+	if len(lines) > maxLines {
+		displayLines = lines[:maxLines]
+	}
+
+	bg := sty.Tool.ContentCodeBg
+	highlighted, _ := common.SyntaxHighlight(sty, strings.Join(displayLines, "\n"), path, bg)
+	highlightedLines := strings.Split(highlighted, "\n")
+
+	// Calculate line number width.
+	maxLineNumber := len(displayLines) + offset
+	maxDigits := getDigits(maxLineNumber)
+	numFmt := fmt.Sprintf("%%%dd", maxDigits)
+
+	bodyWidth := width - toolBodyLeftPaddingTotal
+	codeWidth := bodyWidth - maxDigits - 4 // -4 for line number padding
+
+	var out []string
+	for i, ln := range highlightedLines {
+		lineNum := sty.Tool.ContentLineNumber.Render(fmt.Sprintf(numFmt, i+1+offset))
+
+		if lipgloss.Width(ln) > codeWidth {
+			ln = ansi.Truncate(ln, codeWidth, "…")
+		}
+
+		codeLine := sty.Tool.ContentCodeLine.
+			Width(codeWidth).
+			PaddingLeft(2).
+			Render(ln)
+
+		out = append(out, lipgloss.JoinHorizontal(lipgloss.Left, lineNum, codeLine))
+	}
+
+	// Add truncation message if needed.
+	if len(lines) > maxLines && !expanded {
+		truncMsg := sty.Tool.ContentCodeTruncation.
+			Width(bodyWidth).
+			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
+		out = append(out, truncMsg)
+	}
+
+	return sty.Tool.Body.Render(strings.Join(out, "\n"))
+}
+
+// toolOutputImageContent renders image data with size info.
+func toolOutputImageContent(sty *styles.Styles, data, mediaType string) string {
+	dataSize := len(data) * 3 / 4
+	sizeStr := formatSize(dataSize)
+
+	loaded := sty.Base.Foreground(sty.Green).Render("Loaded")
+	arrow := sty.Base.Foreground(sty.GreenDark).Render("→")
+	typeStyled := sty.Base.Render(mediaType)
+	sizeStyled := sty.Subtle.Render(sizeStr)
+
+	return sty.Tool.Body.Render(fmt.Sprintf("%s %s %s %s", loaded, arrow, typeStyled, sizeStyled))
+}
+
+// getDigits returns the number of digits in a number.
+func getDigits(n int) int {
+	if n == 0 {
+		return 1
+	}
+	if n < 0 {
+		n = -n
+	}
+	digits := 0
+	for n > 0 {
+		n /= 10
+		digits++
+	}
+	return digits
+}
+
+// formatSize formats byte size into human readable format.
+func formatSize(bytes int) string {
+	const (
+		kb = 1024
+		mb = kb * 1024
+	)
+	switch {
+	case bytes >= mb:
+		return fmt.Sprintf("%.1f MB", float64(bytes)/float64(mb))
+	case bytes >= kb:
+		return fmt.Sprintf("%.1f KB", float64(bytes)/float64(kb))
+	default:
+		return fmt.Sprintf("%d B", bytes)
+	}
+}
+
+// toolOutputDiffContent renders a diff between old and new content.
+func toolOutputDiffContent(sty *styles.Styles, file, oldContent, newContent string, width int, expanded bool) string {
+	bodyWidth := width - toolBodyLeftPaddingTotal
+
+	formatter := common.DiffFormatter(sty).
+		Before(file, oldContent).
+		After(file, newContent).
+		Width(bodyWidth)
+
+	// Use split view for wide terminals.
+	if width > 120 {
+		formatter = formatter.Split()
+	}
+
+	formatted := formatter.String()
+	lines := strings.Split(formatted, "\n")
+
+	// Truncate if needed.
+	maxLines := responseContextHeight
+	if expanded {
+		maxLines = len(lines)
+	}
+
+	if len(lines) > maxLines && !expanded {
+		truncMsg := sty.Tool.DiffTruncation.
+			Width(bodyWidth).
+			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
+		formatted = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
+	}
+
+	return sty.Tool.Body.Render(formatted)
+}
+
+// toolOutputMultiEditDiffContent renders a diff with optional failed edits note.
+func toolOutputMultiEditDiffContent(sty *styles.Styles, file string, meta tools.MultiEditResponseMetadata, totalEdits, width int, expanded bool) string {
+	bodyWidth := width - toolBodyLeftPaddingTotal
+
+	formatter := common.DiffFormatter(sty).
+		Before(file, meta.OldContent).
+		After(file, meta.NewContent).
+		Width(bodyWidth)
+
+	// Use split view for wide terminals.
+	if width > 120 {
+		formatter = formatter.Split()
+	}
+
+	formatted := formatter.String()
+	lines := strings.Split(formatted, "\n")
+
+	// Truncate if needed.
+	maxLines := responseContextHeight
+	if expanded {
+		maxLines = len(lines)
+	}
+
+	if len(lines) > maxLines && !expanded {
+		truncMsg := sty.Tool.DiffTruncation.
+			Width(bodyWidth).
+			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
+		formatted = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
+	}
+
+	// Add failed edits note if any exist.
+	if len(meta.EditsFailed) > 0 {
+		noteTag := sty.Tool.NoteTag.Render("Note")
+		noteMsg := fmt.Sprintf("%d of %d edits succeeded", meta.EditsApplied, totalEdits)
+		note := fmt.Sprintf("%s %s", noteTag, sty.Tool.NoteMessage.Render(noteMsg))
+		formatted = formatted + "\n\n" + note
+	}
+
+	return sty.Tool.Body.Render(formatted)
+}

internal/ui/styles/styles.go 🔗

@@ -238,11 +238,12 @@ type Styles struct {
 		ParamKey  lipgloss.Style // Parameter keys
 
 		// Content rendering styles
-		ContentLine       lipgloss.Style // Individual content line with background and width
-		ContentTruncation lipgloss.Style // Truncation message "… (N lines)"
-		ContentCodeLine   lipgloss.Style // Code line with background and width
-		ContentCodeBg     color.Color    // Background color for syntax highlighting
-		Body              lipgloss.Style // Body content padding (PaddingLeft(2))
+		ContentLine           lipgloss.Style // Individual content line with background and width
+		ContentTruncation     lipgloss.Style // Truncation message "… (N lines)"
+		ContentCodeLine       lipgloss.Style // Code line with background and width
+		ContentCodeTruncation lipgloss.Style // Code truncation message with bgBase
+		ContentCodeBg         color.Color    // Background color for syntax highlighting
+		Body                  lipgloss.Style // Body content padding (PaddingLeft(2))
 
 		// Deprecated - kept for backward compatibility
 		ContentBg         lipgloss.Style // Content background
@@ -970,14 +971,15 @@ func DefaultStyles() Styles {
 	// Content rendering - prepared styles that accept width parameter
 	s.Tool.ContentLine = s.Muted.Background(bgBaseLighter)
 	s.Tool.ContentTruncation = s.Muted.Background(bgBaseLighter)
-	s.Tool.ContentCodeLine = s.Base.Background(bgBaseLighter)
+	s.Tool.ContentCodeLine = s.Base.Background(bgBase)
+	s.Tool.ContentCodeTruncation = s.Muted.Background(bgBase).PaddingLeft(2)
 	s.Tool.ContentCodeBg = bgBase
 	s.Tool.Body = base.PaddingLeft(2)
 
 	// Deprecated - kept for backward compatibility
 	s.Tool.ContentBg = s.Muted.Background(bgBaseLighter)
 	s.Tool.ContentText = s.Muted
-	s.Tool.ContentLineNumber = s.Subtle
+	s.Tool.ContentLineNumber = base.Foreground(fgMuted).Background(bgBase).PaddingRight(1).PaddingLeft(1)
 
 	s.Tool.StateWaiting = base.Foreground(fgSubtle)
 	s.Tool.StateCancelled = base.Foreground(fgSubtle)
@@ -987,7 +989,7 @@ func DefaultStyles() Styles {
 
 	// Diff and multi-edit styles
 	s.Tool.DiffTruncation = s.Muted.Background(bgBaseLighter).PaddingLeft(2)
-	s.Tool.NoteTag = base.Padding(0, 1).Background(yellow).Foreground(white)
+	s.Tool.NoteTag = base.Padding(0, 1).Background(info).Foreground(white)
 	s.Tool.NoteMessage = base.Foreground(fgHalfMuted)
 
 	// Job header styles