tools.go

  1package chat
  2
  3import (
  4	"fmt"
  5	"strings"
  6
  7	tea "charm.land/bubbletea/v2"
  8	"charm.land/lipgloss/v2"
  9	"github.com/charmbracelet/crush/internal/agent/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11	"github.com/charmbracelet/crush/internal/ui/anim"
 12	"github.com/charmbracelet/crush/internal/ui/common"
 13	"github.com/charmbracelet/crush/internal/ui/styles"
 14	"github.com/charmbracelet/x/ansi"
 15)
 16
 17// responseContextHeight limits the number of lines displayed in tool output.
 18const responseContextHeight = 10
 19
 20// toolBodyLeftPaddingTotal represents the padding that should be applied to each tool body
 21const toolBodyLeftPaddingTotal = 2
 22
 23// ToolStatus represents the current state of a tool call.
 24type ToolStatus int
 25
 26const (
 27	ToolStatusAwaitingPermission ToolStatus = iota
 28	ToolStatusRunning
 29	ToolStatusSuccess
 30	ToolStatusError
 31	ToolStatusCanceled
 32)
 33
 34// ToolMessageItem represents a tool call message in the chat UI.
 35type ToolMessageItem interface {
 36	MessageItem
 37
 38	ToolCall() message.ToolCall
 39	SetToolCall(tc message.ToolCall)
 40	SetResult(res *message.ToolResult)
 41}
 42
 43// DefaultToolRenderContext implements the default [ToolRenderer] interface.
 44type DefaultToolRenderContext struct{}
 45
 46// RenderTool implements the [ToolRenderer] interface.
 47func (d *DefaultToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
 48	return "TODO: Implement Tool Renderer For: " + opts.ToolCall.Name
 49}
 50
 51// ToolRenderOpts contains the data needed to render a tool call.
 52type ToolRenderOpts struct {
 53	ToolCall            message.ToolCall
 54	Result              *message.ToolResult
 55	Canceled            bool
 56	Anim                *anim.Anim
 57	Expanded            bool
 58	Nested              bool
 59	IsSpinning          bool
 60	PermissionRequested bool
 61	PermissionGranted   bool
 62}
 63
 64// Status returns the current status of the tool call.
 65func (opts *ToolRenderOpts) Status() ToolStatus {
 66	if opts.Canceled && opts.Result == nil {
 67		return ToolStatusCanceled
 68	}
 69	if opts.Result != nil {
 70		if opts.Result.IsError {
 71			return ToolStatusError
 72		}
 73		return ToolStatusSuccess
 74	}
 75	if opts.PermissionRequested && !opts.PermissionGranted {
 76		return ToolStatusAwaitingPermission
 77	}
 78	return ToolStatusRunning
 79}
 80
 81// ToolRenderer represents an interface for rendering tool calls.
 82type ToolRenderer interface {
 83	RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string
 84}
 85
 86// ToolRendererFunc is a function type that implements the [ToolRenderer] interface.
 87type ToolRendererFunc func(sty *styles.Styles, width int, opts *ToolRenderOpts) string
 88
 89// RenderTool implements the ToolRenderer interface.
 90func (f ToolRendererFunc) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
 91	return f(sty, width, opts)
 92}
 93
 94// baseToolMessageItem represents a tool call message that can be displayed in the UI.
 95type baseToolMessageItem struct {
 96	*highlightableMessageItem
 97	*cachedMessageItem
 98	*focusableMessageItem
 99
100	toolRenderer        ToolRenderer
101	toolCall            message.ToolCall
102	result              *message.ToolResult
103	canceled            bool
104	permissionRequested bool
105	permissionGranted   bool
106	// we use this so we can efficiently cache
107	// tools that have a capped width (e.x bash.. and others)
108	hasCappedWidth bool
109
110	sty      *styles.Styles
111	anim     *anim.Anim
112	expanded bool
113}
114
115// newBaseToolMessageItem is the internal constructor for base tool message items.
116func newBaseToolMessageItem(
117	sty *styles.Styles,
118	toolCall message.ToolCall,
119	result *message.ToolResult,
120	toolRenderer ToolRenderer,
121	canceled bool,
122) *baseToolMessageItem {
123	// we only do full width for diffs (as far as I know)
124	hasCappedWidth := toolCall.Name != tools.EditToolName && toolCall.Name != tools.MultiEditToolName
125
126	t := &baseToolMessageItem{
127		highlightableMessageItem: defaultHighlighter(sty),
128		cachedMessageItem:        &cachedMessageItem{},
129		focusableMessageItem:     &focusableMessageItem{},
130		sty:                      sty,
131		toolRenderer:             toolRenderer,
132		toolCall:                 toolCall,
133		result:                   result,
134		canceled:                 canceled,
135		hasCappedWidth:           hasCappedWidth,
136	}
137	t.anim = anim.New(anim.Settings{
138		ID:          toolCall.ID,
139		Size:        15,
140		GradColorA:  sty.Primary,
141		GradColorB:  sty.Secondary,
142		LabelColor:  sty.FgBase,
143		CycleColors: true,
144	})
145
146	return t
147}
148
149// NewToolMessageItem creates a new [ToolMessageItem] based on the tool call name.
150//
151// It returns a specific tool message item type if implemented, otherwise it
152// returns a generic tool message item.
153func NewToolMessageItem(
154	sty *styles.Styles,
155	toolCall message.ToolCall,
156	result *message.ToolResult,
157	canceled bool,
158) ToolMessageItem {
159	switch toolCall.Name {
160	case tools.BashToolName:
161		return NewBashToolMessageItem(sty, toolCall, result, canceled)
162	case tools.JobOutputToolName:
163		return NewJobOutputToolMessageItem(sty, toolCall, result, canceled)
164	case tools.JobKillToolName:
165		return NewJobKillToolMessageItem(sty, toolCall, result, canceled)
166	case tools.ViewToolName:
167		return NewViewToolMessageItem(sty, toolCall, result, canceled)
168	case tools.WriteToolName:
169		return NewWriteToolMessageItem(sty, toolCall, result, canceled)
170	case tools.EditToolName:
171		return NewEditToolMessageItem(sty, toolCall, result, canceled)
172	case tools.MultiEditToolName:
173		return NewMultiEditToolMessageItem(sty, toolCall, result, canceled)
174	default:
175		// TODO: Implement other tool items
176		return newBaseToolMessageItem(
177			sty,
178			toolCall,
179			result,
180			&DefaultToolRenderContext{},
181			canceled,
182		)
183	}
184}
185
186// ID returns the unique identifier for this tool message item.
187func (t *baseToolMessageItem) ID() string {
188	return t.toolCall.ID
189}
190
191// StartAnimation starts the assistant message animation if it should be spinning.
192func (t *baseToolMessageItem) StartAnimation() tea.Cmd {
193	if !t.isSpinning() {
194		return nil
195	}
196	return t.anim.Start()
197}
198
199// Animate progresses the assistant message animation if it should be spinning.
200func (t *baseToolMessageItem) Animate(msg anim.StepMsg) tea.Cmd {
201	if !t.isSpinning() {
202		return nil
203	}
204	return t.anim.Animate(msg)
205}
206
207// Render renders the tool message item at the given width.
208func (t *baseToolMessageItem) Render(width int) string {
209	toolItemWidth := width - messageLeftPaddingTotal
210	if t.hasCappedWidth {
211		toolItemWidth = cappedMessageWidth(width)
212	}
213	style := t.sty.Chat.Message.ToolCallBlurred
214	if t.focused {
215		style = t.sty.Chat.Message.ToolCallFocused
216	}
217
218	content, height, ok := t.getCachedRender(toolItemWidth)
219	// if we are spinning or there is no cache rerender
220	if !ok || t.isSpinning() {
221		content = t.toolRenderer.RenderTool(t.sty, toolItemWidth, &ToolRenderOpts{
222			ToolCall:            t.toolCall,
223			Result:              t.result,
224			Canceled:            t.canceled,
225			Anim:                t.anim,
226			Expanded:            t.expanded,
227			PermissionRequested: t.permissionRequested,
228			PermissionGranted:   t.permissionGranted,
229			IsSpinning:          t.isSpinning(),
230		})
231		height = lipgloss.Height(content)
232		// cache the rendered content
233		t.setCachedRender(content, toolItemWidth, height)
234	}
235
236	highlightedContent := t.renderHighlighted(content, toolItemWidth, height)
237	return style.Render(highlightedContent)
238}
239
240// ToolCall returns the tool call associated with this message item.
241func (t *baseToolMessageItem) ToolCall() message.ToolCall {
242	return t.toolCall
243}
244
245// SetToolCall sets the tool call associated with this message item.
246func (t *baseToolMessageItem) SetToolCall(tc message.ToolCall) {
247	t.toolCall = tc
248	t.clearCache()
249}
250
251// SetResult sets the tool result associated with this message item.
252func (t *baseToolMessageItem) SetResult(res *message.ToolResult) {
253	t.result = res
254	t.clearCache()
255}
256
257// SetPermissionRequested sets whether permission has been requested for this tool call.
258// TODO: Consider merging with SetPermissionGranted and add an interface for
259// permission management.
260func (t *baseToolMessageItem) SetPermissionRequested(requested bool) {
261	t.permissionRequested = requested
262	t.clearCache()
263}
264
265// SetPermissionGranted sets whether permission has been granted for this tool call.
266// TODO: Consider merging with SetPermissionRequested and add an interface for
267// permission management.
268func (t *baseToolMessageItem) SetPermissionGranted(granted bool) {
269	t.permissionGranted = granted
270	t.clearCache()
271}
272
273// isSpinning returns true if the tool should show animation.
274func (t *baseToolMessageItem) isSpinning() bool {
275	return !t.toolCall.Finished && !t.canceled
276}
277
278// ToggleExpanded toggles the expanded state of the thinking box.
279func (t *baseToolMessageItem) ToggleExpanded() {
280	t.expanded = !t.expanded
281	t.clearCache()
282}
283
284// HandleMouseClick implements MouseClickable.
285func (t *baseToolMessageItem) HandleMouseClick(btn ansi.MouseButton, x, y int) bool {
286	if btn != ansi.MouseLeft {
287		return false
288	}
289	t.ToggleExpanded()
290	return true
291}
292
293// pendingTool renders a tool that is still in progress with an animation.
294func pendingTool(sty *styles.Styles, name string, anim *anim.Anim) string {
295	icon := sty.Tool.IconPending.Render()
296	toolName := sty.Tool.NameNormal.Render(name)
297
298	var animView string
299	if anim != nil {
300		animView = anim.Render()
301	}
302
303	return fmt.Sprintf("%s %s %s", icon, toolName, animView)
304}
305
306// toolEarlyStateContent handles error/cancelled/pending states before content rendering.
307// Returns the rendered output and true if early state was handled.
308func toolEarlyStateContent(sty *styles.Styles, opts *ToolRenderOpts, width int) (string, bool) {
309	var msg string
310	switch opts.Status() {
311	case ToolStatusError:
312		msg = toolErrorContent(sty, opts.Result, width)
313	case ToolStatusCanceled:
314		msg = sty.Tool.StateCancelled.Render("Canceled.")
315	case ToolStatusAwaitingPermission:
316		msg = sty.Tool.StateWaiting.Render("Requesting permission...")
317	case ToolStatusRunning:
318		msg = sty.Tool.StateWaiting.Render("Waiting for tool response...")
319	default:
320		return "", false
321	}
322	return msg, true
323}
324
325// toolErrorContent formats an error message with ERROR tag.
326func toolErrorContent(sty *styles.Styles, result *message.ToolResult, width int) string {
327	if result == nil {
328		return ""
329	}
330	errContent := strings.ReplaceAll(result.Content, "\n", " ")
331	errTag := sty.Tool.ErrorTag.Render("ERROR")
332	tagWidth := lipgloss.Width(errTag)
333	errContent = ansi.Truncate(errContent, width-tagWidth-3, "…")
334	return fmt.Sprintf("%s %s", errTag, sty.Tool.ErrorMessage.Render(errContent))
335}
336
337// toolIcon returns the status icon for a tool call.
338// toolIcon returns the status icon for a tool call based on its status.
339func toolIcon(sty *styles.Styles, status ToolStatus) string {
340	switch status {
341	case ToolStatusSuccess:
342		return sty.Tool.IconSuccess.String()
343	case ToolStatusError:
344		return sty.Tool.IconError.String()
345	case ToolStatusCanceled:
346		return sty.Tool.IconCancelled.String()
347	default:
348		return sty.Tool.IconPending.String()
349	}
350}
351
352// toolParamList formats parameters as "main (key=value, ...)" with truncation.
353// toolParamList formats tool parameters as "main (key=value, ...)" with truncation.
354func toolParamList(sty *styles.Styles, params []string, width int) string {
355	// minSpaceForMainParam is the min space required for the main param
356	// if this is less that the value set we will only show the main param nothing else
357	const minSpaceForMainParam = 30
358	if len(params) == 0 {
359		return ""
360	}
361
362	mainParam := params[0]
363
364	// Build key=value pairs from remaining params (consecutive key, value pairs).
365	var kvPairs []string
366	for i := 1; i+1 < len(params); i += 2 {
367		if params[i+1] != "" {
368			kvPairs = append(kvPairs, fmt.Sprintf("%s=%s", params[i], params[i+1]))
369		}
370	}
371
372	// Try to include key=value pairs if there's enough space.
373	output := mainParam
374	if len(kvPairs) > 0 {
375		partsStr := strings.Join(kvPairs, ", ")
376		if remaining := width - lipgloss.Width(partsStr) - 3; remaining >= minSpaceForMainParam {
377			output = fmt.Sprintf("%s (%s)", mainParam, partsStr)
378		}
379	}
380
381	if width >= 0 {
382		output = ansi.Truncate(output, width, "…")
383	}
384	return sty.Tool.ParamMain.Render(output)
385}
386
387// toolHeader builds the tool header line: "● ToolName params..."
388func toolHeader(sty *styles.Styles, status ToolStatus, name string, width int, nested bool, params ...string) string {
389	icon := toolIcon(sty, status)
390	nameStyle := sty.Tool.NameNormal
391	if nested {
392		nameStyle = sty.Tool.NameNested
393	}
394	toolName := nameStyle.Render(name)
395	prefix := fmt.Sprintf("%s %s ", icon, toolName)
396	prefixWidth := lipgloss.Width(prefix)
397	remainingWidth := width - prefixWidth
398	paramsStr := toolParamList(sty, params, remainingWidth)
399	return prefix + paramsStr
400}
401
402// toolOutputPlainContent renders plain text with optional expansion support.
403func toolOutputPlainContent(sty *styles.Styles, content string, width int, expanded bool) string {
404	content = strings.ReplaceAll(content, "\r\n", "\n")
405	content = strings.ReplaceAll(content, "\t", "    ")
406	content = strings.TrimSpace(content)
407	lines := strings.Split(content, "\n")
408
409	maxLines := responseContextHeight
410	if expanded {
411		maxLines = len(lines) // Show all
412	}
413
414	var out []string
415	for i, ln := range lines {
416		if i >= maxLines {
417			break
418		}
419		ln = " " + ln
420		if lipgloss.Width(ln) > width {
421			ln = ansi.Truncate(ln, width, "…")
422		}
423		out = append(out, sty.Tool.ContentLine.Width(width).Render(ln))
424	}
425
426	wasTruncated := len(lines) > responseContextHeight
427
428	if !expanded && wasTruncated {
429		out = append(out, sty.Tool.ContentTruncation.
430			Width(width).
431			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight)))
432	}
433
434	return strings.Join(out, "\n")
435}
436
437// toolOutputCodeContent renders code with syntax highlighting and line numbers.
438func toolOutputCodeContent(sty *styles.Styles, path, content string, offset, width int, expanded bool) string {
439	content = strings.ReplaceAll(content, "\r\n", "\n")
440	content = strings.ReplaceAll(content, "\t", "    ")
441
442	lines := strings.Split(content, "\n")
443	maxLines := responseContextHeight
444	if expanded {
445		maxLines = len(lines)
446	}
447
448	// Truncate if needed.
449	displayLines := lines
450	if len(lines) > maxLines {
451		displayLines = lines[:maxLines]
452	}
453
454	bg := sty.Tool.ContentCodeBg
455	highlighted, _ := common.SyntaxHighlight(sty, strings.Join(displayLines, "\n"), path, bg)
456	highlightedLines := strings.Split(highlighted, "\n")
457
458	// Calculate line number width.
459	maxLineNumber := len(displayLines) + offset
460	maxDigits := getDigits(maxLineNumber)
461	numFmt := fmt.Sprintf("%%%dd", maxDigits)
462
463	bodyWidth := width - toolBodyLeftPaddingTotal
464	codeWidth := bodyWidth - maxDigits - 4 // -4 for line number padding
465
466	var out []string
467	for i, ln := range highlightedLines {
468		lineNum := sty.Tool.ContentLineNumber.Render(fmt.Sprintf(numFmt, i+1+offset))
469
470		if lipgloss.Width(ln) > codeWidth {
471			ln = ansi.Truncate(ln, codeWidth, "…")
472		}
473
474		codeLine := sty.Tool.ContentCodeLine.
475			Width(codeWidth).
476			PaddingLeft(2).
477			Render(ln)
478
479		out = append(out, lipgloss.JoinHorizontal(lipgloss.Left, lineNum, codeLine))
480	}
481
482	// Add truncation message if needed.
483	if len(lines) > maxLines && !expanded {
484		truncMsg := sty.Tool.ContentCodeTruncation.
485			Width(bodyWidth).
486			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
487		out = append(out, truncMsg)
488	}
489
490	return sty.Tool.Body.Render(strings.Join(out, "\n"))
491}
492
493// toolOutputImageContent renders image data with size info.
494func toolOutputImageContent(sty *styles.Styles, data, mediaType string) string {
495	dataSize := len(data) * 3 / 4
496	sizeStr := formatSize(dataSize)
497
498	loaded := sty.Base.Foreground(sty.Green).Render("Loaded")
499	arrow := sty.Base.Foreground(sty.GreenDark).Render("→")
500	typeStyled := sty.Base.Render(mediaType)
501	sizeStyled := sty.Subtle.Render(sizeStr)
502
503	return sty.Tool.Body.Render(fmt.Sprintf("%s %s %s %s", loaded, arrow, typeStyled, sizeStyled))
504}
505
506// getDigits returns the number of digits in a number.
507func getDigits(n int) int {
508	if n == 0 {
509		return 1
510	}
511	if n < 0 {
512		n = -n
513	}
514	digits := 0
515	for n > 0 {
516		n /= 10
517		digits++
518	}
519	return digits
520}
521
522// formatSize formats byte size into human readable format.
523func formatSize(bytes int) string {
524	const (
525		kb = 1024
526		mb = kb * 1024
527	)
528	switch {
529	case bytes >= mb:
530		return fmt.Sprintf("%.1f MB", float64(bytes)/float64(mb))
531	case bytes >= kb:
532		return fmt.Sprintf("%.1f KB", float64(bytes)/float64(kb))
533	default:
534		return fmt.Sprintf("%d B", bytes)
535	}
536}
537
538// toolOutputDiffContent renders a diff between old and new content.
539func toolOutputDiffContent(sty *styles.Styles, file, oldContent, newContent string, width int, expanded bool) string {
540	bodyWidth := width - toolBodyLeftPaddingTotal
541
542	formatter := common.DiffFormatter(sty).
543		Before(file, oldContent).
544		After(file, newContent).
545		Width(bodyWidth)
546
547	// Use split view for wide terminals.
548	if width > 120 {
549		formatter = formatter.Split()
550	}
551
552	formatted := formatter.String()
553	lines := strings.Split(formatted, "\n")
554
555	// Truncate if needed.
556	maxLines := responseContextHeight
557	if expanded {
558		maxLines = len(lines)
559	}
560
561	if len(lines) > maxLines && !expanded {
562		truncMsg := sty.Tool.DiffTruncation.
563			Width(bodyWidth).
564			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
565		formatted = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
566	}
567
568	return sty.Tool.Body.Render(formatted)
569}
570
571// toolOutputMultiEditDiffContent renders a diff with optional failed edits note.
572func toolOutputMultiEditDiffContent(sty *styles.Styles, file string, meta tools.MultiEditResponseMetadata, totalEdits, width int, expanded bool) string {
573	bodyWidth := width - toolBodyLeftPaddingTotal
574
575	formatter := common.DiffFormatter(sty).
576		Before(file, meta.OldContent).
577		After(file, meta.NewContent).
578		Width(bodyWidth)
579
580	// Use split view for wide terminals.
581	if width > 120 {
582		formatter = formatter.Split()
583	}
584
585	formatted := formatter.String()
586	lines := strings.Split(formatted, "\n")
587
588	// Truncate if needed.
589	maxLines := responseContextHeight
590	if expanded {
591		maxLines = len(lines)
592	}
593
594	if len(lines) > maxLines && !expanded {
595		truncMsg := sty.Tool.DiffTruncation.
596			Width(bodyWidth).
597			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
598		formatted = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
599	}
600
601	// Add failed edits note if any exist.
602	if len(meta.EditsFailed) > 0 {
603		noteTag := sty.Tool.NoteTag.Render("Note")
604		noteMsg := fmt.Sprintf("%d of %d edits succeeded", meta.EditsApplied, totalEdits)
605		note := fmt.Sprintf("%s %s", noteTag, sty.Tool.NoteMessage.Render(noteMsg))
606		formatted = formatted + "\n\n" + note
607	}
608
609	return sty.Tool.Body.Render(formatted)
610}