tool_base.go

  1package chat
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"charm.land/bubbles/v2/key"
 10	tea "charm.land/bubbletea/v2"
 11	"charm.land/lipgloss/v2"
 12	"github.com/charmbracelet/crush/internal/ansiext"
 13	"github.com/charmbracelet/crush/internal/message"
 14	"github.com/charmbracelet/crush/internal/session"
 15	"github.com/charmbracelet/crush/internal/ui/common"
 16	"github.com/charmbracelet/crush/internal/ui/styles"
 17	"github.com/charmbracelet/x/ansi"
 18)
 19
 20// responseContextHeight limits the number of lines displayed in tool output.
 21const responseContextHeight = 10
 22
 23// ToolStatus represents the current state of a tool call.
 24type ToolStatus int
 25
 26const (
 27	ToolStatusAwaitingPermission ToolStatus = iota
 28	ToolStatusRunning
 29	ToolStatusSuccess
 30	ToolStatusError
 31	ToolStatusCancelled
 32)
 33
 34// ToolCallContext provides the context needed for rendering a tool call.
 35type ToolCallContext struct {
 36	Call                message.ToolCall
 37	Result              *message.ToolResult
 38	Cancelled           bool
 39	PermissionRequested bool
 40	PermissionGranted   bool
 41	IsNested            bool
 42	Styles              *styles.Styles
 43
 44	// NestedCalls holds child tool calls for agent/agentic_fetch.
 45	NestedCalls []ToolCallContext
 46}
 47
 48// Status returns the current status of the tool call.
 49func (ctx *ToolCallContext) Status() ToolStatus {
 50	if ctx.Cancelled {
 51		return ToolStatusCancelled
 52	}
 53	if ctx.HasResult() {
 54		if ctx.Result.IsError {
 55			return ToolStatusError
 56		}
 57		return ToolStatusSuccess
 58	}
 59	// No result yet - check permission state.
 60	if ctx.PermissionRequested && !ctx.PermissionGranted {
 61		return ToolStatusAwaitingPermission
 62	}
 63	return ToolStatusRunning
 64}
 65
 66// HasResult returns true if the tool call has a completed result.
 67func (ctx *ToolCallContext) HasResult() bool {
 68	return ctx.Result != nil && ctx.Result.ToolCallID != ""
 69}
 70
 71// toolStyles provides common FocusStylable and HighlightStylable implementations.
 72// Embed this in tool items to avoid repeating style methods.
 73type toolStyles struct {
 74	sty *styles.Styles
 75}
 76
 77func (s toolStyles) FocusStyle() lipgloss.Style {
 78	return s.sty.Chat.Message.ToolCallFocused
 79}
 80
 81func (s toolStyles) BlurStyle() lipgloss.Style {
 82	return s.sty.Chat.Message.ToolCallBlurred
 83}
 84
 85func (s toolStyles) HighlightStyle() lipgloss.Style {
 86	return s.sty.TextSelection
 87}
 88
 89// toolItem provides common base functionality for all tool items.
 90type toolItem struct {
 91	toolStyles
 92	id           string
 93	expanded     bool // Whether truncated content is expanded.
 94	wasTruncated bool // Whether the last render was truncated.
 95}
 96
 97// newToolItem creates a new toolItem with the given context.
 98func newToolItem(ctx ToolCallContext) toolItem {
 99	return toolItem{
100		toolStyles: toolStyles{sty: ctx.Styles},
101		id:         ctx.Call.ID,
102	}
103}
104
105// ID implements Identifiable.
106func (t *toolItem) ID() string {
107	return t.id
108}
109
110// HandleMouseClick implements list.MouseClickable.
111func (t *toolItem) HandleMouseClick(btn ansi.MouseButton, x, y int) bool {
112	// Only handle left clicks on truncated content.
113	if btn != ansi.MouseLeft || !t.wasTruncated {
114		return false
115	}
116
117	// Toggle expansion.
118	t.expanded = !t.expanded
119	return true
120}
121
122// HandleKeyPress implements list.KeyPressable.
123func (t *toolItem) HandleKeyPress(msg tea.KeyPressMsg) bool {
124	// Only handle space key on truncated content.
125	if !t.wasTruncated {
126		return false
127	}
128
129	if key.Matches(msg, key.NewBinding(key.WithKeys("space"))) {
130		// Toggle expansion.
131		t.expanded = !t.expanded
132		return true
133	}
134
135	return false
136}
137
138// unmarshalParams unmarshals JSON input into the target struct.
139func unmarshalParams(input string, target any) error {
140	return json.Unmarshal([]byte(input), target)
141}
142
143// ParamBuilder helps construct parameter lists for tool headers.
144type ParamBuilder struct {
145	args []string
146}
147
148// NewParamBuilder creates a new parameter builder.
149func NewParamBuilder() *ParamBuilder {
150	return &ParamBuilder{args: make([]string, 0, 4)}
151}
152
153// Main adds the main parameter (first positional argument).
154func (pb *ParamBuilder) Main(value string) *ParamBuilder {
155	if value != "" {
156		pb.args = append(pb.args, value)
157	}
158	return pb
159}
160
161// KeyValue adds a key-value pair parameter.
162func (pb *ParamBuilder) KeyValue(key, value string) *ParamBuilder {
163	if value != "" {
164		pb.args = append(pb.args, key, value)
165	}
166	return pb
167}
168
169// Flag adds a boolean flag parameter (only if true).
170func (pb *ParamBuilder) Flag(key string, value bool) *ParamBuilder {
171	if value {
172		pb.args = append(pb.args, key, "true")
173	}
174	return pb
175}
176
177// Build returns the parameter list.
178func (pb *ParamBuilder) Build() []string {
179	return pb.args
180}
181
182// renderToolIcon returns the status icon for a tool call.
183func renderToolIcon(status ToolStatus, sty *styles.Styles) string {
184	switch status {
185	case ToolStatusSuccess:
186		return sty.Tool.IconSuccess.String()
187	case ToolStatusError:
188		return sty.Tool.IconError.String()
189	case ToolStatusCancelled:
190		return sty.Tool.IconCancelled.String()
191	default:
192		return sty.Tool.IconPending.String()
193	}
194}
195
196// renderToolHeader builds the tool header line: "● ToolName params..."
197func renderToolHeader(ctx *ToolCallContext, name string, width int, params ...string) string {
198	sty := ctx.Styles
199	icon := renderToolIcon(ctx.Status(), sty)
200
201	var toolName string
202	if ctx.IsNested {
203		toolName = sty.Tool.NameNested.Render(name)
204	} else {
205		toolName = sty.Tool.NameNormal.Render(name)
206	}
207
208	prefix := fmt.Sprintf("%s %s ", icon, toolName)
209	prefixWidth := lipgloss.Width(prefix)
210	remainingWidth := width - prefixWidth
211
212	paramsStr := renderParamList(params, remainingWidth, sty)
213	return prefix + paramsStr
214}
215
216// renderParamList formats parameters as "main (key=value, ...)" with truncation.
217func renderParamList(params []string, width int, sty *styles.Styles) string {
218	if len(params) == 0 {
219		return ""
220	}
221
222	mainParam := params[0]
223	if width >= 0 && lipgloss.Width(mainParam) > width {
224		mainParam = ansi.Truncate(mainParam, width, "…")
225	}
226
227	if len(params) == 1 {
228		return sty.Tool.ParamMain.Render(mainParam)
229	}
230
231	// Build key=value pairs from remaining params.
232	otherParams := params[1:]
233	if len(otherParams)%2 != 0 {
234		otherParams = append(otherParams, "")
235	}
236
237	var parts []string
238	for i := 0; i < len(otherParams); i += 2 {
239		key := otherParams[i]
240		value := otherParams[i+1]
241		if value == "" {
242			continue
243		}
244		parts = append(parts, fmt.Sprintf("%s=%s", key, value))
245	}
246
247	if len(parts) == 0 {
248		return sty.Tool.ParamMain.Render(ansi.Truncate(mainParam, width, "…"))
249	}
250
251	partsRendered := strings.Join(parts, ", ")
252	remainingWidth := width - lipgloss.Width(partsRendered) - 3 // " ()"
253	if remainingWidth < 30 {
254		// Not enough space for params, just show main.
255		return sty.Tool.ParamMain.Render(ansi.Truncate(mainParam, width, "…"))
256	}
257
258	fullParam := fmt.Sprintf("%s (%s)", mainParam, partsRendered)
259	return sty.Tool.ParamMain.Render(ansi.Truncate(fullParam, width, "…"))
260}
261
262// renderEarlyState handles error/cancelled/pending states before content rendering.
263// Returns the rendered output and true if early state was handled.
264func renderEarlyState(ctx *ToolCallContext, header string, width int) (string, bool) {
265	sty := ctx.Styles
266
267	var msg string
268	switch ctx.Status() {
269	case ToolStatusError:
270		msg = renderToolError(ctx, width)
271	case ToolStatusCancelled:
272		msg = sty.Tool.StateCancelled.Render("Canceled.")
273	case ToolStatusAwaitingPermission:
274		msg = sty.Tool.StateWaiting.Render("Requesting permission...")
275	case ToolStatusRunning:
276		msg = sty.Tool.StateWaiting.Render("Waiting for tool response...")
277	default:
278		return "", false
279	}
280
281	msg = sty.Tool.BodyPadding.Render(msg)
282	return lipgloss.JoinVertical(lipgloss.Left, header, "", msg), true
283}
284
285// renderToolError formats an error message with ERROR tag.
286func renderToolError(ctx *ToolCallContext, width int) string {
287	sty := ctx.Styles
288	errContent := strings.ReplaceAll(ctx.Result.Content, "\n", " ")
289	errTag := sty.Tool.ErrorTag.Render("ERROR")
290	tagWidth := lipgloss.Width(errTag)
291	errContent = ansi.Truncate(errContent, width-tagWidth-3, "…")
292	return fmt.Sprintf("%s %s", errTag, sty.Tool.ErrorMessage.Render(errContent))
293}
294
295// joinHeaderBody combines header and body with proper padding.
296func joinHeaderBody(header, body string, sty *styles.Styles) string {
297	if body == "" {
298		return header
299	}
300	body = sty.Tool.BodyPadding.Render(body)
301	return lipgloss.JoinVertical(lipgloss.Left, header, "", body)
302}
303
304// renderPlainContent renders plain text with optional expansion support.
305func renderPlainContent(content string, width int, sty *styles.Styles, item *toolItem) string {
306	content = strings.ReplaceAll(content, "\r\n", "\n")
307	content = strings.ReplaceAll(content, "\t", "    ")
308	content = strings.TrimSpace(content)
309	lines := strings.Split(content, "\n")
310
311	expanded := item != nil && item.expanded
312	maxLines := responseContextHeight
313	if expanded {
314		maxLines = len(lines) // Show all
315	}
316
317	var out []string
318	for i, ln := range lines {
319		if i >= maxLines {
320			break
321		}
322		ln = " " + ln
323		if lipgloss.Width(ln) > width {
324			ln = ansi.Truncate(ln, width, "…")
325		}
326		out = append(out, sty.Tool.ContentLine.Width(width).Render(ln))
327	}
328
329	wasTruncated := len(lines) > responseContextHeight
330	if item != nil {
331		item.wasTruncated = wasTruncated
332	}
333
334	if !expanded && wasTruncated {
335		out = append(out, sty.Tool.ContentTruncation.
336			Width(width).
337			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight)))
338	}
339
340	return strings.Join(out, "\n")
341}
342
343// formatNonZero returns string representation of non-zero integers, empty for zero.
344func formatNonZero(value int) string {
345	if value == 0 {
346		return ""
347	}
348	return fmt.Sprintf("%d", value)
349}
350
351// renderCodeContent renders syntax-highlighted code with line numbers and optional expansion.
352func renderCodeContent(path, content string, offset, width int, sty *styles.Styles, item *toolItem) string {
353	content = strings.ReplaceAll(content, "\r\n", "\n")
354	content = strings.ReplaceAll(content, "\t", "    ")
355
356	lines := strings.Split(content, "\n")
357
358	maxLines := responseContextHeight
359	if item != nil && item.expanded {
360		maxLines = len(lines)
361	}
362
363	truncated := lines
364	if len(lines) > maxLines {
365		truncated = lines[:maxLines]
366	}
367
368	// Escape ANSI sequences in content.
369	for i, ln := range truncated {
370		truncated[i] = ansiext.Escape(ln)
371	}
372
373	// Apply syntax highlighting.
374	bg := sty.Tool.ContentCodeBg
375	highlighted, _ := common.SyntaxHighlight(sty, strings.Join(truncated, "\n"), path, bg)
376	highlightedLines := strings.Split(highlighted, "\n")
377
378	// Calculate gutter width for line numbers.
379	maxLineNum := offset + len(highlightedLines)
380	maxDigits := getDigits(maxLineNum)
381	numFmt := fmt.Sprintf("%%%dd", maxDigits)
382
383	// Calculate available width for code (accounting for gutter).
384	const numPR, numPL, codePR, codePL = 1, 1, 1, 2
385	codeWidth := width - maxDigits - numPL - numPR - 2
386
387	var out []string
388	for i, ln := range highlightedLines {
389		lineNum := sty.Base.
390			Foreground(sty.FgMuted).
391			Background(bg).
392			PaddingRight(numPR).
393			PaddingLeft(numPL).
394			Render(fmt.Sprintf(numFmt, offset+i+1))
395
396		codeLine := sty.Base.
397			Width(codeWidth).
398			Background(bg).
399			PaddingRight(codePR).
400			PaddingLeft(codePL).
401			Render(ansi.Truncate(ln, codeWidth-codePL-codePR, "…"))
402
403		out = append(out, lipgloss.JoinHorizontal(lipgloss.Left, lineNum, codeLine))
404	}
405
406	wasTruncated := len(lines) > responseContextHeight
407	if item != nil {
408		item.wasTruncated = wasTruncated
409	}
410
411	expanded := item != nil && item.expanded
412
413	if !expanded && wasTruncated {
414		msg := fmt.Sprintf(" …(%d lines) [click or space to expand]", len(lines)-responseContextHeight)
415		out = append(out, sty.Muted.Background(bg).Render(msg))
416	}
417
418	return lipgloss.JoinVertical(lipgloss.Left, out...)
419}
420
421// renderMarkdownContent renders markdown with optional expansion support.
422func renderMarkdownContent(content string, width int, sty *styles.Styles, item *toolItem) string {
423	content = strings.ReplaceAll(content, "\r\n", "\n")
424	content = strings.ReplaceAll(content, "\t", "    ")
425	content = strings.TrimSpace(content)
426
427	cappedWidth := min(width, 120)
428	renderer := common.PlainMarkdownRenderer(sty, cappedWidth)
429	rendered, err := renderer.Render(content)
430	if err != nil {
431		return renderPlainContent(content, width, sty, nil)
432	}
433
434	lines := strings.Split(rendered, "\n")
435
436	maxLines := responseContextHeight
437	if item != nil && item.expanded {
438		maxLines = len(lines)
439	}
440
441	var out []string
442	for i, ln := range lines {
443		if i >= maxLines {
444			break
445		}
446		out = append(out, ln)
447	}
448
449	wasTruncated := len(lines) > responseContextHeight
450	if item != nil {
451		item.wasTruncated = wasTruncated
452	}
453
454	expanded := item != nil && item.expanded
455
456	if !expanded && wasTruncated {
457		out = append(out, sty.Tool.ContentTruncation.
458			Width(cappedWidth-2).
459			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight)))
460	}
461
462	return sty.Tool.ContentLine.Render(strings.Join(out, "\n"))
463}
464
465// renderDiffContent renders a diff with optional expansion support.
466func renderDiffContent(file, oldContent, newContent string, width int, sty *styles.Styles, item *toolItem) string {
467	formatter := common.DiffFormatter(sty).
468		Before(file, oldContent).
469		After(file, newContent).
470		Width(width)
471
472	if width > 120 {
473		formatter = formatter.Split()
474	}
475
476	formatted := formatter.String()
477	lines := strings.Split(formatted, "\n")
478
479	wasTruncated := len(lines) > responseContextHeight
480	if item != nil {
481		item.wasTruncated = wasTruncated
482	}
483
484	expanded := item != nil && item.expanded
485
486	if !expanded && wasTruncated {
487		truncateMsg := sty.Tool.DiffTruncation.
488			Width(width).
489			Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight))
490		formatted = strings.Join(lines[:responseContextHeight], "\n") + "\n" + truncateMsg
491	}
492
493	return formatted
494}
495
496// renderImageContent renders image data with optional text content.
497func renderImageContent(data, mediaType, textContent string, sty *styles.Styles) string {
498	dataSize := len(data) * 3 / 4 // Base64 to bytes approximation.
499	sizeStr := formatSize(dataSize)
500
501	loaded := sty.Tool.IconSuccess.String()
502	arrow := sty.Tool.NameNested.Render("→")
503	typeStyled := sty.Base.Render(mediaType)
504	sizeStyled := sty.Subtle.Render(sizeStr)
505
506	imageDisplay := fmt.Sprintf("%s %s %s %s", loaded, arrow, typeStyled, sizeStyled)
507
508	if strings.TrimSpace(textContent) != "" {
509		textDisplay := sty.Tool.ContentLine.Render(textContent)
510		return lipgloss.JoinVertical(lipgloss.Left, textDisplay, "", imageDisplay)
511	}
512
513	return imageDisplay
514}
515
516// renderMediaContent renders non-image media content.
517func renderMediaContent(mediaType, textContent string, sty *styles.Styles) string {
518	loaded := sty.Tool.IconSuccess.String()
519	arrow := sty.Tool.NameNested.Render("→")
520	typeStyled := sty.Base.Render(mediaType)
521	mediaDisplay := fmt.Sprintf("%s %s %s", loaded, arrow, typeStyled)
522
523	if strings.TrimSpace(textContent) != "" {
524		textDisplay := sty.Tool.ContentLine.Render(textContent)
525		return lipgloss.JoinVertical(lipgloss.Left, textDisplay, "", mediaDisplay)
526	}
527
528	return mediaDisplay
529}
530
531// formatSize formats byte count as human-readable size.
532func formatSize(bytes int) string {
533	if bytes < 1024 {
534		return fmt.Sprintf("%d B", bytes)
535	}
536	if bytes < 1024*1024 {
537		return fmt.Sprintf("%.1f KB", float64(bytes)/1024)
538	}
539	return fmt.Sprintf("%.1f MB", float64(bytes)/(1024*1024))
540}
541
542// getDigits returns the number of digits in a number.
543func getDigits(n int) int {
544	if n == 0 {
545		return 1
546	}
547	if n < 0 {
548		n = -n
549	}
550	digits := 0
551	for n > 0 {
552		n /= 10
553		digits++
554	}
555	return digits
556}
557
558// formatTodosList formats a list of todos with status icons.
559func formatTodosList(todos []session.Todo, width int, sty *styles.Styles) string {
560	if len(todos) == 0 {
561		return ""
562	}
563
564	sorted := make([]session.Todo, len(todos))
565	copy(sorted, todos)
566	slices.SortStableFunc(sorted, func(a, b session.Todo) int {
567		return todoStatusOrder(a.Status) - todoStatusOrder(b.Status)
568	})
569
570	var lines []string
571	for _, todo := range sorted {
572		var prefix string
573		var textStyle lipgloss.Style
574
575		switch todo.Status {
576		case session.TodoStatusCompleted:
577			prefix = sty.Base.Foreground(sty.Green).Render(styles.TodoCompletedIcon) + " "
578			textStyle = sty.Base
579		case session.TodoStatusInProgress:
580			prefix = sty.Base.Foreground(sty.GreenDark).Render(styles.ArrowRightIcon) + " "
581			textStyle = sty.Base
582		default:
583			prefix = sty.Muted.Render(styles.TodoPendingIcon) + " "
584			textStyle = sty.Base
585		}
586
587		text := todo.Content
588		if todo.Status == session.TodoStatusInProgress && todo.ActiveForm != "" {
589			text = todo.ActiveForm
590		}
591
592		line := prefix + textStyle.Render(text)
593		line = ansi.Truncate(line, width, "…")
594		lines = append(lines, line)
595	}
596
597	return strings.Join(lines, "\n")
598}
599
600// todoStatusOrder returns sort order for todo statuses.
601func todoStatusOrder(s session.TodoStatus) int {
602	switch s {
603	case session.TodoStatusCompleted:
604		return 0
605	case session.TodoStatusInProgress:
606		return 1
607	default:
608		return 2
609	}
610}