tools.go

  1package chat
  2
  3import (
  4	"fmt"
  5	"strings"
  6
  7	tea "charm.land/bubbletea/v2"
  8	"charm.land/lipgloss/v2"
  9	"github.com/charmbracelet/crush/internal/agent/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11	"github.com/charmbracelet/crush/internal/ui/anim"
 12	"github.com/charmbracelet/crush/internal/ui/common"
 13	"github.com/charmbracelet/crush/internal/ui/styles"
 14	"github.com/charmbracelet/x/ansi"
 15)
 16
 17// responseContextHeight limits the number of lines displayed in tool output.
 18const responseContextHeight = 10
 19
 20// toolBodyLeftPaddingTotal represents the padding that should be applied to each tool body
 21const toolBodyLeftPaddingTotal = 2
 22
 23// ToolStatus represents the current state of a tool call.
 24type ToolStatus int
 25
 26const (
 27	ToolStatusAwaitingPermission ToolStatus = iota
 28	ToolStatusRunning
 29	ToolStatusSuccess
 30	ToolStatusError
 31	ToolStatusCanceled
 32)
 33
 34// ToolMessageItem represents a tool call message in the chat UI.
 35type ToolMessageItem interface {
 36	MessageItem
 37
 38	ToolCall() message.ToolCall
 39	SetToolCall(tc message.ToolCall)
 40	SetResult(res *message.ToolResult)
 41}
 42
 43// DefaultToolRenderContext implements the default [ToolRenderer] interface.
 44type DefaultToolRenderContext struct{}
 45
 46// RenderTool implements the [ToolRenderer] interface.
 47func (d *DefaultToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
 48	return "TODO: Implement Tool Renderer For: " + opts.ToolCall.Name
 49}
 50
 51// ToolRenderOpts contains the data needed to render a tool call.
 52type ToolRenderOpts struct {
 53	ToolCall            message.ToolCall
 54	Result              *message.ToolResult
 55	Canceled            bool
 56	Anim                *anim.Anim
 57	Expanded            bool
 58	Nested              bool
 59	IsSpinning          bool
 60	PermissionRequested bool
 61	PermissionGranted   bool
 62}
 63
 64// Status returns the current status of the tool call.
 65func (opts *ToolRenderOpts) Status() ToolStatus {
 66	if opts.Canceled && opts.Result == nil {
 67		return ToolStatusCanceled
 68	}
 69	if opts.Result != nil {
 70		if opts.Result.IsError {
 71			return ToolStatusError
 72		}
 73		return ToolStatusSuccess
 74	}
 75	if opts.PermissionRequested && !opts.PermissionGranted {
 76		return ToolStatusAwaitingPermission
 77	}
 78	return ToolStatusRunning
 79}
 80
 81// ToolRenderer represents an interface for rendering tool calls.
 82type ToolRenderer interface {
 83	RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string
 84}
 85
 86// ToolRendererFunc is a function type that implements the [ToolRenderer] interface.
 87type ToolRendererFunc func(sty *styles.Styles, width int, opts *ToolRenderOpts) string
 88
 89// RenderTool implements the ToolRenderer interface.
 90func (f ToolRendererFunc) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
 91	return f(sty, width, opts)
 92}
 93
 94// baseToolMessageItem represents a tool call message that can be displayed in the UI.
 95type baseToolMessageItem struct {
 96	*highlightableMessageItem
 97	*cachedMessageItem
 98	*focusableMessageItem
 99
100	toolRenderer        ToolRenderer
101	toolCall            message.ToolCall
102	result              *message.ToolResult
103	canceled            bool
104	permissionRequested bool
105	permissionGranted   bool
106	// we use this so we can efficiently cache
107	// tools that have a capped width (e.x bash.. and others)
108	hasCappedWidth bool
109
110	sty      *styles.Styles
111	anim     *anim.Anim
112	expanded bool
113}
114
115// newBaseToolMessageItem is the internal constructor for base tool message items.
116func newBaseToolMessageItem(
117	sty *styles.Styles,
118	toolCall message.ToolCall,
119	result *message.ToolResult,
120	toolRenderer ToolRenderer,
121	canceled bool,
122) *baseToolMessageItem {
123	// we only do full width for diffs (as far as I know)
124	hasCappedWidth := toolCall.Name != tools.EditToolName && toolCall.Name != tools.MultiEditToolName
125
126	t := &baseToolMessageItem{
127		highlightableMessageItem: defaultHighlighter(sty),
128		cachedMessageItem:        &cachedMessageItem{},
129		focusableMessageItem:     &focusableMessageItem{},
130		sty:                      sty,
131		toolRenderer:             toolRenderer,
132		toolCall:                 toolCall,
133		result:                   result,
134		canceled:                 canceled,
135		hasCappedWidth:           hasCappedWidth,
136	}
137	t.anim = anim.New(anim.Settings{
138		ID:          toolCall.ID,
139		Size:        15,
140		GradColorA:  sty.Primary,
141		GradColorB:  sty.Secondary,
142		LabelColor:  sty.FgBase,
143		CycleColors: true,
144	})
145
146	return t
147}
148
149// NewToolMessageItem creates a new [ToolMessageItem] based on the tool call name.
150//
151// It returns a specific tool message item type if implemented, otherwise it
152// returns a generic tool message item.
153func NewToolMessageItem(
154	sty *styles.Styles,
155	toolCall message.ToolCall,
156	result *message.ToolResult,
157	canceled bool,
158) ToolMessageItem {
159	switch toolCall.Name {
160	case tools.BashToolName:
161		return NewBashToolMessageItem(sty, toolCall, result, canceled)
162	case tools.JobOutputToolName:
163		return NewJobOutputToolMessageItem(sty, toolCall, result, canceled)
164	case tools.JobKillToolName:
165		return NewJobKillToolMessageItem(sty, toolCall, result, canceled)
166	case tools.ViewToolName:
167		return NewViewToolMessageItem(sty, toolCall, result, canceled)
168	case tools.WriteToolName:
169		return NewWriteToolMessageItem(sty, toolCall, result, canceled)
170	case tools.EditToolName:
171		return NewEditToolMessageItem(sty, toolCall, result, canceled)
172	case tools.MultiEditToolName:
173		return NewMultiEditToolMessageItem(sty, toolCall, result, canceled)
174	case tools.GlobToolName:
175		return NewGlobToolMessageItem(sty, toolCall, result, canceled)
176	case tools.GrepToolName:
177		return NewGrepToolMessageItem(sty, toolCall, result, canceled)
178	case tools.LSToolName:
179		return NewLSToolMessageItem(sty, toolCall, result, canceled)
180	default:
181		// TODO: Implement other tool items
182		return newBaseToolMessageItem(
183			sty,
184			toolCall,
185			result,
186			&DefaultToolRenderContext{},
187			canceled,
188		)
189	}
190}
191
192// ID returns the unique identifier for this tool message item.
193func (t *baseToolMessageItem) ID() string {
194	return t.toolCall.ID
195}
196
197// StartAnimation starts the assistant message animation if it should be spinning.
198func (t *baseToolMessageItem) StartAnimation() tea.Cmd {
199	if !t.isSpinning() {
200		return nil
201	}
202	return t.anim.Start()
203}
204
205// Animate progresses the assistant message animation if it should be spinning.
206func (t *baseToolMessageItem) Animate(msg anim.StepMsg) tea.Cmd {
207	if !t.isSpinning() {
208		return nil
209	}
210	return t.anim.Animate(msg)
211}
212
213// Render renders the tool message item at the given width.
214func (t *baseToolMessageItem) Render(width int) string {
215	toolItemWidth := width - messageLeftPaddingTotal
216	if t.hasCappedWidth {
217		toolItemWidth = cappedMessageWidth(width)
218	}
219	style := t.sty.Chat.Message.ToolCallBlurred
220	if t.focused {
221		style = t.sty.Chat.Message.ToolCallFocused
222	}
223
224	content, height, ok := t.getCachedRender(toolItemWidth)
225	// if we are spinning or there is no cache rerender
226	if !ok || t.isSpinning() {
227		content = t.toolRenderer.RenderTool(t.sty, toolItemWidth, &ToolRenderOpts{
228			ToolCall:            t.toolCall,
229			Result:              t.result,
230			Canceled:            t.canceled,
231			Anim:                t.anim,
232			Expanded:            t.expanded,
233			PermissionRequested: t.permissionRequested,
234			PermissionGranted:   t.permissionGranted,
235			IsSpinning:          t.isSpinning(),
236		})
237		height = lipgloss.Height(content)
238		// cache the rendered content
239		t.setCachedRender(content, toolItemWidth, height)
240	}
241
242	highlightedContent := t.renderHighlighted(content, toolItemWidth, height)
243	return style.Render(highlightedContent)
244}
245
246// ToolCall returns the tool call associated with this message item.
247func (t *baseToolMessageItem) ToolCall() message.ToolCall {
248	return t.toolCall
249}
250
251// SetToolCall sets the tool call associated with this message item.
252func (t *baseToolMessageItem) SetToolCall(tc message.ToolCall) {
253	t.toolCall = tc
254	t.clearCache()
255}
256
257// SetResult sets the tool result associated with this message item.
258func (t *baseToolMessageItem) SetResult(res *message.ToolResult) {
259	t.result = res
260	t.clearCache()
261}
262
263// SetPermissionRequested sets whether permission has been requested for this tool call.
264// TODO: Consider merging with SetPermissionGranted and add an interface for
265// permission management.
266func (t *baseToolMessageItem) SetPermissionRequested(requested bool) {
267	t.permissionRequested = requested
268	t.clearCache()
269}
270
271// SetPermissionGranted sets whether permission has been granted for this tool call.
272// TODO: Consider merging with SetPermissionRequested and add an interface for
273// permission management.
274func (t *baseToolMessageItem) SetPermissionGranted(granted bool) {
275	t.permissionGranted = granted
276	t.clearCache()
277}
278
279// isSpinning returns true if the tool should show animation.
280func (t *baseToolMessageItem) isSpinning() bool {
281	return !t.toolCall.Finished && !t.canceled
282}
283
284// ToggleExpanded toggles the expanded state of the thinking box.
285func (t *baseToolMessageItem) ToggleExpanded() {
286	t.expanded = !t.expanded
287	t.clearCache()
288}
289
290// HandleMouseClick implements MouseClickable.
291func (t *baseToolMessageItem) HandleMouseClick(btn ansi.MouseButton, x, y int) bool {
292	if btn != ansi.MouseLeft {
293		return false
294	}
295	t.ToggleExpanded()
296	return true
297}
298
299// pendingTool renders a tool that is still in progress with an animation.
300func pendingTool(sty *styles.Styles, name string, anim *anim.Anim) string {
301	icon := sty.Tool.IconPending.Render()
302	toolName := sty.Tool.NameNormal.Render(name)
303
304	var animView string
305	if anim != nil {
306		animView = anim.Render()
307	}
308
309	return fmt.Sprintf("%s %s %s", icon, toolName, animView)
310}
311
312// toolEarlyStateContent handles error/cancelled/pending states before content rendering.
313// Returns the rendered output and true if early state was handled.
314func toolEarlyStateContent(sty *styles.Styles, opts *ToolRenderOpts, width int) (string, bool) {
315	var msg string
316	switch opts.Status() {
317	case ToolStatusError:
318		msg = toolErrorContent(sty, opts.Result, width)
319	case ToolStatusCanceled:
320		msg = sty.Tool.StateCancelled.Render("Canceled.")
321	case ToolStatusAwaitingPermission:
322		msg = sty.Tool.StateWaiting.Render("Requesting permission...")
323	case ToolStatusRunning:
324		msg = sty.Tool.StateWaiting.Render("Waiting for tool response...")
325	default:
326		return "", false
327	}
328	return msg, true
329}
330
331// toolErrorContent formats an error message with ERROR tag.
332func toolErrorContent(sty *styles.Styles, result *message.ToolResult, width int) string {
333	if result == nil {
334		return ""
335	}
336	errContent := strings.ReplaceAll(result.Content, "\n", " ")
337	errTag := sty.Tool.ErrorTag.Render("ERROR")
338	tagWidth := lipgloss.Width(errTag)
339	errContent = ansi.Truncate(errContent, width-tagWidth-3, "…")
340	return fmt.Sprintf("%s %s", errTag, sty.Tool.ErrorMessage.Render(errContent))
341}
342
343// toolIcon returns the status icon for a tool call.
344// toolIcon returns the status icon for a tool call based on its status.
345func toolIcon(sty *styles.Styles, status ToolStatus) string {
346	switch status {
347	case ToolStatusSuccess:
348		return sty.Tool.IconSuccess.String()
349	case ToolStatusError:
350		return sty.Tool.IconError.String()
351	case ToolStatusCanceled:
352		return sty.Tool.IconCancelled.String()
353	default:
354		return sty.Tool.IconPending.String()
355	}
356}
357
358// toolParamList formats parameters as "main (key=value, ...)" with truncation.
359// toolParamList formats tool parameters as "main (key=value, ...)" with truncation.
360func toolParamList(sty *styles.Styles, params []string, width int) string {
361	// minSpaceForMainParam is the min space required for the main param
362	// if this is less that the value set we will only show the main param nothing else
363	const minSpaceForMainParam = 30
364	if len(params) == 0 {
365		return ""
366	}
367
368	mainParam := params[0]
369
370	// Build key=value pairs from remaining params (consecutive key, value pairs).
371	var kvPairs []string
372	for i := 1; i+1 < len(params); i += 2 {
373		if params[i+1] != "" {
374			kvPairs = append(kvPairs, fmt.Sprintf("%s=%s", params[i], params[i+1]))
375		}
376	}
377
378	// Try to include key=value pairs if there's enough space.
379	output := mainParam
380	if len(kvPairs) > 0 {
381		partsStr := strings.Join(kvPairs, ", ")
382		if remaining := width - lipgloss.Width(partsStr) - 3; remaining >= minSpaceForMainParam {
383			output = fmt.Sprintf("%s (%s)", mainParam, partsStr)
384		}
385	}
386
387	if width >= 0 {
388		output = ansi.Truncate(output, width, "…")
389	}
390	return sty.Tool.ParamMain.Render(output)
391}
392
393// toolHeader builds the tool header line: "● ToolName params..."
394func toolHeader(sty *styles.Styles, status ToolStatus, name string, width int, nested bool, params ...string) string {
395	icon := toolIcon(sty, status)
396	nameStyle := sty.Tool.NameNormal
397	if nested {
398		nameStyle = sty.Tool.NameNested
399	}
400	toolName := nameStyle.Render(name)
401	prefix := fmt.Sprintf("%s %s ", icon, toolName)
402	prefixWidth := lipgloss.Width(prefix)
403	remainingWidth := width - prefixWidth
404	paramsStr := toolParamList(sty, params, remainingWidth)
405	return prefix + paramsStr
406}
407
408// toolOutputPlainContent renders plain text with optional expansion support.
409func toolOutputPlainContent(sty *styles.Styles, content string, width int, expanded bool) string {
410	content = strings.ReplaceAll(content, "\r\n", "\n")
411	content = strings.ReplaceAll(content, "\t", "    ")
412	content = strings.TrimSpace(content)
413	lines := strings.Split(content, "\n")
414
415	maxLines := responseContextHeight
416	if expanded {
417		maxLines = len(lines) // Show all
418	}
419
420	var out []string
421	for i, ln := range lines {
422		if i >= maxLines {
423			break
424		}
425		ln = " " + ln
426		if lipgloss.Width(ln) > width {
427			ln = ansi.Truncate(ln, width, "…")
428		}
429		out = append(out, sty.Tool.ContentLine.Width(width).Render(ln))
430	}
431
432	wasTruncated := len(lines) > responseContextHeight
433
434	if !expanded && wasTruncated {
435		out = append(out, sty.Tool.ContentTruncation.
436			Width(width).
437			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight)))
438	}
439
440	return strings.Join(out, "\n")
441}
442
443// toolOutputCodeContent renders code with syntax highlighting and line numbers.
444func toolOutputCodeContent(sty *styles.Styles, path, content string, offset, width int, expanded bool) string {
445	content = strings.ReplaceAll(content, "\r\n", "\n")
446	content = strings.ReplaceAll(content, "\t", "    ")
447
448	lines := strings.Split(content, "\n")
449	maxLines := responseContextHeight
450	if expanded {
451		maxLines = len(lines)
452	}
453
454	// Truncate if needed.
455	displayLines := lines
456	if len(lines) > maxLines {
457		displayLines = lines[:maxLines]
458	}
459
460	bg := sty.Tool.ContentCodeBg
461	highlighted, _ := common.SyntaxHighlight(sty, strings.Join(displayLines, "\n"), path, bg)
462	highlightedLines := strings.Split(highlighted, "\n")
463
464	// Calculate line number width.
465	maxLineNumber := len(displayLines) + offset
466	maxDigits := getDigits(maxLineNumber)
467	numFmt := fmt.Sprintf("%%%dd", maxDigits)
468
469	bodyWidth := width - toolBodyLeftPaddingTotal
470	codeWidth := bodyWidth - maxDigits - 4 // -4 for line number padding
471
472	var out []string
473	for i, ln := range highlightedLines {
474		lineNum := sty.Tool.ContentLineNumber.Render(fmt.Sprintf(numFmt, i+1+offset))
475
476		if lipgloss.Width(ln) > codeWidth {
477			ln = ansi.Truncate(ln, codeWidth, "…")
478		}
479
480		codeLine := sty.Tool.ContentCodeLine.
481			Width(codeWidth).
482			PaddingLeft(2).
483			Render(ln)
484
485		out = append(out, lipgloss.JoinHorizontal(lipgloss.Left, lineNum, codeLine))
486	}
487
488	// Add truncation message if needed.
489	if len(lines) > maxLines && !expanded {
490		truncMsg := sty.Tool.ContentCodeTruncation.
491			Width(bodyWidth).
492			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
493		out = append(out, truncMsg)
494	}
495
496	return sty.Tool.Body.Render(strings.Join(out, "\n"))
497}
498
499// toolOutputImageContent renders image data with size info.
500func toolOutputImageContent(sty *styles.Styles, data, mediaType string) string {
501	dataSize := len(data) * 3 / 4
502	sizeStr := formatSize(dataSize)
503
504	loaded := sty.Base.Foreground(sty.Green).Render("Loaded")
505	arrow := sty.Base.Foreground(sty.GreenDark).Render("→")
506	typeStyled := sty.Base.Render(mediaType)
507	sizeStyled := sty.Subtle.Render(sizeStr)
508
509	return sty.Tool.Body.Render(fmt.Sprintf("%s %s %s %s", loaded, arrow, typeStyled, sizeStyled))
510}
511
512// getDigits returns the number of digits in a number.
513func getDigits(n int) int {
514	if n == 0 {
515		return 1
516	}
517	if n < 0 {
518		n = -n
519	}
520	digits := 0
521	for n > 0 {
522		n /= 10
523		digits++
524	}
525	return digits
526}
527
528// formatSize formats byte size into human readable format.
529func formatSize(bytes int) string {
530	const (
531		kb = 1024
532		mb = kb * 1024
533	)
534	switch {
535	case bytes >= mb:
536		return fmt.Sprintf("%.1f MB", float64(bytes)/float64(mb))
537	case bytes >= kb:
538		return fmt.Sprintf("%.1f KB", float64(bytes)/float64(kb))
539	default:
540		return fmt.Sprintf("%d B", bytes)
541	}
542}
543
544// toolOutputDiffContent renders a diff between old and new content.
545func toolOutputDiffContent(sty *styles.Styles, file, oldContent, newContent string, width int, expanded bool) string {
546	bodyWidth := width - toolBodyLeftPaddingTotal
547
548	formatter := common.DiffFormatter(sty).
549		Before(file, oldContent).
550		After(file, newContent).
551		Width(bodyWidth)
552
553	// Use split view for wide terminals.
554	if width > 120 {
555		formatter = formatter.Split()
556	}
557
558	formatted := formatter.String()
559	lines := strings.Split(formatted, "\n")
560
561	// Truncate if needed.
562	maxLines := responseContextHeight
563	if expanded {
564		maxLines = len(lines)
565	}
566
567	if len(lines) > maxLines && !expanded {
568		truncMsg := sty.Tool.DiffTruncation.
569			Width(bodyWidth).
570			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
571		formatted = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
572	}
573
574	return sty.Tool.Body.Render(formatted)
575}
576
577// toolOutputMultiEditDiffContent renders a diff with optional failed edits note.
578func toolOutputMultiEditDiffContent(sty *styles.Styles, file string, meta tools.MultiEditResponseMetadata, totalEdits, width int, expanded bool) string {
579	bodyWidth := width - toolBodyLeftPaddingTotal
580
581	formatter := common.DiffFormatter(sty).
582		Before(file, meta.OldContent).
583		After(file, meta.NewContent).
584		Width(bodyWidth)
585
586	// Use split view for wide terminals.
587	if width > 120 {
588		formatter = formatter.Split()
589	}
590
591	formatted := formatter.String()
592	lines := strings.Split(formatted, "\n")
593
594	// Truncate if needed.
595	maxLines := responseContextHeight
596	if expanded {
597		maxLines = len(lines)
598	}
599
600	if len(lines) > maxLines && !expanded {
601		truncMsg := sty.Tool.DiffTruncation.
602			Width(bodyWidth).
603			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-maxLines))
604		formatted = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
605	}
606
607	// Add failed edits note if any exist.
608	if len(meta.EditsFailed) > 0 {
609		noteTag := sty.Tool.NoteTag.Render("Note")
610		noteMsg := fmt.Sprintf("%d of %d edits succeeded", meta.EditsApplied, totalEdits)
611		note := fmt.Sprintf("%s %s", noteTag, sty.Tool.NoteMessage.Render(noteMsg))
612		formatted = formatted + "\n\n" + note
613	}
614
615	return sty.Tool.Body.Render(formatted)
616}