refactor(ui): chat: abstract tool message rendering

Ayman Bagabas created

Change summary

internal/ui/chat/bash.go     |  33 +++++++++++
internal/ui/chat/messages.go |  12 ----
internal/ui/chat/tools.go    | 106 ++++++++++++++++++++++++++++---------
internal/ui/model/ui.go      |   4 
4 files changed, 113 insertions(+), 42 deletions(-)

Detailed changes

internal/ui/chat/bash.go 🔗

@@ -5,11 +5,40 @@ import (
 	"strings"
 
 	"github.com/charmbracelet/crush/internal/agent/tools"
+	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/ui/styles"
 )
 
-// BashToolRenderer renders a bash tool call.
-func BashToolRenderer(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+// BashToolMessageItem is a message item that represents a bash tool call.
+type BashToolMessageItem struct {
+	*baseToolMessageItem
+}
+
+var _ ToolMessageItem = (*BashToolMessageItem)(nil)
+
+// NewBashToolMessageItem creates a new [BashToolMessageItem].
+func NewBashToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	return newBaseToolMessageItem(
+		sty,
+		toolCall,
+		result,
+		&BashToolRenderContext{},
+		canceled,
+	)
+}
+
+// BashToolRenderContext holds context for rendering bash tool messages.
+//
+// It implements the [ToolRenderer] interface.
+type BashToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (b *BashToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
 	cappedWidth := cappedMessageWidth(width)
 	const toolName = "Bash"
 	if !opts.ToolCall.Finished && !opts.Canceled {

internal/ui/chat/messages.go 🔗

@@ -8,7 +8,6 @@ import (
 	"strings"
 
 	tea "charm.land/bubbletea/v2"
-	"github.com/charmbracelet/crush/internal/agent/tools"
 	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/ui/anim"
 	"github.com/charmbracelet/crush/internal/ui/list"
@@ -182,17 +181,6 @@ func ExtractMessageItems(sty *styles.Styles, msg *message.Message, toolResults m
 	return []MessageItem{}
 }
 
-// ToolRenderer returns the appropriate [ToolRenderFunc] for a given tool call.
-// this should be used for nested tools as well.
-func ToolRenderer(tc message.ToolCall) ToolRenderFunc {
-	switch tc.Name {
-	case tools.BashToolName:
-		return BashToolRenderer
-	default:
-		return DefaultToolRenderer
-	}
-}
-
 // shouldRenderAssistantMessage determines if an assistant message should be rendered
 //
 // In some cases the assistant message only has tools so we do not want to render an

internal/ui/chat/tools.go 🔗

@@ -30,6 +30,23 @@ const (
 	ToolStatusCanceled
 )
 
+// ToolMessageItem represents a tool call message in the chat UI.
+type ToolMessageItem interface {
+	MessageItem
+
+	ToolCall() message.ToolCall
+	SetToolCall(tc message.ToolCall)
+	SetResult(res *message.ToolResult)
+}
+
+// DefaultToolRenderContext implements the default [ToolRenderer] interface.
+type DefaultToolRenderContext struct{}
+
+// RenderTool implements the [ToolRenderer] interface.
+func (d *DefaultToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	return "TODO: Implement Tool Renderer For: " + opts.ToolCall.Name
+}
+
 // ToolRenderOpts contains the data needed to render a tool call.
 type ToolRenderOpts struct {
 	ToolCall            message.ToolCall
@@ -60,21 +77,26 @@ func (opts *ToolRenderOpts) Status() ToolStatus {
 	return ToolStatusRunning
 }
 
-// ToolRenderFunc is a function that renders a tool call to a string.
-type ToolRenderFunc func(sty *styles.Styles, width int, t *ToolRenderOpts) string
+// ToolRenderer represents an interface for rendering tool calls.
+type ToolRenderer interface {
+	RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string
+}
 
-// DefaultToolRenderer is a placeholder renderer for tools without a custom renderer.
-func DefaultToolRenderer(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
-	return "TODO: Implement Tool Renderer For: " + opts.ToolCall.Name
+// ToolRendererFunc is a function type that implements the [ToolRenderer] interface.
+type ToolRendererFunc func(sty *styles.Styles, width int, opts *ToolRenderOpts) string
+
+// RenderTool implements the ToolRenderer interface.
+func (f ToolRendererFunc) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
+	return f(sty, width, opts)
 }
 
-// ToolMessageItem represents a tool call message that can be displayed in the UI.
-type ToolMessageItem struct {
+// baseToolMessageItem represents a tool call message that can be displayed in the UI.
+type baseToolMessageItem struct {
 	*highlightableMessageItem
 	*cachedMessageItem
 	*focusableMessageItem
 
-	renderFunc          ToolRenderFunc
+	toolRenderer        ToolRenderer
 	toolCall            message.ToolCall
 	result              *message.ToolResult
 	canceled            bool
@@ -89,21 +111,23 @@ type ToolMessageItem struct {
 	expanded bool
 }
 
-// NewToolMessageItem creates a new tool message item with the given renderFunc.
-func NewToolMessageItem(
+// newBaseToolMessageItem is the internal constructor for base tool message items.
+func newBaseToolMessageItem(
 	sty *styles.Styles,
 	toolCall message.ToolCall,
 	result *message.ToolResult,
+	toolRenderer ToolRenderer,
 	canceled bool,
-) *ToolMessageItem {
+) *baseToolMessageItem {
 	// we only do full width for diffs (as far as I know)
 	hasCappedWidth := toolCall.Name != tools.EditToolName && toolCall.Name != tools.MultiEditToolName
-	t := &ToolMessageItem{
+
+	t := &baseToolMessageItem{
 		highlightableMessageItem: defaultHighlighter(sty),
 		cachedMessageItem:        &cachedMessageItem{},
 		focusableMessageItem:     &focusableMessageItem{},
 		sty:                      sty,
-		renderFunc:               ToolRenderer(toolCall),
+		toolRenderer:             toolRenderer,
 		toolCall:                 toolCall,
 		result:                   result,
 		canceled:                 canceled,
@@ -117,16 +141,42 @@ func NewToolMessageItem(
 		LabelColor:  sty.FgBase,
 		CycleColors: true,
 	})
+
 	return t
 }
 
+// NewToolMessageItem creates a new [ToolMessageItem] based on the tool call name.
+//
+// It returns a specific tool message item type if implemented, otherwise it
+// returns a generic tool message item.
+func NewToolMessageItem(
+	sty *styles.Styles,
+	toolCall message.ToolCall,
+	result *message.ToolResult,
+	canceled bool,
+) ToolMessageItem {
+	switch toolCall.Name {
+	case tools.BashToolName:
+		return NewBashToolMessageItem(sty, toolCall, result, canceled)
+	default:
+		// TODO: Implement other tool items
+		return newBaseToolMessageItem(
+			sty,
+			toolCall,
+			result,
+			&DefaultToolRenderContext{},
+			canceled,
+		)
+	}
+}
+
 // ID returns the unique identifier for this tool message item.
-func (t *ToolMessageItem) ID() string {
+func (t *baseToolMessageItem) ID() string {
 	return t.toolCall.ID
 }
 
 // StartAnimation starts the assistant message animation if it should be spinning.
-func (t *ToolMessageItem) StartAnimation() tea.Cmd {
+func (t *baseToolMessageItem) StartAnimation() tea.Cmd {
 	if !t.isSpinning() {
 		return nil
 	}
@@ -134,7 +184,7 @@ func (t *ToolMessageItem) StartAnimation() tea.Cmd {
 }
 
 // Animate progresses the assistant message animation if it should be spinning.
-func (t *ToolMessageItem) Animate(msg anim.StepMsg) tea.Cmd {
+func (t *baseToolMessageItem) Animate(msg anim.StepMsg) tea.Cmd {
 	if !t.isSpinning() {
 		return nil
 	}
@@ -142,7 +192,7 @@ func (t *ToolMessageItem) Animate(msg anim.StepMsg) tea.Cmd {
 }
 
 // Render renders the tool message item at the given width.
-func (t *ToolMessageItem) Render(width int) string {
+func (t *baseToolMessageItem) Render(width int) string {
 	toolItemWidth := width - messageLeftPaddingTotal
 	if t.hasCappedWidth {
 		toolItemWidth = cappedMessageWidth(width)
@@ -155,7 +205,7 @@ func (t *ToolMessageItem) Render(width int) string {
 	content, height, ok := t.getCachedRender(toolItemWidth)
 	// if we are spinning or there is no cache rerender
 	if !ok || t.isSpinning() {
-		content = t.renderFunc(t.sty, toolItemWidth, &ToolRenderOpts{
+		content = t.toolRenderer.RenderTool(t.sty, toolItemWidth, &ToolRenderOpts{
 			ToolCall:            t.toolCall,
 			Result:              t.result,
 			Canceled:            t.canceled,
@@ -175,47 +225,51 @@ func (t *ToolMessageItem) Render(width int) string {
 }
 
 // ToolCall returns the tool call associated with this message item.
-func (t *ToolMessageItem) ToolCall() message.ToolCall {
+func (t *baseToolMessageItem) ToolCall() message.ToolCall {
 	return t.toolCall
 }
 
 // SetToolCall sets the tool call associated with this message item.
-func (t *ToolMessageItem) SetToolCall(tc message.ToolCall) {
+func (t *baseToolMessageItem) SetToolCall(tc message.ToolCall) {
 	t.toolCall = tc
 	t.clearCache()
 }
 
 // SetResult sets the tool result associated with this message item.
-func (t *ToolMessageItem) SetResult(res *message.ToolResult) {
+func (t *baseToolMessageItem) SetResult(res *message.ToolResult) {
 	t.result = res
 	t.clearCache()
 }
 
 // SetPermissionRequested sets whether permission has been requested for this tool call.
-func (t *ToolMessageItem) SetPermissionRequested(requested bool) {
+// TODO: Consider merging with SetPermissionGranted and add an interface for
+// permission management.
+func (t *baseToolMessageItem) SetPermissionRequested(requested bool) {
 	t.permissionRequested = requested
 	t.clearCache()
 }
 
 // SetPermissionGranted sets whether permission has been granted for this tool call.
-func (t *ToolMessageItem) SetPermissionGranted(granted bool) {
+// TODO: Consider merging with SetPermissionRequested and add an interface for
+// permission management.
+func (t *baseToolMessageItem) SetPermissionGranted(granted bool) {
 	t.permissionGranted = granted
 	t.clearCache()
 }
 
 // isSpinning returns true if the tool should show animation.
-func (t *ToolMessageItem) isSpinning() bool {
+func (t *baseToolMessageItem) isSpinning() bool {
 	return !t.toolCall.Finished && !t.canceled
 }
 
 // ToggleExpanded toggles the expanded state of the thinking box.
-func (t *ToolMessageItem) ToggleExpanded() {
+func (t *baseToolMessageItem) ToggleExpanded() {
 	t.expanded = !t.expanded
 	t.clearCache()
 }
 
 // HandleMouseClick implements MouseClickable.
-func (t *ToolMessageItem) HandleMouseClick(btn ansi.MouseButton, x, y int) bool {
+func (t *baseToolMessageItem) HandleMouseClick(btn ansi.MouseButton, x, y int) bool {
 	if btn != ansi.MouseLeft {
 		return false
 	}

internal/ui/model/ui.go 🔗

@@ -426,7 +426,7 @@ func (m *UI) appendSessionMessage(msg message.Message) tea.Cmd {
 				// we should have an item!
 				continue
 			}
-			if toolMsgItem, ok := toolItem.(*chat.ToolMessageItem); ok {
+			if toolMsgItem, ok := toolItem.(chat.ToolMessageItem); ok {
 				toolMsgItem.SetResult(&tr)
 			}
 		}
@@ -451,7 +451,7 @@ func (m *UI) updateSessionMessage(msg message.Message) tea.Cmd {
 	var items []chat.MessageItem
 	for _, tc := range msg.ToolCalls() {
 		existingToolItem := m.chat.MessageItem(tc.ID)
-		if toolItem, ok := existingToolItem.(*chat.ToolMessageItem); ok {
+		if toolItem, ok := existingToolItem.(chat.ToolMessageItem); ok {
 			existingToolCall := toolItem.ToolCall()
 			// only update if finished state changed or input changed
 			// to avoid clearing the cache