tool_base.go

  1package chat
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"charm.land/bubbles/v2/key"
 10	tea "charm.land/bubbletea/v2"
 11	"charm.land/lipgloss/v2"
 12	"github.com/charmbracelet/crush/internal/ansiext"
 13	"github.com/charmbracelet/crush/internal/message"
 14	"github.com/charmbracelet/crush/internal/session"
 15	"github.com/charmbracelet/crush/internal/ui/common"
 16	"github.com/charmbracelet/crush/internal/ui/common/anim"
 17	"github.com/charmbracelet/crush/internal/ui/styles"
 18	"github.com/charmbracelet/x/ansi"
 19)
 20
 21// responseContextHeight limits the number of lines displayed in tool output.
 22const responseContextHeight = 10
 23
 24// ToolStatus represents the current state of a tool call.
 25type ToolStatus int
 26
 27const (
 28	ToolStatusAwaitingPermission ToolStatus = iota
 29	ToolStatusRunning
 30	ToolStatusSuccess
 31	ToolStatusError
 32	ToolStatusCancelled
 33)
 34
 35// ToolCallContext provides the context needed for rendering a tool call.
 36type ToolCallContext struct {
 37	Call                message.ToolCall
 38	Result              *message.ToolResult
 39	Cancelled           bool
 40	PermissionRequested bool
 41	PermissionGranted   bool
 42	IsNested            bool
 43	Styles              *styles.Styles
 44
 45	NestedCalls []ToolCallContext
 46}
 47
 48// Status returns the current status of the tool call.
 49func (ctx *ToolCallContext) Status() ToolStatus {
 50	if ctx.Cancelled {
 51		return ToolStatusCancelled
 52	}
 53	if ctx.HasResult() {
 54		if ctx.Result.IsError {
 55			return ToolStatusError
 56		}
 57		return ToolStatusSuccess
 58	}
 59	if ctx.PermissionRequested && !ctx.PermissionGranted {
 60		return ToolStatusAwaitingPermission
 61	}
 62	return ToolStatusRunning
 63}
 64
 65// HasResult returns true if the tool call has a completed result.
 66func (ctx *ToolCallContext) HasResult() bool {
 67	return ctx.Result != nil && ctx.Result.ToolCallID != ""
 68}
 69
 70// toolStyles provides common FocusStylable and HighlightStylable implementations.
 71type toolStyles struct {
 72	sty *styles.Styles
 73}
 74
 75func (s toolStyles) FocusStyle() lipgloss.Style {
 76	return s.sty.Chat.Message.ToolCallFocused
 77}
 78
 79func (s toolStyles) BlurStyle() lipgloss.Style {
 80	return s.sty.Chat.Message.ToolCallBlurred
 81}
 82
 83func (s toolStyles) HighlightStyle() lipgloss.Style {
 84	return s.sty.TextSelection
 85}
 86
 87// toolItem provides common base functionality for all tool items.
 88type toolItem struct {
 89	toolStyles
 90	id           string
 91	ctx          ToolCallContext
 92	expanded     bool
 93	wasTruncated bool
 94	spinning     bool
 95	anim         *anim.Anim
 96}
 97
 98// newToolItem creates a new toolItem with the given context.
 99func newToolItem(ctx ToolCallContext) toolItem {
100	animSize := 15
101	if ctx.IsNested {
102		animSize = 10
103	}
104
105	t := toolItem{
106		toolStyles: toolStyles{sty: ctx.Styles},
107		id:         ctx.Call.ID,
108		ctx:        ctx,
109		spinning:   shouldSpin(ctx),
110		anim: anim.New(anim.Settings{
111			Size:        animSize,
112			Label:       "Working",
113			GradColorA:  ctx.Styles.Primary,
114			GradColorB:  ctx.Styles.Secondary,
115			LabelColor:  ctx.Styles.FgBase,
116			CycleColors: true,
117		}),
118	}
119
120	return t
121}
122
123// shouldSpin returns true if the tool should show animation.
124func shouldSpin(ctx ToolCallContext) bool {
125	return !ctx.Call.Finished && !ctx.Cancelled
126}
127
128// ID implements Identifiable.
129func (t *toolItem) ID() string {
130	return t.id
131}
132
133// HandleMouseClick implements list.MouseClickable.
134func (t *toolItem) HandleMouseClick(btn ansi.MouseButton, x, y int) bool {
135	if btn != ansi.MouseLeft || !t.wasTruncated {
136		return false
137	}
138
139	t.expanded = !t.expanded
140	return true
141}
142
143// HandleKeyPress implements list.KeyPressable.
144func (t *toolItem) HandleKeyPress(msg tea.KeyPressMsg) bool {
145	if !t.wasTruncated {
146		return false
147	}
148
149	if key.Matches(msg, key.NewBinding(key.WithKeys("space"))) {
150		t.expanded = !t.expanded
151		return true
152	}
153
154	return false
155}
156
157// updateAnimation handles animation updates and returns true if changed.
158func (t *toolItem) updateAnimation(msg tea.Msg) (tea.Cmd, bool) {
159	if !t.spinning || t.anim == nil {
160		return nil, false
161	}
162
163	switch msg.(type) {
164	case anim.StepMsg:
165		updatedAnim, cmd := t.anim.Update(msg)
166		t.anim = updatedAnim
167		return cmd, cmd != nil
168	}
169
170	return nil, false
171}
172
173// InitAnimation initializes and starts the animation.
174func (t *toolItem) InitAnimation() tea.Cmd {
175	t.spinning = shouldSpin(t.ctx)
176	return t.anim.Init()
177}
178
179// SetResult updates the tool call with a result.
180func (t *toolItem) SetResult(result message.ToolResult) {
181	t.ctx.Result = &result
182	t.ctx.Call.Finished = true
183	t.spinning = false
184}
185
186// SetCancelled marks the tool call as cancelled.
187func (t *toolItem) SetCancelled() {
188	t.ctx.Cancelled = true
189	t.spinning = false
190}
191
192// UpdateCall updates the tool call data.
193func (t *toolItem) UpdateCall(call message.ToolCall) {
194	t.ctx.Call = call
195	if call.Finished {
196		t.spinning = false
197	}
198}
199
200// SetNestedCalls sets the nested tool calls for agent tools.
201func (t *toolItem) SetNestedCalls(calls []ToolCallContext) {
202	t.ctx.NestedCalls = calls
203}
204
205// Context returns the current tool call context.
206func (t *toolItem) Context() *ToolCallContext {
207	return &t.ctx
208}
209
210// renderPending returns the pending state view with animation.
211func (t *toolItem) renderPending() string {
212	icon := t.sty.Tool.IconPending.Render()
213
214	var toolName string
215	if t.ctx.IsNested {
216		toolName = t.sty.Tool.NameNested.Render(prettifyToolName(t.ctx.Call.Name))
217	} else {
218		toolName = t.sty.Tool.NameNormal.Render(prettifyToolName(t.ctx.Call.Name))
219	}
220
221	var animView string
222	if t.anim != nil {
223		animView = t.anim.View()
224	}
225
226	return fmt.Sprintf("%s %s %s", icon, toolName, animView)
227}
228
229// unmarshalParams unmarshals JSON input into the target struct.
230func unmarshalParams(input string, target any) error {
231	return json.Unmarshal([]byte(input), target)
232}
233
234// ParamBuilder helps construct parameter lists for tool headers.
235type ParamBuilder struct {
236	args []string
237}
238
239// NewParamBuilder creates a new parameter builder.
240func NewParamBuilder() *ParamBuilder {
241	return &ParamBuilder{args: make([]string, 0, 4)}
242}
243
244// Main adds the main parameter (first positional argument).
245func (pb *ParamBuilder) Main(value string) *ParamBuilder {
246	if value != "" {
247		pb.args = append(pb.args, value)
248	}
249	return pb
250}
251
252// KeyValue adds a key-value pair parameter.
253func (pb *ParamBuilder) KeyValue(key, value string) *ParamBuilder {
254	if value != "" {
255		pb.args = append(pb.args, key, value)
256	}
257	return pb
258}
259
260// Flag adds a boolean flag parameter (only if true).
261func (pb *ParamBuilder) Flag(key string, value bool) *ParamBuilder {
262	if value {
263		pb.args = append(pb.args, key, "true")
264	}
265	return pb
266}
267
268// Build returns the parameter list.
269func (pb *ParamBuilder) Build() []string {
270	return pb.args
271}
272
273// renderToolIcon returns the status icon for a tool call.
274func renderToolIcon(status ToolStatus, sty *styles.Styles) string {
275	switch status {
276	case ToolStatusSuccess:
277		return sty.Tool.IconSuccess.String()
278	case ToolStatusError:
279		return sty.Tool.IconError.String()
280	case ToolStatusCancelled:
281		return sty.Tool.IconCancelled.String()
282	default:
283		return sty.Tool.IconPending.String()
284	}
285}
286
287// renderToolHeader builds the tool header line: "● ToolName params..."
288func renderToolHeader(ctx *ToolCallContext, name string, width int, params ...string) string {
289	sty := ctx.Styles
290	icon := renderToolIcon(ctx.Status(), sty)
291
292	var toolName string
293	if ctx.IsNested {
294		toolName = sty.Tool.NameNested.Render(name)
295	} else {
296		toolName = sty.Tool.NameNormal.Render(name)
297	}
298
299	prefix := fmt.Sprintf("%s %s ", icon, toolName)
300	prefixWidth := lipgloss.Width(prefix)
301	remainingWidth := width - prefixWidth
302
303	paramsStr := renderParamList(params, remainingWidth, sty)
304	return prefix + paramsStr
305}
306
307// renderParamList formats parameters as "main (key=value, ...)" with truncation.
308func renderParamList(params []string, width int, sty *styles.Styles) string {
309	if len(params) == 0 {
310		return ""
311	}
312
313	mainParam := params[0]
314	if width >= 0 && lipgloss.Width(mainParam) > width {
315		mainParam = ansi.Truncate(mainParam, width, "…")
316	}
317
318	if len(params) == 1 {
319		return sty.Tool.ParamMain.Render(mainParam)
320	}
321
322	// Build key=value pairs from remaining params.
323	otherParams := params[1:]
324	if len(otherParams)%2 != 0 {
325		otherParams = append(otherParams, "")
326	}
327
328	var parts []string
329	for i := 0; i < len(otherParams); i += 2 {
330		key := otherParams[i]
331		value := otherParams[i+1]
332		if value == "" {
333			continue
334		}
335		parts = append(parts, fmt.Sprintf("%s=%s", key, value))
336	}
337
338	if len(parts) == 0 {
339		return sty.Tool.ParamMain.Render(ansi.Truncate(mainParam, width, "…"))
340	}
341
342	partsRendered := strings.Join(parts, ", ")
343	remainingWidth := width - lipgloss.Width(partsRendered) - 3 // " ()"
344	if remainingWidth < 30 {
345		// Not enough space for params, just show main.
346		return sty.Tool.ParamMain.Render(ansi.Truncate(mainParam, width, "…"))
347	}
348
349	fullParam := fmt.Sprintf("%s (%s)", mainParam, partsRendered)
350	return sty.Tool.ParamMain.Render(ansi.Truncate(fullParam, width, "…"))
351}
352
353// renderEarlyState handles error/cancelled/pending states before content rendering.
354// Returns the rendered output and true if early state was handled.
355func renderEarlyState(ctx *ToolCallContext, header string, width int) (string, bool) {
356	sty := ctx.Styles
357
358	var msg string
359	switch ctx.Status() {
360	case ToolStatusError:
361		msg = renderToolError(ctx, width)
362	case ToolStatusCancelled:
363		msg = sty.Tool.StateCancelled.Render("Canceled.")
364	case ToolStatusAwaitingPermission:
365		msg = sty.Tool.StateWaiting.Render("Requesting permission...")
366	case ToolStatusRunning:
367		msg = sty.Tool.StateWaiting.Render("Waiting for tool response...")
368	default:
369		return "", false
370	}
371
372	msg = sty.Tool.BodyPadding.Render(msg)
373	return lipgloss.JoinVertical(lipgloss.Left, header, "", msg), true
374}
375
376// renderToolError formats an error message with ERROR tag.
377func renderToolError(ctx *ToolCallContext, width int) string {
378	sty := ctx.Styles
379	errContent := strings.ReplaceAll(ctx.Result.Content, "\n", " ")
380	errTag := sty.Tool.ErrorTag.Render("ERROR")
381	tagWidth := lipgloss.Width(errTag)
382	errContent = ansi.Truncate(errContent, width-tagWidth-3, "…")
383	return fmt.Sprintf("%s %s", errTag, sty.Tool.ErrorMessage.Render(errContent))
384}
385
386// joinHeaderBody combines header and body with proper padding.
387func joinHeaderBody(header, body string, sty *styles.Styles) string {
388	if body == "" {
389		return header
390	}
391	body = sty.Tool.BodyPadding.Render(body)
392	return lipgloss.JoinVertical(lipgloss.Left, header, "", body)
393}
394
395// renderPlainContent renders plain text with optional expansion support.
396func renderPlainContent(content string, width int, sty *styles.Styles, item *toolItem) string {
397	content = strings.ReplaceAll(content, "\r\n", "\n")
398	content = strings.ReplaceAll(content, "\t", "    ")
399	content = strings.TrimSpace(content)
400	lines := strings.Split(content, "\n")
401
402	expanded := item != nil && item.expanded
403	maxLines := responseContextHeight
404	if expanded {
405		maxLines = len(lines) // Show all
406	}
407
408	var out []string
409	for i, ln := range lines {
410		if i >= maxLines {
411			break
412		}
413		ln = " " + ln
414		if lipgloss.Width(ln) > width {
415			ln = ansi.Truncate(ln, width, "…")
416		}
417		out = append(out, sty.Tool.ContentLine.Width(width).Render(ln))
418	}
419
420	wasTruncated := len(lines) > responseContextHeight
421	if item != nil {
422		item.wasTruncated = wasTruncated
423	}
424
425	if !expanded && wasTruncated {
426		out = append(out, sty.Tool.ContentTruncation.
427			Width(width).
428			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight)))
429	}
430
431	return strings.Join(out, "\n")
432}
433
434// formatNonZero returns string representation of non-zero integers, empty for zero.
435func formatNonZero(value int) string {
436	if value == 0 {
437		return ""
438	}
439	return fmt.Sprintf("%d", value)
440}
441
442// renderCodeContent renders syntax-highlighted code with line numbers and optional expansion.
443func renderCodeContent(path, content string, offset, width int, sty *styles.Styles, item *toolItem) string {
444	content = strings.ReplaceAll(content, "\r\n", "\n")
445	content = strings.ReplaceAll(content, "\t", "    ")
446
447	lines := strings.Split(content, "\n")
448
449	maxLines := responseContextHeight
450	if item != nil && item.expanded {
451		maxLines = len(lines)
452	}
453
454	truncated := lines
455	if len(lines) > maxLines {
456		truncated = lines[:maxLines]
457	}
458
459	// Escape ANSI sequences in content.
460	for i, ln := range truncated {
461		truncated[i] = ansiext.Escape(ln)
462	}
463
464	// Apply syntax highlighting.
465	bg := sty.Tool.ContentCodeBg
466	highlighted, _ := common.SyntaxHighlight(sty, strings.Join(truncated, "\n"), path, bg)
467	highlightedLines := strings.Split(highlighted, "\n")
468
469	// Calculate gutter width for line numbers.
470	maxLineNum := offset + len(highlightedLines)
471	maxDigits := getDigits(maxLineNum)
472	numFmt := fmt.Sprintf("%%%dd", maxDigits)
473
474	// Calculate available width for code (accounting for gutter).
475	const numPR, numPL, codePR, codePL = 1, 1, 1, 2
476	codeWidth := width - maxDigits - numPL - numPR - 2
477
478	var out []string
479	for i, ln := range highlightedLines {
480		lineNum := sty.Base.
481			Foreground(sty.FgMuted).
482			Background(bg).
483			PaddingRight(numPR).
484			PaddingLeft(numPL).
485			Render(fmt.Sprintf(numFmt, offset+i+1))
486
487		codeLine := sty.Base.
488			Width(codeWidth).
489			Background(bg).
490			PaddingRight(codePR).
491			PaddingLeft(codePL).
492			Render(ansi.Truncate(ln, codeWidth-codePL-codePR, "…"))
493
494		out = append(out, lipgloss.JoinHorizontal(lipgloss.Left, lineNum, codeLine))
495	}
496
497	wasTruncated := len(lines) > responseContextHeight
498	if item != nil {
499		item.wasTruncated = wasTruncated
500	}
501
502	expanded := item != nil && item.expanded
503
504	if !expanded && wasTruncated {
505		msg := fmt.Sprintf(" …(%d lines) [click or space to expand]", len(lines)-responseContextHeight)
506		out = append(out, sty.Muted.Background(bg).Render(msg))
507	}
508
509	return lipgloss.JoinVertical(lipgloss.Left, out...)
510}
511
512// renderMarkdownContent renders markdown with optional expansion support.
513func renderMarkdownContent(content string, width int, sty *styles.Styles, item *toolItem) string {
514	content = strings.ReplaceAll(content, "\r\n", "\n")
515	content = strings.ReplaceAll(content, "\t", "    ")
516	content = strings.TrimSpace(content)
517
518	cappedWidth := min(width, 120)
519	renderer := common.PlainMarkdownRenderer(sty, cappedWidth)
520	rendered, err := renderer.Render(content)
521	if err != nil {
522		return renderPlainContent(content, width, sty, nil)
523	}
524
525	lines := strings.Split(rendered, "\n")
526
527	maxLines := responseContextHeight
528	if item != nil && item.expanded {
529		maxLines = len(lines)
530	}
531
532	var out []string
533	for i, ln := range lines {
534		if i >= maxLines {
535			break
536		}
537		out = append(out, ln)
538	}
539
540	wasTruncated := len(lines) > responseContextHeight
541	if item != nil {
542		item.wasTruncated = wasTruncated
543	}
544
545	expanded := item != nil && item.expanded
546
547	if !expanded && wasTruncated {
548		out = append(out, sty.Tool.ContentTruncation.
549			Width(cappedWidth-2).
550			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight)))
551	}
552
553	return sty.Tool.ContentLine.Render(strings.Join(out, "\n"))
554}
555
556// renderDiffContent renders a diff with optional expansion support.
557func renderDiffContent(file, oldContent, newContent string, width int, sty *styles.Styles, item *toolItem) string {
558	formatter := common.DiffFormatter(sty).
559		Before(file, oldContent).
560		After(file, newContent).
561		Width(width)
562
563	if width > 120 {
564		formatter = formatter.Split()
565	}
566
567	formatted := formatter.String()
568	lines := strings.Split(formatted, "\n")
569
570	wasTruncated := len(lines) > responseContextHeight
571	if item != nil {
572		item.wasTruncated = wasTruncated
573	}
574
575	expanded := item != nil && item.expanded
576
577	if !expanded && wasTruncated {
578		truncateMsg := sty.Tool.DiffTruncation.
579			Width(width).
580			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight))
581		formatted = strings.Join(lines[:responseContextHeight], "\n") + "\n" + truncateMsg
582	}
583
584	return formatted
585}
586
587// renderImageContent renders image data with optional text content.
588func renderImageContent(data, mediaType, textContent string, sty *styles.Styles) string {
589	dataSize := len(data) * 3 / 4 // Base64 to bytes approximation.
590	sizeStr := formatSize(dataSize)
591
592	loaded := sty.Tool.IconSuccess.String()
593	arrow := sty.Tool.NameNested.Render("→")
594	typeStyled := sty.Base.Render(mediaType)
595	sizeStyled := sty.Subtle.Render(sizeStr)
596
597	imageDisplay := fmt.Sprintf("%s %s %s %s", loaded, arrow, typeStyled, sizeStyled)
598
599	if strings.TrimSpace(textContent) != "" {
600		textDisplay := sty.Tool.ContentLine.Render(textContent)
601		return lipgloss.JoinVertical(lipgloss.Left, textDisplay, "", imageDisplay)
602	}
603
604	return imageDisplay
605}
606
607// renderMediaContent renders non-image media content.
608func renderMediaContent(mediaType, textContent string, sty *styles.Styles) string {
609	loaded := sty.Tool.IconSuccess.String()
610	arrow := sty.Tool.NameNested.Render("→")
611	typeStyled := sty.Base.Render(mediaType)
612	mediaDisplay := fmt.Sprintf("%s %s %s", loaded, arrow, typeStyled)
613
614	if strings.TrimSpace(textContent) != "" {
615		textDisplay := sty.Tool.ContentLine.Render(textContent)
616		return lipgloss.JoinVertical(lipgloss.Left, textDisplay, "", mediaDisplay)
617	}
618
619	return mediaDisplay
620}
621
622// formatSize formats byte count as human-readable size.
623func formatSize(bytes int) string {
624	if bytes < 1024 {
625		return fmt.Sprintf("%d B", bytes)
626	}
627	if bytes < 1024*1024 {
628		return fmt.Sprintf("%.1f KB", float64(bytes)/1024)
629	}
630	return fmt.Sprintf("%.1f MB", float64(bytes)/(1024*1024))
631}
632
633// getDigits returns the number of digits in a number.
634func getDigits(n int) int {
635	if n == 0 {
636		return 1
637	}
638	if n < 0 {
639		n = -n
640	}
641	digits := 0
642	for n > 0 {
643		n /= 10
644		digits++
645	}
646	return digits
647}
648
649// formatTodosList formats a list of todos with status icons.
650func formatTodosList(todos []session.Todo, width int, sty *styles.Styles) string {
651	if len(todos) == 0 {
652		return ""
653	}
654
655	sorted := make([]session.Todo, len(todos))
656	copy(sorted, todos)
657	slices.SortStableFunc(sorted, func(a, b session.Todo) int {
658		return todoStatusOrder(a.Status) - todoStatusOrder(b.Status)
659	})
660
661	var lines []string
662	for _, todo := range sorted {
663		var prefix string
664		var textStyle lipgloss.Style
665
666		switch todo.Status {
667		case session.TodoStatusCompleted:
668			prefix = sty.Base.Foreground(sty.Green).Render(styles.TodoCompletedIcon) + " "
669			textStyle = sty.Base
670		case session.TodoStatusInProgress:
671			prefix = sty.Base.Foreground(sty.GreenDark).Render(styles.ArrowRightIcon) + " "
672			textStyle = sty.Base
673		default:
674			prefix = sty.Muted.Render(styles.TodoPendingIcon) + " "
675			textStyle = sty.Base
676		}
677
678		text := todo.Content
679		if todo.Status == session.TodoStatusInProgress && todo.ActiveForm != "" {
680			text = todo.ActiveForm
681		}
682
683		line := prefix + textStyle.Render(text)
684		line = ansi.Truncate(line, width, "…")
685		lines = append(lines, line)
686	}
687
688	return strings.Join(lines, "\n")
689}
690
691// todoStatusOrder returns sort order for todo statuses.
692func todoStatusOrder(s session.TodoStatus) int {
693	switch s {
694	case session.TodoStatusCompleted:
695		return 0
696	case session.TodoStatusInProgress:
697		return 1
698	default:
699		return 2
700	}
701}