1package chat
2
3import (
4 "encoding/json"
5 "fmt"
6 "slices"
7 "strings"
8
9 "charm.land/bubbles/v2/key"
10 tea "charm.land/bubbletea/v2"
11 "charm.land/lipgloss/v2"
12 "github.com/charmbracelet/crush/internal/ansiext"
13 "github.com/charmbracelet/crush/internal/message"
14 "github.com/charmbracelet/crush/internal/session"
15 "github.com/charmbracelet/crush/internal/ui/common"
16 "github.com/charmbracelet/crush/internal/ui/common/anim"
17 "github.com/charmbracelet/crush/internal/ui/styles"
18 "github.com/charmbracelet/x/ansi"
19)
20
21// responseContextHeight limits the number of lines displayed in tool output.
22const responseContextHeight = 10
23
24// ToolStatus represents the current state of a tool call.
25type ToolStatus int
26
27const (
28 ToolStatusAwaitingPermission ToolStatus = iota
29 ToolStatusRunning
30 ToolStatusSuccess
31 ToolStatusError
32 ToolStatusCancelled
33)
34
35// ToolCallContext provides the context needed for rendering a tool call.
36type ToolCallContext struct {
37 Call message.ToolCall
38 Result *message.ToolResult
39 Cancelled bool
40 PermissionRequested bool
41 PermissionGranted bool
42 IsNested bool
43 Styles *styles.Styles
44
45 NestedCalls []ToolCallContext
46}
47
48// Status returns the current status of the tool call.
49func (ctx *ToolCallContext) Status() ToolStatus {
50 if ctx.Cancelled {
51 return ToolStatusCancelled
52 }
53 if ctx.HasResult() {
54 if ctx.Result.IsError {
55 return ToolStatusError
56 }
57 return ToolStatusSuccess
58 }
59 if ctx.PermissionRequested && !ctx.PermissionGranted {
60 return ToolStatusAwaitingPermission
61 }
62 return ToolStatusRunning
63}
64
65// HasResult returns true if the tool call has a completed result.
66func (ctx *ToolCallContext) HasResult() bool {
67 return ctx.Result != nil && ctx.Result.ToolCallID != ""
68}
69
70// toolStyles provides common FocusStylable and HighlightStylable implementations.
71type toolStyles struct {
72 sty *styles.Styles
73}
74
75func (s toolStyles) FocusStyle() lipgloss.Style {
76 return s.sty.Chat.Message.ToolCallFocused
77}
78
79func (s toolStyles) BlurStyle() lipgloss.Style {
80 return s.sty.Chat.Message.ToolCallBlurred
81}
82
83func (s toolStyles) HighlightStyle() lipgloss.Style {
84 return s.sty.TextSelection
85}
86
87// toolItem provides common base functionality for all tool items.
88type toolItem struct {
89 toolStyles
90 id string
91 ctx ToolCallContext
92 expanded bool
93 wasTruncated bool
94 spinning bool
95 anim *anim.Anim
96}
97
98// newToolItem creates a new toolItem with the given context.
99func newToolItem(ctx ToolCallContext) toolItem {
100 animSize := 15
101 if ctx.IsNested {
102 animSize = 10
103 }
104
105 t := toolItem{
106 toolStyles: toolStyles{sty: ctx.Styles},
107 id: ctx.Call.ID,
108 ctx: ctx,
109 spinning: shouldSpin(ctx),
110 anim: anim.New(anim.Settings{
111 Size: animSize,
112 Label: "Working",
113 GradColorA: ctx.Styles.Primary,
114 GradColorB: ctx.Styles.Secondary,
115 LabelColor: ctx.Styles.FgBase,
116 CycleColors: true,
117 }),
118 }
119
120 return t
121}
122
123// shouldSpin returns true if the tool should show animation.
124func shouldSpin(ctx ToolCallContext) bool {
125 return !ctx.Call.Finished && !ctx.Cancelled
126}
127
128// ID implements Identifiable.
129func (t *toolItem) ID() string {
130 return t.id
131}
132
133// HandleMouseClick implements list.MouseClickable.
134func (t *toolItem) HandleMouseClick(btn ansi.MouseButton, x, y int) bool {
135 if btn != ansi.MouseLeft || !t.wasTruncated {
136 return false
137 }
138
139 t.expanded = !t.expanded
140 return true
141}
142
143// HandleKeyPress implements list.KeyPressable.
144func (t *toolItem) HandleKeyPress(msg tea.KeyPressMsg) bool {
145 if !t.wasTruncated {
146 return false
147 }
148
149 if key.Matches(msg, key.NewBinding(key.WithKeys("space"))) {
150 t.expanded = !t.expanded
151 return true
152 }
153
154 return false
155}
156
157// updateAnimation handles animation updates and returns true if changed.
158func (t *toolItem) updateAnimation(msg tea.Msg) (tea.Cmd, bool) {
159 if !t.spinning || t.anim == nil {
160 return nil, false
161 }
162
163 switch msg.(type) {
164 case anim.StepMsg:
165 updatedAnim, cmd := t.anim.Update(msg)
166 t.anim = updatedAnim
167 return cmd, cmd != nil
168 }
169
170 return nil, false
171}
172
173// InitAnimation initializes and starts the animation.
174func (t *toolItem) InitAnimation() tea.Cmd {
175 t.spinning = shouldSpin(t.ctx)
176 return t.anim.Init()
177}
178
179// SetResult updates the tool call with a result.
180func (t *toolItem) SetResult(result message.ToolResult) {
181 t.ctx.Result = &result
182 t.ctx.Call.Finished = true
183 t.spinning = false
184}
185
186// SetCancelled marks the tool call as cancelled.
187func (t *toolItem) SetCancelled() {
188 t.ctx.Cancelled = true
189 t.spinning = false
190}
191
192// UpdateCall updates the tool call data.
193func (t *toolItem) UpdateCall(call message.ToolCall) {
194 t.ctx.Call = call
195 if call.Finished {
196 t.spinning = false
197 }
198}
199
200// SetNestedCalls sets the nested tool calls for agent tools.
201func (t *toolItem) SetNestedCalls(calls []ToolCallContext) {
202 t.ctx.NestedCalls = calls
203}
204
205// Context returns the current tool call context.
206func (t *toolItem) Context() *ToolCallContext {
207 return &t.ctx
208}
209
210// renderPending returns the pending state view with animation.
211func (t *toolItem) renderPending() string {
212 icon := t.sty.Tool.IconPending.Render()
213
214 var toolName string
215 if t.ctx.IsNested {
216 toolName = t.sty.Tool.NameNested.Render(prettifyToolName(t.ctx.Call.Name))
217 } else {
218 toolName = t.sty.Tool.NameNormal.Render(prettifyToolName(t.ctx.Call.Name))
219 }
220
221 var animView string
222 if t.anim != nil {
223 animView = t.anim.View()
224 }
225
226 return fmt.Sprintf("%s %s %s", icon, toolName, animView)
227}
228
229// unmarshalParams unmarshals JSON input into the target struct.
230func unmarshalParams(input string, target any) error {
231 return json.Unmarshal([]byte(input), target)
232}
233
234// ParamBuilder helps construct parameter lists for tool headers.
235type ParamBuilder struct {
236 args []string
237}
238
239// NewParamBuilder creates a new parameter builder.
240func NewParamBuilder() *ParamBuilder {
241 return &ParamBuilder{args: make([]string, 0, 4)}
242}
243
244// Main adds the main parameter (first positional argument).
245func (pb *ParamBuilder) Main(value string) *ParamBuilder {
246 if value != "" {
247 pb.args = append(pb.args, value)
248 }
249 return pb
250}
251
252// KeyValue adds a key-value pair parameter.
253func (pb *ParamBuilder) KeyValue(key, value string) *ParamBuilder {
254 if value != "" {
255 pb.args = append(pb.args, key, value)
256 }
257 return pb
258}
259
260// Flag adds a boolean flag parameter (only if true).
261func (pb *ParamBuilder) Flag(key string, value bool) *ParamBuilder {
262 if value {
263 pb.args = append(pb.args, key, "true")
264 }
265 return pb
266}
267
268// Build returns the parameter list.
269func (pb *ParamBuilder) Build() []string {
270 return pb.args
271}
272
273// renderToolIcon returns the status icon for a tool call.
274func renderToolIcon(status ToolStatus, sty *styles.Styles) string {
275 switch status {
276 case ToolStatusSuccess:
277 return sty.Tool.IconSuccess.String()
278 case ToolStatusError:
279 return sty.Tool.IconError.String()
280 case ToolStatusCancelled:
281 return sty.Tool.IconCancelled.String()
282 default:
283 return sty.Tool.IconPending.String()
284 }
285}
286
287// renderToolHeader builds the tool header line: "● ToolName params..."
288func renderToolHeader(ctx *ToolCallContext, name string, width int, params ...string) string {
289 sty := ctx.Styles
290 icon := renderToolIcon(ctx.Status(), sty)
291
292 var toolName string
293 if ctx.IsNested {
294 toolName = sty.Tool.NameNested.Render(name)
295 } else {
296 toolName = sty.Tool.NameNormal.Render(name)
297 }
298
299 prefix := fmt.Sprintf("%s %s ", icon, toolName)
300 prefixWidth := lipgloss.Width(prefix)
301 remainingWidth := width - prefixWidth
302
303 paramsStr := renderParamList(params, remainingWidth, sty)
304 return prefix + paramsStr
305}
306
307// renderParamList formats parameters as "main (key=value, ...)" with truncation.
308func renderParamList(params []string, width int, sty *styles.Styles) string {
309 if len(params) == 0 {
310 return ""
311 }
312
313 mainParam := params[0]
314 if width >= 0 && lipgloss.Width(mainParam) > width {
315 mainParam = ansi.Truncate(mainParam, width, "…")
316 }
317
318 if len(params) == 1 {
319 return sty.Tool.ParamMain.Render(mainParam)
320 }
321
322 // Build key=value pairs from remaining params.
323 otherParams := params[1:]
324 if len(otherParams)%2 != 0 {
325 otherParams = append(otherParams, "")
326 }
327
328 var parts []string
329 for i := 0; i < len(otherParams); i += 2 {
330 key := otherParams[i]
331 value := otherParams[i+1]
332 if value == "" {
333 continue
334 }
335 parts = append(parts, fmt.Sprintf("%s=%s", key, value))
336 }
337
338 if len(parts) == 0 {
339 return sty.Tool.ParamMain.Render(ansi.Truncate(mainParam, width, "…"))
340 }
341
342 partsRendered := strings.Join(parts, ", ")
343 remainingWidth := width - lipgloss.Width(partsRendered) - 3 // " ()"
344 if remainingWidth < 30 {
345 // Not enough space for params, just show main.
346 return sty.Tool.ParamMain.Render(ansi.Truncate(mainParam, width, "…"))
347 }
348
349 fullParam := fmt.Sprintf("%s (%s)", mainParam, partsRendered)
350 return sty.Tool.ParamMain.Render(ansi.Truncate(fullParam, width, "…"))
351}
352
353// renderEarlyState handles error/cancelled/pending states before content rendering.
354// Returns the rendered output and true if early state was handled.
355func renderEarlyState(ctx *ToolCallContext, header string, width int) (string, bool) {
356 sty := ctx.Styles
357
358 var msg string
359 switch ctx.Status() {
360 case ToolStatusError:
361 msg = renderToolError(ctx, width)
362 case ToolStatusCancelled:
363 msg = sty.Tool.StateCancelled.Render("Canceled.")
364 case ToolStatusAwaitingPermission:
365 msg = sty.Tool.StateWaiting.Render("Requesting permission...")
366 case ToolStatusRunning:
367 msg = sty.Tool.StateWaiting.Render("Waiting for tool response...")
368 default:
369 return "", false
370 }
371
372 msg = sty.Tool.BodyPadding.Render(msg)
373 return lipgloss.JoinVertical(lipgloss.Left, header, "", msg), true
374}
375
376// renderToolError formats an error message with ERROR tag.
377func renderToolError(ctx *ToolCallContext, width int) string {
378 sty := ctx.Styles
379 errContent := strings.ReplaceAll(ctx.Result.Content, "\n", " ")
380 errTag := sty.Tool.ErrorTag.Render("ERROR")
381 tagWidth := lipgloss.Width(errTag)
382 errContent = ansi.Truncate(errContent, width-tagWidth-3, "…")
383 return fmt.Sprintf("%s %s", errTag, sty.Tool.ErrorMessage.Render(errContent))
384}
385
386// joinHeaderBody combines header and body with proper padding.
387func joinHeaderBody(header, body string, sty *styles.Styles) string {
388 if body == "" {
389 return header
390 }
391 body = sty.Tool.BodyPadding.Render(body)
392 return lipgloss.JoinVertical(lipgloss.Left, header, "", body)
393}
394
395// renderPlainContent renders plain text with optional expansion support.
396func renderPlainContent(content string, width int, sty *styles.Styles, item *toolItem) string {
397 content = strings.ReplaceAll(content, "\r\n", "\n")
398 content = strings.ReplaceAll(content, "\t", " ")
399 content = strings.TrimSpace(content)
400 lines := strings.Split(content, "\n")
401
402 expanded := item != nil && item.expanded
403 maxLines := responseContextHeight
404 if expanded {
405 maxLines = len(lines) // Show all
406 }
407
408 var out []string
409 for i, ln := range lines {
410 if i >= maxLines {
411 break
412 }
413 ln = " " + ln
414 if lipgloss.Width(ln) > width {
415 ln = ansi.Truncate(ln, width, "…")
416 }
417 out = append(out, sty.Tool.ContentLine.Width(width).Render(ln))
418 }
419
420 wasTruncated := len(lines) > responseContextHeight
421 if item != nil {
422 item.wasTruncated = wasTruncated
423 }
424
425 if !expanded && wasTruncated {
426 out = append(out, sty.Tool.ContentTruncation.
427 Width(width).
428 Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight)))
429 }
430
431 return strings.Join(out, "\n")
432}
433
434// formatNonZero returns string representation of non-zero integers, empty for zero.
435func formatNonZero(value int) string {
436 if value == 0 {
437 return ""
438 }
439 return fmt.Sprintf("%d", value)
440}
441
442// renderCodeContent renders syntax-highlighted code with line numbers and optional expansion.
443func renderCodeContent(path, content string, offset, width int, sty *styles.Styles, item *toolItem) string {
444 content = strings.ReplaceAll(content, "\r\n", "\n")
445 content = strings.ReplaceAll(content, "\t", " ")
446
447 lines := strings.Split(content, "\n")
448
449 maxLines := responseContextHeight
450 if item != nil && item.expanded {
451 maxLines = len(lines)
452 }
453
454 truncated := lines
455 if len(lines) > maxLines {
456 truncated = lines[:maxLines]
457 }
458
459 // Escape ANSI sequences in content.
460 for i, ln := range truncated {
461 truncated[i] = ansiext.Escape(ln)
462 }
463
464 // Apply syntax highlighting.
465 bg := sty.Tool.ContentCodeBg
466 highlighted, _ := common.SyntaxHighlight(sty, strings.Join(truncated, "\n"), path, bg)
467 highlightedLines := strings.Split(highlighted, "\n")
468
469 // Calculate gutter width for line numbers.
470 maxLineNum := offset + len(highlightedLines)
471 maxDigits := getDigits(maxLineNum)
472 numFmt := fmt.Sprintf("%%%dd", maxDigits)
473
474 // Calculate available width for code (accounting for gutter).
475 const numPR, numPL, codePR, codePL = 1, 1, 1, 2
476 codeWidth := width - maxDigits - numPL - numPR - 2
477
478 var out []string
479 for i, ln := range highlightedLines {
480 lineNum := sty.Base.
481 Foreground(sty.FgMuted).
482 Background(bg).
483 PaddingRight(numPR).
484 PaddingLeft(numPL).
485 Render(fmt.Sprintf(numFmt, offset+i+1))
486
487 codeLine := sty.Base.
488 Width(codeWidth).
489 Background(bg).
490 PaddingRight(codePR).
491 PaddingLeft(codePL).
492 Render(ansi.Truncate(ln, codeWidth-codePL-codePR, "…"))
493
494 out = append(out, lipgloss.JoinHorizontal(lipgloss.Left, lineNum, codeLine))
495 }
496
497 wasTruncated := len(lines) > responseContextHeight
498 if item != nil {
499 item.wasTruncated = wasTruncated
500 }
501
502 expanded := item != nil && item.expanded
503
504 if !expanded && wasTruncated {
505 msg := fmt.Sprintf(" …(%d lines) [click or space to expand]", len(lines)-responseContextHeight)
506 out = append(out, sty.Muted.Background(bg).Render(msg))
507 }
508
509 return lipgloss.JoinVertical(lipgloss.Left, out...)
510}
511
512// renderMarkdownContent renders markdown with optional expansion support.
513func renderMarkdownContent(content string, width int, sty *styles.Styles, item *toolItem) string {
514 content = strings.ReplaceAll(content, "\r\n", "\n")
515 content = strings.ReplaceAll(content, "\t", " ")
516 content = strings.TrimSpace(content)
517
518 cappedWidth := min(width, 120)
519 renderer := common.PlainMarkdownRenderer(sty, cappedWidth)
520 rendered, err := renderer.Render(content)
521 if err != nil {
522 return renderPlainContent(content, width, sty, nil)
523 }
524
525 lines := strings.Split(rendered, "\n")
526
527 maxLines := responseContextHeight
528 if item != nil && item.expanded {
529 maxLines = len(lines)
530 }
531
532 var out []string
533 for i, ln := range lines {
534 if i >= maxLines {
535 break
536 }
537 out = append(out, ln)
538 }
539
540 wasTruncated := len(lines) > responseContextHeight
541 if item != nil {
542 item.wasTruncated = wasTruncated
543 }
544
545 expanded := item != nil && item.expanded
546
547 if !expanded && wasTruncated {
548 out = append(out, sty.Tool.ContentTruncation.
549 Width(cappedWidth-2).
550 Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight)))
551 }
552
553 return sty.Tool.ContentLine.Render(strings.Join(out, "\n"))
554}
555
556// renderDiffContent renders a diff with optional expansion support.
557func renderDiffContent(file, oldContent, newContent string, width int, sty *styles.Styles, item *toolItem) string {
558 formatter := common.DiffFormatter(sty).
559 Before(file, oldContent).
560 After(file, newContent).
561 Width(width)
562
563 if width > 120 {
564 formatter = formatter.Split()
565 }
566
567 formatted := formatter.String()
568 lines := strings.Split(formatted, "\n")
569
570 wasTruncated := len(lines) > responseContextHeight
571 if item != nil {
572 item.wasTruncated = wasTruncated
573 }
574
575 expanded := item != nil && item.expanded
576
577 if !expanded && wasTruncated {
578 truncateMsg := sty.Tool.DiffTruncation.
579 Width(width).
580 Render(fmt.Sprintf("… (%d lines) [click or space to expand]", len(lines)-responseContextHeight))
581 formatted = strings.Join(lines[:responseContextHeight], "\n") + "\n" + truncateMsg
582 }
583
584 return formatted
585}
586
587// renderImageContent renders image data with optional text content.
588func renderImageContent(data, mediaType, textContent string, sty *styles.Styles) string {
589 dataSize := len(data) * 3 / 4 // Base64 to bytes approximation.
590 sizeStr := formatSize(dataSize)
591
592 loaded := sty.Tool.IconSuccess.String()
593 arrow := sty.Tool.NameNested.Render("→")
594 typeStyled := sty.Base.Render(mediaType)
595 sizeStyled := sty.Subtle.Render(sizeStr)
596
597 imageDisplay := fmt.Sprintf("%s %s %s %s", loaded, arrow, typeStyled, sizeStyled)
598
599 if strings.TrimSpace(textContent) != "" {
600 textDisplay := sty.Tool.ContentLine.Render(textContent)
601 return lipgloss.JoinVertical(lipgloss.Left, textDisplay, "", imageDisplay)
602 }
603
604 return imageDisplay
605}
606
607// renderMediaContent renders non-image media content.
608func renderMediaContent(mediaType, textContent string, sty *styles.Styles) string {
609 loaded := sty.Tool.IconSuccess.String()
610 arrow := sty.Tool.NameNested.Render("→")
611 typeStyled := sty.Base.Render(mediaType)
612 mediaDisplay := fmt.Sprintf("%s %s %s", loaded, arrow, typeStyled)
613
614 if strings.TrimSpace(textContent) != "" {
615 textDisplay := sty.Tool.ContentLine.Render(textContent)
616 return lipgloss.JoinVertical(lipgloss.Left, textDisplay, "", mediaDisplay)
617 }
618
619 return mediaDisplay
620}
621
622// formatSize formats byte count as human-readable size.
623func formatSize(bytes int) string {
624 if bytes < 1024 {
625 return fmt.Sprintf("%d B", bytes)
626 }
627 if bytes < 1024*1024 {
628 return fmt.Sprintf("%.1f KB", float64(bytes)/1024)
629 }
630 return fmt.Sprintf("%.1f MB", float64(bytes)/(1024*1024))
631}
632
633// getDigits returns the number of digits in a number.
634func getDigits(n int) int {
635 if n == 0 {
636 return 1
637 }
638 if n < 0 {
639 n = -n
640 }
641 digits := 0
642 for n > 0 {
643 n /= 10
644 digits++
645 }
646 return digits
647}
648
649// formatTodosList formats a list of todos with status icons.
650func formatTodosList(todos []session.Todo, width int, sty *styles.Styles) string {
651 if len(todos) == 0 {
652 return ""
653 }
654
655 sorted := make([]session.Todo, len(todos))
656 copy(sorted, todos)
657 slices.SortStableFunc(sorted, func(a, b session.Todo) int {
658 return todoStatusOrder(a.Status) - todoStatusOrder(b.Status)
659 })
660
661 var lines []string
662 for _, todo := range sorted {
663 var prefix string
664 var textStyle lipgloss.Style
665
666 switch todo.Status {
667 case session.TodoStatusCompleted:
668 prefix = sty.Base.Foreground(sty.Green).Render(styles.TodoCompletedIcon) + " "
669 textStyle = sty.Base
670 case session.TodoStatusInProgress:
671 prefix = sty.Base.Foreground(sty.GreenDark).Render(styles.ArrowRightIcon) + " "
672 textStyle = sty.Base
673 default:
674 prefix = sty.Muted.Render(styles.TodoPendingIcon) + " "
675 textStyle = sty.Base
676 }
677
678 text := todo.Content
679 if todo.Status == session.TodoStatusInProgress && todo.ActiveForm != "" {
680 text = todo.ActiveForm
681 }
682
683 line := prefix + textStyle.Render(text)
684 line = ansi.Truncate(line, width, "…")
685 lines = append(lines, line)
686 }
687
688 return strings.Join(lines, "\n")
689}
690
691// todoStatusOrder returns sort order for todo statuses.
692func todoStatusOrder(s session.TodoStatus) int {
693 switch s {
694 case session.TodoStatusCompleted:
695 return 0
696 case session.TodoStatusInProgress:
697 return 1
698 default:
699 return 2
700 }
701}