From c0e3f79756a0f405da72cb0c78ce93e773bee073 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 6 Jan 2026 15:27:41 +0100 Subject: [PATCH] refactor: add agent tools and rename simple->compact --- internal/ui/chat/agent.go | 310 ++++++++++++++++++++++++++++++++ internal/ui/chat/bash.go | 10 +- internal/ui/chat/diagnostics.go | 6 +- internal/ui/chat/fetch.go | 114 +++++++++++- internal/ui/chat/file.go | 34 ++-- internal/ui/chat/messages.go | 1 + internal/ui/chat/search.go | 24 +-- internal/ui/chat/todos.go | 192 ++++++++++++++++++++ internal/ui/chat/tools.go | 185 +++++++++++++++---- internal/ui/model/chat.go | 36 ++++ internal/ui/model/ui.go | 160 ++++++++++++++++- internal/ui/styles/styles.go | 15 ++ 12 files changed, 1011 insertions(+), 76 deletions(-) create mode 100644 internal/ui/chat/agent.go create mode 100644 internal/ui/chat/todos.go diff --git a/internal/ui/chat/agent.go b/internal/ui/chat/agent.go new file mode 100644 index 0000000000000000000000000000000000000000..75c936a92ddfd1c75e9b2e49ec9ef7ee46f08ddf --- /dev/null +++ b/internal/ui/chat/agent.go @@ -0,0 +1,310 @@ +package chat + +import ( + "encoding/json" + "strings" + + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "charm.land/lipgloss/v2/tree" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/ui/anim" + "github.com/charmbracelet/crush/internal/ui/styles" +) + +// ----------------------------------------------------------------------------- +// Agent Tool +// ----------------------------------------------------------------------------- + +// NestedToolContainer is an interface for tool items that can contain nested tool calls. +type NestedToolContainer interface { + NestedTools() []ToolMessageItem + SetNestedTools(tools []ToolMessageItem) + AddNestedTool(tool ToolMessageItem) +} + +// AgentToolMessageItem is a message item that represents an agent tool call. +type AgentToolMessageItem struct { + *baseToolMessageItem + + nestedTools []ToolMessageItem +} + +var ( + _ ToolMessageItem = (*AgentToolMessageItem)(nil) + _ NestedToolContainer = (*AgentToolMessageItem)(nil) +) + +// NewAgentToolMessageItem creates a new [AgentToolMessageItem]. +func NewAgentToolMessageItem( + sty *styles.Styles, + toolCall message.ToolCall, + result *message.ToolResult, + canceled bool, +) *AgentToolMessageItem { + t := &AgentToolMessageItem{} + t.baseToolMessageItem = newBaseToolMessageItem(sty, toolCall, result, &AgentToolRenderContext{agent: t}, canceled) + // For the agent tool we keep spinning until the tool call is finished. + t.isSpinningFn = func(state IsSpinningState) bool { + return state.Result == nil && !state.Canceled + } + return t +} + +// Animate progresses the message animation if it should be spinning. +func (a *AgentToolMessageItem) Animate(msg anim.StepMsg) tea.Cmd { + if a.result != nil || a.canceled { + return nil + } + if msg.ID == a.ID() { + return a.anim.Animate(msg) + } + for _, nestedTool := range a.nestedTools { + if msg.ID != nestedTool.ID() { + continue + } + if s, ok := nestedTool.(Animatable); ok { + return s.Animate(msg) + } + } + return nil +} + +// NestedTools returns the nested tools. +func (a *AgentToolMessageItem) NestedTools() []ToolMessageItem { + return a.nestedTools +} + +// SetNestedTools sets the nested tools. +func (a *AgentToolMessageItem) SetNestedTools(tools []ToolMessageItem) { + a.nestedTools = tools + a.clearCache() +} + +// AddNestedTool adds a nested tool. +func (a *AgentToolMessageItem) AddNestedTool(tool ToolMessageItem) { + // Mark nested tools as simple (compact) rendering. + if s, ok := tool.(Compactable); ok { + s.SetCompact(true) + } + a.nestedTools = append(a.nestedTools, tool) + a.clearCache() +} + +// AgentToolRenderContext renders agent tool messages. +type AgentToolRenderContext struct { + agent *AgentToolMessageItem +} + +// RenderTool implements the [ToolRenderer] interface. +func (r *AgentToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string { + cappedWidth := cappedMessageWidth(width) + if !opts.ToolCall.Finished && !opts.Canceled && len(r.agent.nestedTools) == 0 { + return pendingTool(sty, "Agent", opts.Anim) + } + + var params agent.AgentParams + _ = json.Unmarshal([]byte(opts.ToolCall.Input), ¶ms) + + prompt := params.Prompt + prompt = strings.ReplaceAll(prompt, "\n", " ") + + header := toolHeader(sty, opts.Status(), "Agent", cappedWidth, opts.Compact) + if opts.Compact { + return header + } + + // Build the task tag and prompt. + taskTag := sty.Tool.AgentTaskTag.Render("Task") + taskTagWidth := lipgloss.Width(taskTag) + + // Calculate remaining width for prompt. + remainingWidth := min(cappedWidth-taskTagWidth-3, 120-taskTagWidth-3) // -3 for spacing + + promptText := sty.Tool.AgentPrompt.Width(remainingWidth).Render(prompt) + + header = lipgloss.JoinVertical( + lipgloss.Left, + header, + "", + lipgloss.JoinHorizontal( + lipgloss.Left, + taskTag, + " ", + promptText, + ), + ) + + // Build tree with nested tool calls. + childTools := tree.Root(header) + + for _, nestedTool := range r.agent.nestedTools { + childView := nestedTool.Render(remainingWidth) + childTools.Child(childView) + } + + // Build parts. + var parts []string + parts = append(parts, childTools.Enumerator(roundedEnumerator(2, taskTagWidth-5)).String()) + + // Show animation if still running. + if opts.Result == nil && !opts.Canceled { + parts = append(parts, "", opts.Anim.Render()) + } + + result := lipgloss.JoinVertical(lipgloss.Left, parts...) + + // Add body content when completed. + if opts.Result != nil && opts.Result.Content != "" { + body := toolOutputMarkdownContent(sty, opts.Result.Content, cappedWidth-toolBodyLeftPaddingTotal, opts.ExpandedContent) + return joinToolParts(result, body) + } + + return result +} + +// ----------------------------------------------------------------------------- +// Agentic Fetch Tool +// ----------------------------------------------------------------------------- + +// AgenticFetchToolMessageItem is a message item that represents an agentic fetch tool call. +type AgenticFetchToolMessageItem struct { + *baseToolMessageItem + + nestedTools []ToolMessageItem +} + +var ( + _ ToolMessageItem = (*AgenticFetchToolMessageItem)(nil) + _ NestedToolContainer = (*AgenticFetchToolMessageItem)(nil) +) + +// NewAgenticFetchToolMessageItem creates a new [AgenticFetchToolMessageItem]. +func NewAgenticFetchToolMessageItem( + sty *styles.Styles, + toolCall message.ToolCall, + result *message.ToolResult, + canceled bool, +) *AgenticFetchToolMessageItem { + t := &AgenticFetchToolMessageItem{} + t.baseToolMessageItem = newBaseToolMessageItem(sty, toolCall, result, &AgenticFetchToolRenderContext{fetch: t}, canceled) + // For the agentic fetch tool we keep spinning until the tool call is finished. + t.isSpinningFn = func(state IsSpinningState) bool { + return state.Result == nil && !state.Canceled + } + return t +} + +// NestedTools returns the nested tools. +func (a *AgenticFetchToolMessageItem) NestedTools() []ToolMessageItem { + return a.nestedTools +} + +// SetNestedTools sets the nested tools. +func (a *AgenticFetchToolMessageItem) SetNestedTools(tools []ToolMessageItem) { + a.nestedTools = tools + a.clearCache() +} + +// AddNestedTool adds a nested tool. +func (a *AgenticFetchToolMessageItem) AddNestedTool(tool ToolMessageItem) { + // Mark nested tools as simple (compact) rendering. + if s, ok := tool.(Compactable); ok { + s.SetCompact(true) + } + a.nestedTools = append(a.nestedTools, tool) + a.clearCache() +} + +// AgenticFetchToolRenderContext renders agentic fetch tool messages. +type AgenticFetchToolRenderContext struct { + fetch *AgenticFetchToolMessageItem +} + +// agenticFetchParams matches tools.AgenticFetchParams. +type agenticFetchParams struct { + URL string `json:"url,omitempty"` + Prompt string `json:"prompt"` +} + +// RenderTool implements the [ToolRenderer] interface. +func (r *AgenticFetchToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string { + cappedWidth := cappedMessageWidth(width) + if !opts.ToolCall.Finished && !opts.Canceled && len(r.fetch.nestedTools) == 0 { + return pendingTool(sty, "Agentic Fetch", opts.Anim) + } + + var params agenticFetchParams + _ = json.Unmarshal([]byte(opts.ToolCall.Input), ¶ms) + + prompt := params.Prompt + prompt = strings.ReplaceAll(prompt, "\n", " ") + + // Build header with optional URL param. + toolParams := []string{} + if params.URL != "" { + toolParams = append(toolParams, params.URL) + } + + header := toolHeader(sty, opts.Status(), "Agentic Fetch", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { + return header + } + + // Build the prompt tag. + promptTag := sty.Base.Bold(true). + Padding(0, 1). + MarginLeft(2). + Background(sty.Green). + Foreground(sty.Border). + Render("Prompt") + promptTagWidth := lipgloss.Width(promptTag) + + // Calculate remaining width for prompt text. + remainingWidth := cappedWidth - promptTagWidth - 3 // -3 for spacing + if remainingWidth > 120-promptTagWidth-3 { + remainingWidth = 120 - promptTagWidth - 3 + } + + promptText := sty.Base.Width(remainingWidth).Render(prompt) + + header = lipgloss.JoinVertical( + lipgloss.Left, + header, + "", + lipgloss.JoinHorizontal( + lipgloss.Left, + promptTag, + " ", + promptText, + ), + ) + + // Build tree with nested tool calls. + childTools := tree.Root(header) + + for _, nestedTool := range r.fetch.nestedTools { + childView := nestedTool.Render(remainingWidth) + childTools.Child(childView) + } + + // Build parts. + var parts []string + parts = append(parts, childTools.Enumerator(roundedEnumerator(2, promptTagWidth-5)).String()) + + // Show animation if still running. + if opts.Result == nil && !opts.Canceled { + parts = append(parts, "", opts.Anim.Render()) + } + + result := lipgloss.JoinVertical(lipgloss.Left, parts...) + + // Add body content when completed. + if opts.Result != nil && opts.Result.Content != "" { + body := toolOutputMarkdownContent(sty, opts.Result.Content, cappedWidth-toolBodyLeftPaddingTotal, opts.ExpandedContent) + return joinToolParts(result, body) + } + + return result +} diff --git a/internal/ui/chat/bash.go b/internal/ui/chat/bash.go index 85b3c0db81756a47ec5a8009b24d975dcf4d3358..0202780cf1e670b48cb7f9a8b9d27a0fe44f5405 100644 --- a/internal/ui/chat/bash.go +++ b/internal/ui/chat/bash.go @@ -69,8 +69,8 @@ func (b *BashToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * toolParams = append(toolParams, "background", "true") } - header := toolHeader(sty, opts.Status(), "Bash", cappedWidth, opts.Simple, toolParams...) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Bash", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { return header } @@ -91,7 +91,7 @@ func (b *BashToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * } bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, output, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, output, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } @@ -201,7 +201,7 @@ func (j *JobKillToolRenderContext) RenderTool(sty *styles.Styles, width int, opt // header → nested check → early state → body. func renderJobTool(sty *styles.Styles, opts *ToolRenderOpts, width int, action, shellID, description, content string) string { header := jobHeader(sty, opts.Status(), action, shellID, description, width) - if opts.Simple { + if opts.Compact { return header } @@ -214,7 +214,7 @@ func renderJobTool(sty *styles.Styles, opts *ToolRenderOpts, width int, action, } bodyWidth := width - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } diff --git a/internal/ui/chat/diagnostics.go b/internal/ui/chat/diagnostics.go index 6d59141c651a4235cee3a65e97f456cd95c7dc77..8ca5436b9082033a9cbb0debedffec041833ea11 100644 --- a/internal/ui/chat/diagnostics.go +++ b/internal/ui/chat/diagnostics.go @@ -49,8 +49,8 @@ func (d *DiagnosticsToolRenderContext) RenderTool(sty *styles.Styles, width int, mainParam = fsext.PrettyPath(params.FilePath) } - header := toolHeader(sty, opts.Status(), "Diagnostics", cappedWidth, opts.Simple, mainParam) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Diagnostics", cappedWidth, opts.Compact, mainParam) + if opts.Compact { return header } @@ -63,6 +63,6 @@ func (d *DiagnosticsToolRenderContext) RenderTool(sty *styles.Styles, width int, } bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } diff --git a/internal/ui/chat/fetch.go b/internal/ui/chat/fetch.go index 53f5c066060266eb9f2f88844eb1640daf05e245..41e35c90004a76337e8ce3d59908cadf32ed699f 100644 --- a/internal/ui/chat/fetch.go +++ b/internal/ui/chat/fetch.go @@ -52,8 +52,8 @@ func (f *FetchToolRenderContext) RenderTool(sty *styles.Styles, width int, opts toolParams = append(toolParams, "timeout", formatTimeout(params.Timeout)) } - header := toolHeader(sty, opts.Status(), "Fetch", cappedWidth, opts.Simple, toolParams...) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Fetch", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { return header } @@ -67,7 +67,7 @@ func (f *FetchToolRenderContext) RenderTool(sty *styles.Styles, width int, opts // Determine file extension for syntax highlighting based on format. file := getFileExtensionForFormat(params.Format) - body := toolOutputCodeContent(sty, file, opts.Result.Content, 0, cappedWidth, opts.Expanded) + body := toolOutputCodeContent(sty, file, opts.Result.Content, 0, cappedWidth, opts.ExpandedContent) return joinToolParts(header, body) } @@ -82,3 +82,111 @@ func getFileExtensionForFormat(format string) string { return "fetch.md" } } + +// ----------------------------------------------------------------------------- +// WebFetch Tool +// ----------------------------------------------------------------------------- + +// WebFetchToolMessageItem is a message item that represents a web_fetch tool call. +type WebFetchToolMessageItem struct { + *baseToolMessageItem +} + +var _ ToolMessageItem = (*WebFetchToolMessageItem)(nil) + +// NewWebFetchToolMessageItem creates a new [WebFetchToolMessageItem]. +func NewWebFetchToolMessageItem( + sty *styles.Styles, + toolCall message.ToolCall, + result *message.ToolResult, + canceled bool, +) ToolMessageItem { + return newBaseToolMessageItem(sty, toolCall, result, &WebFetchToolRenderContext{}, canceled) +} + +// WebFetchToolRenderContext renders web_fetch tool messages. +type WebFetchToolRenderContext struct{} + +// RenderTool implements the [ToolRenderer] interface. +func (w *WebFetchToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string { + cappedWidth := cappedMessageWidth(width) + if !opts.ToolCall.Finished && !opts.Canceled { + return pendingTool(sty, "Fetch", opts.Anim) + } + + var params tools.WebFetchParams + if err := json.Unmarshal([]byte(opts.ToolCall.Input), ¶ms); err != nil { + return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth) + } + + toolParams := []string{params.URL} + header := toolHeader(sty, opts.Status(), "Fetch", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { + return header + } + + if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok { + return joinToolParts(header, earlyState) + } + + if opts.Result == nil || opts.Result.Content == "" { + return header + } + + body := toolOutputMarkdownContent(sty, opts.Result.Content, cappedWidth, opts.ExpandedContent) + return joinToolParts(header, body) +} + +// ----------------------------------------------------------------------------- +// WebSearch Tool +// ----------------------------------------------------------------------------- + +// WebSearchToolMessageItem is a message item that represents a web_search tool call. +type WebSearchToolMessageItem struct { + *baseToolMessageItem +} + +var _ ToolMessageItem = (*WebSearchToolMessageItem)(nil) + +// NewWebSearchToolMessageItem creates a new [WebSearchToolMessageItem]. +func NewWebSearchToolMessageItem( + sty *styles.Styles, + toolCall message.ToolCall, + result *message.ToolResult, + canceled bool, +) ToolMessageItem { + return newBaseToolMessageItem(sty, toolCall, result, &WebSearchToolRenderContext{}, canceled) +} + +// WebSearchToolRenderContext renders web_search tool messages. +type WebSearchToolRenderContext struct{} + +// RenderTool implements the [ToolRenderer] interface. +func (w *WebSearchToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string { + cappedWidth := cappedMessageWidth(width) + if !opts.ToolCall.Finished && !opts.Canceled { + return pendingTool(sty, "Search", opts.Anim) + } + + var params tools.WebSearchParams + if err := json.Unmarshal([]byte(opts.ToolCall.Input), ¶ms); err != nil { + return toolErrorContent(sty, &message.ToolResult{Content: "Invalid parameters"}, cappedWidth) + } + + toolParams := []string{params.Query} + header := toolHeader(sty, opts.Status(), "Search", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { + return header + } + + if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok { + return joinToolParts(header, earlyState) + } + + if opts.Result == nil || opts.Result.Content == "" { + return header + } + + body := toolOutputMarkdownContent(sty, opts.Result.Content, cappedWidth, opts.ExpandedContent) + return joinToolParts(header, body) +} diff --git a/internal/ui/chat/file.go b/internal/ui/chat/file.go index 2c0a1e49440ed263735afbd0ab3104034a2f4634..ca0e0b4934e806bbed0c7826161bb2c91a10843f 100644 --- a/internal/ui/chat/file.go +++ b/internal/ui/chat/file.go @@ -56,8 +56,8 @@ func (v *ViewToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * toolParams = append(toolParams, "offset", fmt.Sprintf("%d", params.Offset)) } - header := toolHeader(sty, opts.Status(), "View", cappedWidth, opts.Simple, toolParams...) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "View", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { return header } @@ -87,7 +87,7 @@ func (v *ViewToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * } // Render code content with syntax highlighting. - body := toolOutputCodeContent(sty, params.FilePath, content, params.Offset, cappedWidth, opts.Expanded) + body := toolOutputCodeContent(sty, params.FilePath, content, params.Offset, cappedWidth, opts.ExpandedContent) return joinToolParts(header, body) } @@ -128,8 +128,8 @@ func (w *WriteToolRenderContext) RenderTool(sty *styles.Styles, width int, opts } file := fsext.PrettyPath(params.FilePath) - header := toolHeader(sty, opts.Status(), "Write", cappedWidth, opts.Simple, file) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Write", cappedWidth, opts.Compact, file) + if opts.Compact { return header } @@ -142,7 +142,7 @@ func (w *WriteToolRenderContext) RenderTool(sty *styles.Styles, width int, opts } // Render code content with syntax highlighting. - body := toolOutputCodeContent(sty, params.FilePath, params.Content, 0, cappedWidth, opts.Expanded) + body := toolOutputCodeContent(sty, params.FilePath, params.Content, 0, cappedWidth, opts.ExpandedContent) return joinToolParts(header, body) } @@ -183,8 +183,8 @@ func (e *EditToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * } file := fsext.PrettyPath(params.FilePath) - header := toolHeader(sty, opts.Status(), "Edit", width, opts.Simple, file) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Edit", width, opts.Compact, file) + if opts.Compact { return header } @@ -200,12 +200,12 @@ func (e *EditToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * var meta tools.EditResponseMetadata if err := json.Unmarshal([]byte(opts.Result.Metadata), &meta); err != nil { bodyWidth := width - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } // Render diff. - body := toolOutputDiffContent(sty, file, meta.OldContent, meta.NewContent, width, opts.Expanded) + body := toolOutputDiffContent(sty, file, meta.OldContent, meta.NewContent, width, opts.ExpandedContent) return joinToolParts(header, body) } @@ -251,8 +251,8 @@ func (m *MultiEditToolRenderContext) RenderTool(sty *styles.Styles, width int, o toolParams = append(toolParams, "edits", fmt.Sprintf("%d", len(params.Edits))) } - header := toolHeader(sty, opts.Status(), "Multi-Edit", width, opts.Simple, toolParams...) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Multi-Edit", width, opts.Compact, toolParams...) + if opts.Compact { return header } @@ -268,12 +268,12 @@ func (m *MultiEditToolRenderContext) RenderTool(sty *styles.Styles, width int, o var meta tools.MultiEditResponseMetadata if err := json.Unmarshal([]byte(opts.Result.Metadata), &meta); err != nil { bodyWidth := width - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } // Render diff with optional failed edits note. - body := toolOutputMultiEditDiffContent(sty, file, meta, len(params.Edits), width, opts.Expanded) + body := toolOutputMultiEditDiffContent(sty, file, meta, len(params.Edits), width, opts.ExpandedContent) return joinToolParts(header, body) } @@ -321,8 +321,8 @@ func (d *DownloadToolRenderContext) RenderTool(sty *styles.Styles, width int, op toolParams = append(toolParams, "timeout", formatTimeout(params.Timeout)) } - header := toolHeader(sty, opts.Status(), "Download", cappedWidth, opts.Simple, toolParams...) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Download", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { return header } @@ -335,6 +335,6 @@ func (d *DownloadToolRenderContext) RenderTool(sty *styles.Styles, width int, op } bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } diff --git a/internal/ui/chat/messages.go b/internal/ui/chat/messages.go index f75a9f9328ff01809208bc30a58b3531e766033d..da55faa10c842f7ad66ad824f69564694855bb50 100644 --- a/internal/ui/chat/messages.go +++ b/internal/ui/chat/messages.go @@ -171,6 +171,7 @@ func ExtractMessageItems(sty *styles.Styles, msg *message.Message, toolResults m } items = append(items, NewToolMessageItem( sty, + msg.ID, tc, result, msg.FinishReason() == message.FinishReasonCanceled, diff --git a/internal/ui/chat/search.go b/internal/ui/chat/search.go index cd19eeef5712761a5c533686297f7baec65b87de..3430d7d5c8aebe6e93979284f659e74b60316ca9 100644 --- a/internal/ui/chat/search.go +++ b/internal/ui/chat/search.go @@ -50,8 +50,8 @@ func (g *GlobToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * toolParams = append(toolParams, "path", params.Path) } - header := toolHeader(sty, opts.Status(), "Glob", cappedWidth, opts.Simple, toolParams...) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Glob", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { return header } @@ -64,7 +64,7 @@ func (g *GlobToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * } bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } @@ -115,8 +115,8 @@ func (g *GrepToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * toolParams = append(toolParams, "literal", "true") } - header := toolHeader(sty, opts.Status(), "Grep", cappedWidth, opts.Simple, toolParams...) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Grep", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { return header } @@ -129,7 +129,7 @@ func (g *GrepToolRenderContext) RenderTool(sty *styles.Styles, width int, opts * } bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } @@ -175,8 +175,8 @@ func (l *LSToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *To } path = fsext.PrettyPath(path) - header := toolHeader(sty, opts.Status(), "List", cappedWidth, opts.Simple, path) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "List", cappedWidth, opts.Compact, path) + if opts.Compact { return header } @@ -189,7 +189,7 @@ func (l *LSToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *To } bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } @@ -237,8 +237,8 @@ func (s *SourcegraphToolRenderContext) RenderTool(sty *styles.Styles, width int, toolParams = append(toolParams, "context", formatNonZero(params.ContextWindow)) } - header := toolHeader(sty, opts.Status(), "Sourcegraph", cappedWidth, opts.Simple, toolParams...) - if opts.Simple { + header := toolHeader(sty, opts.Status(), "Sourcegraph", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { return header } @@ -251,6 +251,6 @@ func (s *SourcegraphToolRenderContext) RenderTool(sty *styles.Styles, width int, } bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.Expanded)) + body := sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) return joinToolParts(header, body) } diff --git a/internal/ui/chat/todos.go b/internal/ui/chat/todos.go new file mode 100644 index 0000000000000000000000000000000000000000..3f92de9b32287270298b8a20c463850a32d110b5 --- /dev/null +++ b/internal/ui/chat/todos.go @@ -0,0 +1,192 @@ +package chat + +import ( + "encoding/json" + "fmt" + "slices" + "strings" + + "github.com/charmbracelet/crush/internal/agent/tools" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/session" + "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/x/ansi" +) + +// ----------------------------------------------------------------------------- +// Todos Tool +// ----------------------------------------------------------------------------- + +// TodosToolMessageItem is a message item that represents a todos tool call. +type TodosToolMessageItem struct { + *baseToolMessageItem +} + +var _ ToolMessageItem = (*TodosToolMessageItem)(nil) + +// NewTodosToolMessageItem creates a new [TodosToolMessageItem]. +func NewTodosToolMessageItem( + sty *styles.Styles, + toolCall message.ToolCall, + result *message.ToolResult, + canceled bool, +) ToolMessageItem { + return newBaseToolMessageItem(sty, toolCall, result, &TodosToolRenderContext{}, canceled) +} + +// TodosToolRenderContext renders todos tool messages. +type TodosToolRenderContext struct{} + +// RenderTool implements the [ToolRenderer] interface. +func (t *TodosToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string { + cappedWidth := cappedMessageWidth(width) + if !opts.ToolCall.Finished && !opts.Canceled { + return pendingTool(sty, "To-Do", opts.Anim) + } + + var params tools.TodosParams + var meta tools.TodosResponseMetadata + var headerText string + var body string + + // Parse params for pending state (before result is available). + if err := json.Unmarshal([]byte(opts.ToolCall.Input), ¶ms); err == nil { + completedCount := 0 + inProgressTask := "" + for _, todo := range params.Todos { + if todo.Status == "completed" { + completedCount++ + } + if todo.Status == "in_progress" { + if todo.ActiveForm != "" { + inProgressTask = todo.ActiveForm + } else { + inProgressTask = todo.Content + } + } + } + + // Default display from params (used when pending or no metadata). + ratio := sty.Tool.TodoRatio.Render(fmt.Sprintf("%d/%d", completedCount, len(params.Todos))) + headerText = ratio + if inProgressTask != "" { + headerText = fmt.Sprintf("%s · %s", ratio, inProgressTask) + } + + // If we have metadata, use it for richer display. + if opts.Result != nil && opts.Result.Metadata != "" { + if err := json.Unmarshal([]byte(opts.Result.Metadata), &meta); err == nil { + if meta.IsNew { + if meta.JustStarted != "" { + headerText = fmt.Sprintf("created %d todos, starting first", meta.Total) + } else { + headerText = fmt.Sprintf("created %d todos", meta.Total) + } + body = formatTodosList(sty, meta.Todos, styles.ArrowRightIcon, cappedWidth) + } else { + // Build header based on what changed. + hasCompleted := len(meta.JustCompleted) > 0 + hasStarted := meta.JustStarted != "" + allCompleted := meta.Completed == meta.Total + + ratio := sty.Tool.TodoRatio.Render(fmt.Sprintf("%d/%d", meta.Completed, meta.Total)) + if hasCompleted && hasStarted { + text := sty.Subtle.Render(fmt.Sprintf(" · completed %d, starting next", len(meta.JustCompleted))) + headerText = fmt.Sprintf("%s%s", ratio, text) + } else if hasCompleted { + text := sty.Subtle.Render(fmt.Sprintf(" · completed %d", len(meta.JustCompleted))) + if allCompleted { + text = sty.Subtle.Render(" · completed all") + } + headerText = fmt.Sprintf("%s%s", ratio, text) + } else if hasStarted { + headerText = fmt.Sprintf("%s%s", ratio, sty.Subtle.Render(" · starting task")) + } else { + headerText = ratio + } + + // Build body with details. + if allCompleted { + // Show all todos when all are completed, like when created. + body = formatTodosList(sty, meta.Todos, styles.ArrowRightIcon, cappedWidth) + } else if meta.JustStarted != "" { + body = sty.Tool.TodoInProgressIcon.Render(styles.ArrowRightIcon+" ") + + sty.Base.Render(meta.JustStarted) + } + } + } + } + } + + toolParams := []string{headerText} + header := toolHeader(sty, opts.Status(), "To-Do", cappedWidth, opts.Compact, toolParams...) + if opts.Compact { + return header + } + + if earlyState, ok := toolEarlyStateContent(sty, opts, cappedWidth); ok { + return joinToolParts(header, earlyState) + } + + if body == "" { + return header + } + + return joinToolParts(header, sty.Tool.Body.Render(body)) +} + +// formatTodosList formats a list of todos for display. +func formatTodosList(sty *styles.Styles, todos []session.Todo, inProgressIcon string, width int) string { + if len(todos) == 0 { + return "" + } + + sorted := make([]session.Todo, len(todos)) + copy(sorted, todos) + sortTodos(sorted) + + var lines []string + for _, todo := range sorted { + var prefix string + textStyle := sty.Base + + switch todo.Status { + case session.TodoStatusCompleted: + prefix = sty.Tool.TodoCompletedIcon.Render(styles.TodoCompletedIcon) + " " + case session.TodoStatusInProgress: + prefix = sty.Tool.TodoInProgressIcon.Render(inProgressIcon + " ") + default: + prefix = sty.Tool.TodoPendingIcon.Render(styles.TodoPendingIcon) + " " + } + + text := todo.Content + if todo.Status == session.TodoStatusInProgress && todo.ActiveForm != "" { + text = todo.ActiveForm + } + line := prefix + textStyle.Render(text) + line = ansi.Truncate(line, width, "…") + + lines = append(lines, line) + } + + return strings.Join(lines, "\n") +} + +// sortTodos sorts todos by status: completed, in_progress, pending. +func sortTodos(todos []session.Todo) { + slices.SortStableFunc(todos, func(a, b session.Todo) int { + return statusOrder(a.Status) - statusOrder(b.Status) + }) +} + +// statusOrder returns the sort order for a todo status. +func statusOrder(s session.TodoStatus) int { + switch s { + case session.TodoStatusCompleted: + return 0 + case session.TodoStatusInProgress: + return 1 + default: + return 2 + } +} diff --git a/internal/ui/chat/tools.go b/internal/ui/chat/tools.go index d621cad4c8a23f2cb638c2b708509feb4c1b64c7..d336e6405ba2129b4c173bdf5a90154ae6c9c23f 100644 --- a/internal/ui/chat/tools.go +++ b/internal/ui/chat/tools.go @@ -6,6 +6,8 @@ import ( tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" + "charm.land/lipgloss/v2/tree" + "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/ui/anim" @@ -38,14 +40,27 @@ type ToolMessageItem interface { ToolCall() message.ToolCall SetToolCall(tc message.ToolCall) SetResult(res *message.ToolResult) + MessageID() string + SetMessageID(id string) } -// Simplifiable is an interface for tool items that can render in a simplified mode. -// When simple mode is enabled, tools render as a compact single-line header. -type Simplifiable interface { - SetSimple(simple bool) +// Compactable is an interface for tool items that can render in a compacted mode. +// When compact mode is enabled, tools render as a compact single-line header. +type Compactable interface { + SetCompact(compact bool) } +// IsSpinningState contains the state passed to IsSpinningFn for custom spinning logic. +type IsSpinningState struct { + ToolCall message.ToolCall + Result *message.ToolResult + Canceled bool +} + +// IsSpinningFn is a function type for custom spinning logic. +// Returns true if the tool should show the spinning animation. +type IsSpinningFn func(state IsSpinningState) bool + // DefaultToolRenderContext implements the default [ToolRenderer] interface. type DefaultToolRenderContext struct{} @@ -60,8 +75,8 @@ type ToolRenderOpts struct { Result *message.ToolResult Canceled bool Anim *anim.Anim - Expanded bool - Simple bool + ExpandedContent bool + Compact bool IsSpinning bool PermissionRequested bool PermissionGranted bool @@ -106,18 +121,22 @@ type baseToolMessageItem struct { toolRenderer ToolRenderer toolCall message.ToolCall result *message.ToolResult + messageID string canceled bool permissionRequested bool permissionGranted bool // we use this so we can efficiently cache // tools that have a capped width (e.x bash.. and others) hasCappedWidth bool - // isSimple indicates this tool should render in simplified/compact mode. - isSimple bool + // isCompact indicates this tool should render in compact mode. + isCompact bool + // isSpinningFn allows tools to override the default spinning logic. + // If nil, uses the default: !toolCall.Finished && !canceled. + isSpinningFn IsSpinningFn - sty *styles.Styles - anim *anim.Anim - expanded bool + sty *styles.Styles + anim *anim.Anim + expandedContent bool } // newBaseToolMessageItem is the internal constructor for base tool message items. @@ -157,45 +176,58 @@ func newBaseToolMessageItem( // NewToolMessageItem creates a new [ToolMessageItem] based on the tool call name. // // It returns a specific tool message item type if implemented, otherwise it -// returns a generic tool message item. +// returns a generic tool message item. The messageID is the ID of the assistant +// message containing this tool call. func NewToolMessageItem( sty *styles.Styles, + messageID string, toolCall message.ToolCall, result *message.ToolResult, canceled bool, ) ToolMessageItem { + var item ToolMessageItem switch toolCall.Name { case tools.BashToolName: - return NewBashToolMessageItem(sty, toolCall, result, canceled) + item = NewBashToolMessageItem(sty, toolCall, result, canceled) case tools.JobOutputToolName: - return NewJobOutputToolMessageItem(sty, toolCall, result, canceled) + item = NewJobOutputToolMessageItem(sty, toolCall, result, canceled) case tools.JobKillToolName: - return NewJobKillToolMessageItem(sty, toolCall, result, canceled) + item = NewJobKillToolMessageItem(sty, toolCall, result, canceled) case tools.ViewToolName: - return NewViewToolMessageItem(sty, toolCall, result, canceled) + item = NewViewToolMessageItem(sty, toolCall, result, canceled) case tools.WriteToolName: - return NewWriteToolMessageItem(sty, toolCall, result, canceled) + item = NewWriteToolMessageItem(sty, toolCall, result, canceled) case tools.EditToolName: - return NewEditToolMessageItem(sty, toolCall, result, canceled) + item = NewEditToolMessageItem(sty, toolCall, result, canceled) case tools.MultiEditToolName: - return NewMultiEditToolMessageItem(sty, toolCall, result, canceled) + item = NewMultiEditToolMessageItem(sty, toolCall, result, canceled) case tools.GlobToolName: - return NewGlobToolMessageItem(sty, toolCall, result, canceled) + item = NewGlobToolMessageItem(sty, toolCall, result, canceled) case tools.GrepToolName: - return NewGrepToolMessageItem(sty, toolCall, result, canceled) + item = NewGrepToolMessageItem(sty, toolCall, result, canceled) case tools.LSToolName: - return NewLSToolMessageItem(sty, toolCall, result, canceled) + item = NewLSToolMessageItem(sty, toolCall, result, canceled) case tools.DownloadToolName: - return NewDownloadToolMessageItem(sty, toolCall, result, canceled) + item = NewDownloadToolMessageItem(sty, toolCall, result, canceled) case tools.FetchToolName: - return NewFetchToolMessageItem(sty, toolCall, result, canceled) + item = NewFetchToolMessageItem(sty, toolCall, result, canceled) case tools.SourcegraphToolName: - return NewSourcegraphToolMessageItem(sty, toolCall, result, canceled) + item = NewSourcegraphToolMessageItem(sty, toolCall, result, canceled) case tools.DiagnosticsToolName: - return NewDiagnosticsToolMessageItem(sty, toolCall, result, canceled) + item = NewDiagnosticsToolMessageItem(sty, toolCall, result, canceled) + case agent.AgentToolName: + item = NewAgentToolMessageItem(sty, toolCall, result, canceled) + case tools.AgenticFetchToolName: + item = NewAgenticFetchToolMessageItem(sty, toolCall, result, canceled) + case tools.WebFetchToolName: + item = NewWebFetchToolMessageItem(sty, toolCall, result, canceled) + case tools.WebSearchToolName: + item = NewWebSearchToolMessageItem(sty, toolCall, result, canceled) + case tools.TodosToolName: + item = NewTodosToolMessageItem(sty, toolCall, result, canceled) default: // TODO: Implement other tool items - return newBaseToolMessageItem( + item = newBaseToolMessageItem( sty, toolCall, result, @@ -203,11 +235,13 @@ func NewToolMessageItem( canceled, ) } + item.SetMessageID(messageID) + return item } -// SetSimple implements the Simplifiable interface. -func (t *baseToolMessageItem) SetSimple(simple bool) { - t.isSimple = simple +// SetCompact implements the Compactable interface. +func (t *baseToolMessageItem) SetCompact(compact bool) { + t.isCompact = compact t.clearCache() } @@ -243,6 +277,10 @@ func (t *baseToolMessageItem) Render(width int) string { style = t.sty.Chat.Message.ToolCallFocused } + if t.isCompact { + style = t.sty.Chat.Message.ToolCallCompact + } + content, height, ok := t.getCachedRender(toolItemWidth) // if we are spinning or there is no cache rerender if !ok || t.isSpinning() { @@ -251,8 +289,8 @@ func (t *baseToolMessageItem) Render(width int) string { Result: t.result, Canceled: t.canceled, Anim: t.anim, - Expanded: t.expanded, - Simple: t.isSimple, + ExpandedContent: t.expandedContent, + Compact: t.isCompact, PermissionRequested: t.permissionRequested, PermissionGranted: t.permissionGranted, IsSpinning: t.isSpinning(), @@ -283,6 +321,16 @@ func (t *baseToolMessageItem) SetResult(res *message.ToolResult) { t.clearCache() } +// MessageID returns the ID of the message containing this tool call. +func (t *baseToolMessageItem) MessageID() string { + return t.messageID +} + +// SetMessageID sets the ID of the message containing this tool call. +func (t *baseToolMessageItem) SetMessageID(id string) { + t.messageID = id +} + // SetPermissionRequested sets whether permission has been requested for this tool call. // TODO: Consider merging with SetPermissionGranted and add an interface for // permission management. @@ -301,12 +349,24 @@ func (t *baseToolMessageItem) SetPermissionGranted(granted bool) { // isSpinning returns true if the tool should show animation. func (t *baseToolMessageItem) isSpinning() bool { + if t.isSpinningFn != nil { + return t.isSpinningFn(IsSpinningState{ + ToolCall: t.toolCall, + Result: t.result, + Canceled: t.canceled, + }) + } return !t.toolCall.Finished && !t.canceled } +// SetIsSpinningFn sets a custom function to determine if the tool should spin. +func (t *baseToolMessageItem) SetIsSpinningFn(fn IsSpinningFn) { + t.isSpinningFn = fn +} + // ToggleExpanded toggles the expanded state of the thinking box. func (t *baseToolMessageItem) ToggleExpanded() { - t.expanded = !t.expanded + t.expandedContent = !t.expandedContent t.clearCache() } @@ -654,3 +714,62 @@ func toolOutputMultiEditDiffContent(sty *styles.Styles, file string, meta tools. return sty.Tool.Body.Render(formatted) } + +// roundedEnumerator creates a tree enumerator with rounded corners. +func roundedEnumerator(lPadding, width int) tree.Enumerator { + if width == 0 { + width = 2 + } + if lPadding == 0 { + lPadding = 1 + } + return func(children tree.Children, index int) string { + line := strings.Repeat("─", width) + padding := strings.Repeat(" ", lPadding) + if children.Length()-1 == index { + return padding + "╰" + line + } + return padding + "├" + line + } +} + +// toolOutputMarkdownContent renders markdown content with optional truncation. +func toolOutputMarkdownContent(sty *styles.Styles, content string, width int, expanded bool) string { + content = strings.ReplaceAll(content, "\r\n", "\n") + content = strings.ReplaceAll(content, "\t", " ") + content = strings.TrimSpace(content) + + // Cap width for readability. + if width > 120 { + width = 120 + } + + renderer := common.PlainMarkdownRenderer(sty, width) + rendered, err := renderer.Render(content) + if err != nil { + return toolOutputPlainContent(sty, content, width, expanded) + } + + lines := strings.Split(rendered, "\n") + maxLines := responseContextHeight + if expanded { + maxLines = len(lines) + } + + var out []string + for i, ln := range lines { + if i >= maxLines { + break + } + out = append(out, ln) + } + + if len(lines) > maxLines && !expanded { + out = append(out, sty.Tool.ContentTruncation. + Width(width). + Render(fmt.Sprintf(assistantMessageTruncateFormat, len(lines)-maxLines)), + ) + } + + return sty.Tool.Body.Render(strings.Join(out, "\n")) +} diff --git a/internal/ui/model/chat.go b/internal/ui/model/chat.go index 9c11a2d512d8355d66ad71c72db1eda078ecf85c..76a82a7d7242b4b089381685763750ee762c043c 100644 --- a/internal/ui/model/chat.go +++ b/internal/ui/model/chat.go @@ -77,6 +77,12 @@ func (m *Chat) SetMessages(msgs ...chat.MessageItem) { items := make([]list.Item, len(msgs)) for i, msg := range msgs { m.idInxMap[msg.ID()] = i + // Register nested tool IDs for tools that contain nested tools. + if container, ok := msg.(chat.NestedToolContainer); ok { + for _, nested := range container.NestedTools() { + m.idInxMap[nested.ID()] = i + } + } items[i] = msg } m.list.SetItems(items...) @@ -89,11 +95,41 @@ func (m *Chat) AppendMessages(msgs ...chat.MessageItem) { indexOffset := m.list.Len() for i, msg := range msgs { m.idInxMap[msg.ID()] = indexOffset + i + // Register nested tool IDs for tools that contain nested tools. + if container, ok := msg.(chat.NestedToolContainer); ok { + for _, nested := range container.NestedTools() { + m.idInxMap[nested.ID()] = indexOffset + i + } + } items[i] = msg } m.list.AppendItems(items...) } +// UpdateNestedToolIDs updates the ID map for nested tools within a container. +// Call this after modifying nested tools to ensure animations work correctly. +func (m *Chat) UpdateNestedToolIDs(containerID string) { + idx, ok := m.idInxMap[containerID] + if !ok { + return + } + + item, ok := m.list.ItemAt(idx).(chat.MessageItem) + if !ok { + return + } + + container, ok := item.(chat.NestedToolContainer) + if !ok { + return + } + + // Register all nested tool IDs to point to the container's index. + for _, nested := range container.NestedTools() { + m.idInxMap[nested.ID()] = idx + } +} + // Animate animates items in the chat list. Only propagates animation messages // to visible items to save CPU. When items are not visible, their animation ID // is tracked so it can be restarted when they become visible again. diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 6072c4292c07aebbd92a1329944bc20b9826bb48..3d7499e519a24807b43677d2a5a0043b2e57d250 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -199,8 +199,15 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case pubsub.Event[message.Message]: - // TODO: handle nested messages for agentic tools - if m.session == nil || msg.Payload.SessionID != m.session.ID { + // Check if this is a child session message for an agent tool. + if m.session == nil { + break + } + if msg.Payload.SessionID != m.session.ID { + // This might be a child session message from an agent tool. + if cmd := m.handleChildSessionMessage(msg); cmd != nil { + cmds = append(cmds, cmd) + } break } switch msg.Type { @@ -374,6 +381,9 @@ func (m *UI) setSessionMessages(msgs []message.Message) tea.Cmd { items = append(items, chat.ExtractMessageItems(m.com.Styles, msg, toolResultMap)...) } + // Load nested tool calls for agent/agentic_fetch tools. + m.loadNestedToolCalls(items) + // If the user switches between sessions while the agent is working we want // to make sure the animations are shown. for _, item := range items { @@ -392,6 +402,64 @@ func (m *UI) setSessionMessages(msgs []message.Message) tea.Cmd { return tea.Batch(cmds...) } +// loadNestedToolCalls recursively loads nested tool calls for agent/agentic_fetch tools. +func (m *UI) loadNestedToolCalls(items []chat.MessageItem) { + for _, item := range items { + nestedContainer, ok := item.(chat.NestedToolContainer) + if !ok { + continue + } + toolItem, ok := item.(chat.ToolMessageItem) + if !ok { + continue + } + + tc := toolItem.ToolCall() + messageID := toolItem.MessageID() + + // Get the agent tool session ID. + agentSessionID := m.com.App.Sessions.CreateAgentToolSessionID(messageID, tc.ID) + + // Fetch nested messages. + nestedMsgs, err := m.com.App.Messages.List(context.Background(), agentSessionID) + if err != nil || len(nestedMsgs) == 0 { + continue + } + + // Build tool result map for nested messages. + nestedMsgPtrs := make([]*message.Message, len(nestedMsgs)) + for i := range nestedMsgs { + nestedMsgPtrs[i] = &nestedMsgs[i] + } + nestedToolResultMap := chat.BuildToolResultMap(nestedMsgPtrs) + + // Extract nested tool items. + var nestedTools []chat.ToolMessageItem + for _, nestedMsg := range nestedMsgPtrs { + nestedItems := chat.ExtractMessageItems(m.com.Styles, nestedMsg, nestedToolResultMap) + for _, nestedItem := range nestedItems { + if nestedToolItem, ok := nestedItem.(chat.ToolMessageItem); ok { + // Mark nested tools as simple (compact) rendering. + if simplifiable, ok := nestedToolItem.(chat.Compactable); ok { + simplifiable.SetCompact(true) + } + nestedTools = append(nestedTools, nestedToolItem) + } + } + } + + // Recursively load nested tool calls for any agent tools within. + nestedMessageItems := make([]chat.MessageItem, len(nestedTools)) + for i, nt := range nestedTools { + nestedMessageItems[i] = nt + } + m.loadNestedToolCalls(nestedMessageItems) + + // Set nested tools on the parent. + nestedContainer.SetNestedTools(nestedTools) + } +} + // appendSessionMessage appends a new message to the current session in the chat // if the message is a tool result it will update the corresponding tool call message func (m *UI) appendSessionMessage(msg message.Message) tea.Cmd { @@ -455,7 +523,7 @@ func (m *UI) updateSessionMessage(msg message.Message) tea.Cmd { } } if existingToolItem == nil { - items = append(items, chat.NewToolMessageItem(m.com.Styles, tc, nil, false)) + items = append(items, chat.NewToolMessageItem(m.com.Styles, msg.ID, tc, nil, false)) } } @@ -474,6 +542,92 @@ func (m *UI) updateSessionMessage(msg message.Message) tea.Cmd { return tea.Batch(cmds...) } +// handleChildSessionMessage handles messages from child sessions (agent tools). +func (m *UI) handleChildSessionMessage(event pubsub.Event[message.Message]) tea.Cmd { + var cmds []tea.Cmd + + // Only process messages with tool calls or results. + if len(event.Payload.ToolCalls()) == 0 && len(event.Payload.ToolResults()) == 0 { + return nil + } + + // Check if this is an agent tool session and parse it. + childSessionID := event.Payload.SessionID + _, toolCallID, ok := m.com.App.Sessions.ParseAgentToolSessionID(childSessionID) + if !ok { + return nil + } + + // Find the parent agent tool item. + var agentItem chat.NestedToolContainer + for i := 0; i < m.chat.Len(); i++ { + item := m.chat.MessageItem(toolCallID) + if item == nil { + continue + } + if agent, ok := item.(chat.NestedToolContainer); ok { + if toolMessageItem, ok := item.(chat.ToolMessageItem); ok { + if toolMessageItem.ToolCall().ID == toolCallID { + // Verify this agent belongs to the correct parent message. + // We can't directly check parentMessageID on the item, so we trust the session parsing. + agentItem = agent + break + } + } + } + } + + if agentItem == nil { + return nil + } + + // Get existing nested tools. + nestedTools := agentItem.NestedTools() + + // Update or create nested tool calls. + for _, tc := range event.Payload.ToolCalls() { + found := false + for _, existingTool := range nestedTools { + if existingTool.ToolCall().ID == tc.ID { + existingTool.SetToolCall(tc) + found = true + break + } + } + if !found { + // Create a new nested tool item. + nestedItem := chat.NewToolMessageItem(m.com.Styles, event.Payload.ID, tc, nil, false) + if simplifiable, ok := nestedItem.(chat.Compactable); ok { + simplifiable.SetCompact(true) + } + if animatable, ok := nestedItem.(chat.Animatable); ok { + if cmd := animatable.StartAnimation(); cmd != nil { + cmds = append(cmds, cmd) + } + } + nestedTools = append(nestedTools, nestedItem) + } + } + + // Update nested tool results. + for _, tr := range event.Payload.ToolResults() { + for _, nestedTool := range nestedTools { + if nestedTool.ToolCall().ID == tr.ToolCallID { + nestedTool.SetResult(&tr) + break + } + } + } + + // Update the agent item with the new nested tools. + agentItem.SetNestedTools(nestedTools) + + // Update the chat so it updates the index map for animations to work as expected + m.chat.UpdateNestedToolIDs(toolCallID) + + return tea.Batch(cmds...) +} + func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { var cmds []tea.Cmd diff --git a/internal/ui/styles/styles.go b/internal/ui/styles/styles.go index a0c7ad418a0c8d4dbeb041cfd48d5a16fd110622..d7d8d6d8a38432b77b8ce49e7d04961a6e489cc8 100644 --- a/internal/ui/styles/styles.go +++ b/internal/ui/styles/styles.go @@ -210,6 +210,7 @@ type Styles struct { ErrorDetails lipgloss.Style Attachment lipgloss.Style ToolCallFocused lipgloss.Style + ToolCallCompact lipgloss.Style ToolCallBlurred lipgloss.Style SectionHeader lipgloss.Style @@ -277,6 +278,12 @@ type Styles struct { // Agent task styles AgentTaskTag lipgloss.Style // Agent task tag (blue background, bold) AgentPrompt lipgloss.Style // Agent prompt text + + // Todo styles + TodoRatio lipgloss.Style // Todo ratio (e.g., "2/5") + TodoCompletedIcon lipgloss.Style // Completed todo icon + TodoInProgressIcon lipgloss.Style // In-progress todo icon + TodoPendingIcon lipgloss.Style // Pending todo icon } // Dialog styles @@ -1005,6 +1012,12 @@ func DefaultStyles() Styles { s.Tool.AgentTaskTag = base.Bold(true).Padding(0, 1).MarginLeft(2).Background(blueLight).Foreground(white) s.Tool.AgentPrompt = s.Muted + // Todo styles + s.Tool.TodoRatio = base.Foreground(blueDark) + s.Tool.TodoCompletedIcon = base.Foreground(green) + s.Tool.TodoInProgressIcon = base.Foreground(greenDark) + s.Tool.TodoPendingIcon = base.Foreground(fgMuted) + // Buttons s.ButtonFocus = lipgloss.NewStyle().Foreground(white).Background(secondary) s.ButtonBlur = s.Base.Background(bgSubtle) @@ -1082,6 +1095,8 @@ func DefaultStyles() Styles { BorderLeft(true). BorderForeground(greenDark) s.Chat.Message.ToolCallBlurred = s.Muted.PaddingLeft(2) + // No padding or border for compact tool calls within messages + s.Chat.Message.ToolCallCompact = s.Muted s.Chat.Message.SectionHeader = s.Base.PaddingLeft(2) // Thinking section styles