From f7beb12689a337dade9d9c9bef575947188bcab7 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Tue, 21 Apr 2026 10:00:41 -0400 Subject: [PATCH] feat: generally render output that looks like a diff as a diff (#2607) For example, the GitHub MCP server can fetch diffs. This makes those diffs render as such, rather than just showing their raw output. --- internal/diffdetect/detect.go | 36 ++ internal/diffdetect/detect_test.go | 152 +++++++ internal/ui/chat/docker_mcp.go | 15 +- internal/ui/chat/generic.go | 29 +- internal/ui/chat/mcp.go | 49 +-- internal/ui/chat/mcp_test.go | 406 +++++++++++++++++++ internal/ui/chat/tool_result_content.go | 59 +++ internal/ui/chat/tool_result_content_test.go | 116 ++++++ internal/ui/chat/tools.go | 2 +- internal/ui/chat/unified_diff.go | 171 ++++++++ internal/ui/diffview/diffview.go | 43 ++ internal/ui/diffview/style.go | 17 + internal/ui/styles/styles.go | 8 + 13 files changed, 1015 insertions(+), 88 deletions(-) create mode 100644 internal/diffdetect/detect.go create mode 100644 internal/diffdetect/detect_test.go create mode 100644 internal/ui/chat/mcp_test.go create mode 100644 internal/ui/chat/tool_result_content.go create mode 100644 internal/ui/chat/tool_result_content_test.go create mode 100644 internal/ui/chat/unified_diff.go diff --git a/internal/diffdetect/detect.go b/internal/diffdetect/detect.go new file mode 100644 index 0000000000000000000000000000000000000000..213803e6b3f3754491dd3953ca82b100090def89 --- /dev/null +++ b/internal/diffdetect/detect.go @@ -0,0 +1,36 @@ +package diffdetect + +import "strings" + +// Signal describes which unified-diff markers were found while scanning text. +type Signal struct { + HasHunk bool + HasFileHeader bool + HasGitHeader bool +} + +// Inspect scans content for unified-diff markers. +func Inspect(content string) Signal { + var signal Signal + for line := range strings.SplitSeq(content, "\n") { + if strings.HasPrefix(line, "@@") { + signal.HasHunk = true + } + if strings.HasPrefix(line, "--- ") || strings.HasPrefix(line, "+++ ") { + signal.HasFileHeader = true + } + if strings.HasPrefix(line, "diff --git ") { + signal.HasGitHeader = true + } + } + return signal +} + +// IsUnifiedDiff reports whether content appears to be a unified diff. +func IsUnifiedDiff(content string) bool { + signal := Inspect(content) + if signal.HasGitHeader && signal.HasFileHeader { + return true + } + return signal.HasHunk && signal.HasFileHeader +} diff --git a/internal/diffdetect/detect_test.go b/internal/diffdetect/detect_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5d2cb50a5e50ff93a19596d92639abf44e59f2bd --- /dev/null +++ b/internal/diffdetect/detect_test.go @@ -0,0 +1,152 @@ +package diffdetect + +import "testing" + +func TestInspect(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + want Signal + }{ + { + name: "git unified diff", + content: `diff --git a/main.go b/main.go +--- a/main.go ++++ b/main.go +@@ -1,2 +1,3 @@ + package main ++import "fmt" +`, + want: Signal{HasHunk: true, HasFileHeader: true, HasGitHeader: true}, + }, + { + name: "non-git unified diff", + content: `--- old.c ++++ old.c +@@ -1 +1 @@ +-old ++new +`, + want: Signal{HasHunk: true, HasFileHeader: true, HasGitHeader: false}, + }, + { + name: "plain text", + content: "hello world", + want: Signal{}, + }, + { + name: "hunk only", + content: "@@ -1 +1 @@\n-old\n+new\n", + want: Signal{HasHunk: true}, + }, + { + name: "headers only", + content: "--- a/file\n+++ b/file\n", + want: Signal{HasFileHeader: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := Inspect(tt.content) + if got != tt.want { + t.Errorf("Inspect() = %+v, want %+v", got, tt.want) + } + }) + } +} + +func TestIsUnifiedDiff(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + want bool + }{ + { + name: "github-style multi-file diff", + content: `diff --git a/one.txt b/one.txt +--- a/one.txt ++++ b/one.txt +@@ -1 +1 @@ +-a ++b +diff --git a/two.txt b/two.txt +--- a/two.txt ++++ b/two.txt +@@ -1 +1 @@ +-c ++d +`, + want: true, + }, + { + name: "non-git unified patch", + content: `--- a/old.c ++++ b/old.c +@@ -1,3 +1,4 @@ + #include +-int main() { ++int main(int argc, char **argv) { + return 0; + } +`, + want: true, + }, + { + name: "new file from dev null", + content: `--- /dev/null ++++ newfile.txt +@@ -0,0 +1,2 @@ ++hello ++world +`, + want: true, + }, + { + name: "markdown false positive candidate", + content: `- Item one +- Item two ++ Bonus item +- Item three +`, + want: false, + }, + { + name: "headers without hunk", + content: `--- a/somefile.txt ++++ b/somefile.txt +Just some content here +No hunk markers at all +`, + want: false, + }, + { + name: "hunk without headers", + content: `@@ -1,3 +1,4 @@ + some line +-another line ++changed line +`, + want: false, + }, + { + name: "empty", + content: "", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := IsUnifiedDiff(tt.content); got != tt.want { + t.Errorf("IsUnifiedDiff() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/ui/chat/docker_mcp.go b/internal/ui/chat/docker_mcp.go index 57cd9da55f83e63279413d9801337236290d5cdc..7f731130380e7a6c5e33b5d617e6951c94627e7b 100644 --- a/internal/ui/chat/docker_mcp.go +++ b/internal/ui/chat/docker_mcp.go @@ -138,20 +138,7 @@ func (d *DockerMCPToolRenderContext) RenderTool(sty *styles.Styles, width int, o // Handle text content. if opts.Result.Content != "" { - var body string - var result json.RawMessage - if err := json.Unmarshal([]byte(opts.Result.Content), &result); err == nil { - prettyResult, err := json.MarshalIndent(result, "", " ") - if err == nil { - body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, bodyWidth, opts.ExpandedContent)) - } else { - body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) - } - } else if looksLikeMarkdown(opts.Result.Content) { - body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", opts.Result.Content, 0, bodyWidth, opts.ExpandedContent)) - } else { - body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) - } + body := renderToolResultTextContent(sty, opts.Result.Content, toolResultContentWidths{Body: bodyWidth, Diff: cappedWidth}, opts.ExpandedContent) parts = append(parts, body) } diff --git a/internal/ui/chat/generic.go b/internal/ui/chat/generic.go index ae4c99758cf7ea9d311019a9e46676c0f565620e..20a80d77f39190ebdaf91dc2773d54930ddd1688 100644 --- a/internal/ui/chat/generic.go +++ b/internal/ui/chat/generic.go @@ -5,7 +5,6 @@ import ( "strings" "github.com/charmbracelet/crush/internal/message" - "github.com/charmbracelet/crush/internal/stringext" "github.com/charmbracelet/crush/internal/ui/styles" ) @@ -32,7 +31,7 @@ type GenericToolRenderContext struct{} // RenderTool implements the [ToolRenderer] interface. func (g *GenericToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string { cappedWidth := cappedMessageWidth(width) - name := genericPrettyName(opts.ToolCall.Name) + name := humanizedToolName(opts.ToolCall.Name) if opts.IsPending() { return pendingTool(sty, name, opts.Anim, opts.Compact) @@ -64,35 +63,11 @@ func (g *GenericToolRenderContext) RenderTool(sty *styles.Styles, width int, opt bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - // Handle image data. if opts.Result.Data != "" && strings.HasPrefix(opts.Result.MIMEType, "image/") { body := sty.Tool.Body.Render(toolOutputImageContent(sty, opts.Result.Data, opts.Result.MIMEType)) return joinToolParts(header, body) } - // Try to parse result as JSON for pretty display. - var result json.RawMessage - var body string - if err := json.Unmarshal([]byte(opts.Result.Content), &result); err == nil { - prettyResult, err := json.MarshalIndent(result, "", " ") - if err == nil { - body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, bodyWidth, opts.ExpandedContent)) - } else { - body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) - } - } else if looksLikeMarkdown(opts.Result.Content) { - body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", opts.Result.Content, 0, bodyWidth, opts.ExpandedContent)) - } else { - body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) - } - + body := renderToolResultTextContent(sty, opts.Result.Content, toolResultContentWidths{Body: bodyWidth, Diff: cappedWidth}, opts.ExpandedContent) return joinToolParts(header, body) } - -// genericPrettyName converts a snake_case or kebab-case tool name to a -// human-readable title case name. -func genericPrettyName(name string) string { - name = strings.ReplaceAll(name, "_", " ") - name = strings.ReplaceAll(name, "-", " ") - return stringext.Capitalize(name) -} diff --git a/internal/ui/chat/mcp.go b/internal/ui/chat/mcp.go index 33d72d6007f5d159d9c1983f09fb25b5e8586388..6f624e130f27f7de99a7b5a482da77537a2c7fe1 100644 --- a/internal/ui/chat/mcp.go +++ b/internal/ui/chat/mcp.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/charmbracelet/crush/internal/message" - "github.com/charmbracelet/crush/internal/stringext" "github.com/charmbracelet/crush/internal/ui/styles" ) @@ -37,8 +36,8 @@ func (b *MCPToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *T if len(toolNameParts) != 3 { return toolErrorContent(sty, &message.ToolResult{Content: "Invalid tool name"}, cappedWidth) } - mcpName := prettyName(toolNameParts[1]) - toolName := prettyName(toolNameParts[2]) + mcpName := humanizedToolName(toolNameParts[1]) + toolName := humanizedToolName(toolNameParts[2]) mcpName = sty.Tool.MCPName.Render(mcpName) toolName = sty.Tool.MCPToolName.Render(toolName) @@ -74,48 +73,6 @@ func (b *MCPToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *T } bodyWidth := cappedWidth - toolBodyLeftPaddingTotal - // see if the result is json - var result json.RawMessage - var body string - if err := json.Unmarshal([]byte(opts.Result.Content), &result); err == nil { - prettyResult, err := json.MarshalIndent(result, "", " ") - if err == nil { - body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, bodyWidth, opts.ExpandedContent)) - } else { - body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) - } - } else if looksLikeMarkdown(opts.Result.Content) { - body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", opts.Result.Content, 0, bodyWidth, opts.ExpandedContent)) - } else { - body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent)) - } + body := renderToolResultTextContent(sty, opts.Result.Content, toolResultContentWidths{Body: bodyWidth, Diff: cappedWidth}, opts.ExpandedContent) return joinToolParts(header, body) } - -func prettyName(name string) string { - name = strings.ReplaceAll(name, "_", " ") - name = strings.ReplaceAll(name, "-", " ") - return stringext.Capitalize(name) -} - -// looksLikeMarkdown checks if content appears to be markdown by looking for -// common markdown patterns. -func looksLikeMarkdown(content string) bool { - patterns := []string{ - "# ", // headers - "## ", // headers - "**", // bold - "```", // code fence - "- ", // unordered list - "1. ", // ordered list - "> ", // blockquote - "---", // horizontal rule - "***", // horizontal rule - } - for _, p := range patterns { - if strings.Contains(content, p) { - return true - } - } - return false -} diff --git a/internal/ui/chat/mcp_test.go b/internal/ui/chat/mcp_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8e317fe398893fb403647c1da9b7b2cde10d7b04 --- /dev/null +++ b/internal/ui/chat/mcp_test.go @@ -0,0 +1,406 @@ +package chat + +import ( + "strings" + "testing" +) + +func TestLooksLikeDiff(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + want bool + }{ + { + name: "simple unified diff", + content: `diff --git a/main.go b/main.go +--- a/main.go ++++ b/main.go +@@ -1,5 +1,6 @@ + package main + ++import "fmt" ++ + func main() { +- println("hello") ++ fmt.Println("hello") + } +`, + want: true, + }, + { + name: "plain text", + content: "This is just some plain text with no diff markers.", + want: false, + }, + { + name: "empty string", + content: "", + want: false, + }, + { + name: "markdown with headers", + content: `# Title + +Some content here. + +## Subtitle + +More content with **bold** text. +`, + want: false, + }, + { + name: "diff with mixed content", + content: `diff --git a/file.txt b/file.txt +--- a/file.txt ++++ b/file.txt +@@ -1 +1 @@ +-old line ++new line +`, + want: true, + }, + { + name: "only plus/minus without hunk or headers", + content: `Hello world +--- +This is not really a diff +Just some text with a few symbols ++ another line +More regular content here +And even more content +`, + want: false, + }, + { + name: "GitHub PR diff format", + content: `diff --git a/src/app.ts b/src/app.ts +index abc1234..def5678 100644 +--- a/src/app.ts ++++ b/src/app.ts +@@ -10,6 +10,8 @@ function handleRequest() { + const data = getData(); ++ validate(data); ++ log(data); + return process(data); + } +`, + want: true, + }, + { + name: "non-git unified patch with hunk and headers", + content: `--- a/old.c ++++ b/old.c +@@ -1,3 +1,4 @@ + #include +-int main() { ++int main(int argc, char **argv) { + return 0; + } +`, + want: true, + }, + { + name: "file headers without hunk markers", + content: `--- a/somefile.txt ++++ b/somefile.txt +Just some content here +No hunk markers at all +`, + want: false, + }, + { + name: "hunk markers without file headers", + content: `@@ -1,3 +1,4 @@ + some line +-another line ++changed line +`, + want: false, + }, + { + name: "markdown list with plus signs", + content: `- Item one +- Item two ++ Bonus item +- Item three +`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := looksLikeDiff(tt.content) + if got != tt.want { + t.Errorf("looksLikeDiff() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseUnifiedDiff(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want []parsedDiffFile + }{ + { + name: "simple diff with additions and removals", + input: `diff --git a/main.go b/main.go +--- a/main.go ++++ b/main.go +@@ -1,5 +1,6 @@ + package main + ++import "fmt" ++ + func main() { +- println("hello") ++ fmt.Println("hello") + } +`, + want: []parsedDiffFile{ + { + path: "main.go", + before: "package main\n\nfunc main() {\n println(\"hello\")\n}", + after: "package main\n\nimport \"fmt\"\n\nfunc main() {\n fmt.Println(\"hello\")\n}", + }, + }, + }, + { + name: "new file creation", + input: `diff --git a/newfile.go b/newfile.go +new file mode 100644 +--- /dev/null ++++ b/newfile.go +@@ -0,0 +1,3 @@ ++package main ++ ++func main() {} +`, + want: []parsedDiffFile{ + { + path: "newfile.go", + before: "", + after: "package main\n\nfunc main() {}", + }, + }, + }, + { + name: "file deletion", + input: `diff --git a/oldfile.go b/oldfile.go +deleted file mode 100644 +--- a/oldfile.go ++++ /dev/null +@@ -1,3 +0,0 @@ +-package main +- +-func main() {} +`, + want: []parsedDiffFile{ + { + path: "oldfile.go", + before: "package main\n\nfunc main() {}", + after: "", + }, + }, + }, + { + name: "non-diff content", + input: "Just some regular text", + want: nil, + }, + { + name: "diff with timestamp in header", + input: `diff --git a/config.yml b/config.yml +--- a/config.yml 2024-01-15 10:30:00 ++++ b/config.yml 2024-01-15 10:31:00 +@@ -1,3 +1,4 @@ + name: myapp +-version: 1.0 ++version: 1.1 ++debug: true +`, + want: []parsedDiffFile{ + { + path: "config.yml", + before: "name: myapp\nversion: 1.0", + after: "name: myapp\nversion: 1.1\ndebug: true", + }, + }, + }, + { + name: "multi-file diff", + input: `diff --git a/one.txt b/one.txt +--- a/one.txt ++++ b/one.txt +@@ -1,3 +1,3 @@ + line one +-line two ++line two updated + line three +diff --git a/two.txt b/two.txt +--- a/two.txt ++++ b/two.txt +@@ -1,2 +1,3 @@ + alpha ++beta + gamma +`, + want: []parsedDiffFile{ + { + path: "one.txt", + before: "line one\nline two\nline three", + after: "line one\nline two updated\nline three", + }, + { + path: "two.txt", + before: "alpha\ngamma", + after: "alpha\nbeta\ngamma", + }, + }, + }, + { + name: "non-git unified patch", + input: `--- old.c ++++ old.c +@@ -1,3 +1,4 @@ + #include +-int main() { ++int main(int argc, char **argv) { + return 0; + } +`, + want: []parsedDiffFile{ + { + path: "old.c", + before: "#include \nint main() {\n return 0;\n}", + after: "#include \nint main(int argc, char **argv) {\n return 0;\n}", + }, + }, + }, + { + name: "non-git new file from /dev/null", + input: `--- /dev/null ++++ newfile.txt +@@ -0,0 +1,2 @@ ++hello ++world +`, + want: []parsedDiffFile{ + { + path: "newfile.txt", + before: "", + after: "hello\nworld", + }, + }, + }, + { + name: "non-git new file with only +++ header", + input: `+++ brand_new.go +@@ -0,0 +1,3 @@ ++package main ++ ++func main() {} +`, + want: []parsedDiffFile{ + { + path: "brand_new.go", + before: "", + after: "package main\n\nfunc main() {}", + }, + }, + }, + { + name: "multi-hunk single file", + input: `diff --git a/big.go b/big.go +--- a/big.go ++++ b/big.go +@@ -1,4 +1,5 @@ + package main ++import "os" + + func init() { +@@ -10,3 +11,3 @@ +- println("done") ++ fmt.Println("done") + } +`, + want: []parsedDiffFile{ + { + path: "big.go", + before: "package main\n\nfunc init() {\n println(\"done\")\n}", + after: "package main\nimport \"os\"\n\nfunc init() {\n fmt.Println(\"done\")\n}", + }, + }, + }, + { + name: "hunk content starting with header-like prefixes", + input: `diff --git a/file.txt b/file.txt +--- a/file.txt ++++ b/file.txt +@@ -1,3 +1,3 @@ +---- tricky +++++ newer + keep +`, + want: []parsedDiffFile{ + { + path: "file.txt", + before: "--- tricky\nkeep", + after: "+++ newer\nkeep", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := parseUnifiedDiff(tt.input) + if len(got) != len(tt.want) { + t.Errorf("parseUnifiedDiff() returned %d files, want %d", len(got), len(tt.want)) + return + } + for i, w := range tt.want { + if got[i].path != w.path { + t.Errorf("parseUnifiedDiff()[%d].path = %q, want %q", i, got[i].path, w.path) + } + if got[i].before != w.before { + t.Errorf("parseUnifiedDiff()[%d].before = %q, want %q", i, got[i].before, w.before) + } + if got[i].after != w.after { + t.Errorf("parseUnifiedDiff()[%d].after = %q, want %q", i, got[i].after, w.after) + } + } + }) + } +} + +func TestLooksLikeDiffVersusMarkdown(t *testing.T) { + t.Parallel() + + // A unified diff should be detected as a diff, not markdown, + // even though it contains "-" which could match markdown patterns. + diffContent := strings.Join([]string{ + "diff --git a/README.md b/README.md", + "--- a/README.md", + "+++ b/README.md", + "@@ -1,3 +1,3 @@", + " # Title", + "-Old subtitle", + "+New subtitle", + " Some content", + }, "\n") + + if !looksLikeDiff(diffContent) { + t.Error("looksLikeDiff() should detect unified diff") + } +} diff --git a/internal/ui/chat/tool_result_content.go b/internal/ui/chat/tool_result_content.go new file mode 100644 index 0000000000000000000000000000000000000000..1ca77d75ee0636991366224e9dc564cecead36b2 --- /dev/null +++ b/internal/ui/chat/tool_result_content.go @@ -0,0 +1,59 @@ +package chat + +import ( + "encoding/json" + "strings" + + "github.com/charmbracelet/crush/internal/diffdetect" + "github.com/charmbracelet/crush/internal/stringext" + "github.com/charmbracelet/crush/internal/ui/styles" +) + +type toolResultContentWidths struct { + Body int + Diff int +} + +func humanizedToolName(name string) string { + name = strings.ReplaceAll(name, "_", " ") + name = strings.ReplaceAll(name, "-", " ") + return stringext.Capitalize(name) +} + +func looksLikeMarkdown(content string) bool { + patterns := []string{ + "# ", + "## ", + "**", + "```", + "- ", + "1. ", + "> ", + "---", + "***", + } + for _, p := range patterns { + if strings.Contains(content, p) { + return true + } + } + return false +} + +func renderToolResultTextContent(sty *styles.Styles, content string, widths toolResultContentWidths, expanded bool) string { + var result json.RawMessage + if err := json.Unmarshal([]byte(content), &result); err == nil { + prettyResult, err := json.MarshalIndent(result, "", " ") + if err == nil { + return sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, widths.Body, expanded)) + } + return sty.Tool.Body.Render(toolOutputPlainContent(sty, content, widths.Body, expanded)) + } + if diffdetect.IsUnifiedDiff(content) { + return toolOutputDiffContentFromUnified(sty, content, widths.Diff, expanded) + } + if looksLikeMarkdown(content) { + return sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", content, 0, widths.Body, expanded)) + } + return sty.Tool.Body.Render(toolOutputPlainContent(sty, content, widths.Body, expanded)) +} diff --git a/internal/ui/chat/tool_result_content_test.go b/internal/ui/chat/tool_result_content_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5020d81bce6d514cd2dffcfb042ccae7711fcb50 --- /dev/null +++ b/internal/ui/chat/tool_result_content_test.go @@ -0,0 +1,116 @@ +package chat + +import ( + "encoding/json" + "testing" + + "github.com/charmbracelet/crush/internal/ui/styles" +) + +func TestHumanizedToolName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + {name: "snake case", input: "mcp_github_get", want: "Mcp Github Get"}, + {name: "kebab case", input: "web-fetch", want: "Web Fetch"}, + {name: "mixed", input: "job_output-tool", want: "Job Output Tool"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := humanizedToolName(tt.input); got != tt.want { + t.Fatalf("humanizedToolName() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestLooksLikeMarkdown(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + want bool + }{ + {name: "header", content: "# Title", want: true}, + {name: "code fence", content: "```go\nfmt.Println(\"x\")\n```", want: true}, + {name: "plain", content: "hello world", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := looksLikeMarkdown(tt.content); got != tt.want { + t.Fatalf("looksLikeMarkdown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRenderToolResultTextContent(t *testing.T) { + t.Parallel() + + sty := styles.DefaultStyles() + styPtr := &sty + widths := toolResultContentWidths{Body: 80, Diff: 82} + + t.Run("json branch", func(t *testing.T) { + t.Parallel() + content := `{"a":1}` + var result json.RawMessage + if err := json.Unmarshal([]byte(content), &result); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + prettyResult, err := json.MarshalIndent(result, "", " ") + if err != nil { + t.Fatalf("json.MarshalIndent() error = %v", err) + } + expected := styPtr.Tool.Body.Render(toolOutputCodeContent(styPtr, "result.json", string(prettyResult), 0, widths.Body, false)) + got := renderToolResultTextContent(styPtr, content, widths, false) + if got != expected { + t.Fatal("renderToolResultTextContent() did not choose JSON rendering") + } + }) + + t.Run("diff branch before markdown", func(t *testing.T) { + t.Parallel() + content := `diff --git a/README.md b/README.md +--- a/README.md ++++ b/README.md +@@ -1 +1 @@ +-# Old ++# New +` + expected := toolOutputDiffContentFromUnified(styPtr, content, widths.Diff, false) + got := renderToolResultTextContent(styPtr, content, widths, false) + if got != expected { + t.Fatal("renderToolResultTextContent() did not choose diff rendering") + } + }) + + t.Run("markdown branch", func(t *testing.T) { + t.Parallel() + content := "# Title\n\nBody" + expected := styPtr.Tool.Body.Render(toolOutputCodeContent(styPtr, "result.md", content, 0, widths.Body, false)) + got := renderToolResultTextContent(styPtr, content, widths, false) + if got != expected { + t.Fatal("renderToolResultTextContent() did not choose markdown rendering") + } + }) + + t.Run("plain branch", func(t *testing.T) { + t.Parallel() + content := "plain text" + expected := styPtr.Tool.Body.Render(toolOutputPlainContent(styPtr, content, widths.Body, false)) + got := renderToolResultTextContent(styPtr, content, widths, false) + if got != expected { + t.Fatal("renderToolResultTextContent() did not choose plain rendering") + } + }) +} diff --git a/internal/ui/chat/tools.go b/internal/ui/chat/tools.go index f91ad8ebd8725c2e2c15a4b9968bb51226c4db12..871ba4348d20d88f056d951660c14b4aafabf03f 100644 --- a/internal/ui/chat/tools.go +++ b/internal/ui/chat/tools.go @@ -1406,6 +1406,6 @@ func prettifyToolName(name string) string { case tools.WriteToolName: return "Write" default: - return genericPrettyName(name) + return humanizedToolName(name) } } diff --git a/internal/ui/chat/unified_diff.go b/internal/ui/chat/unified_diff.go new file mode 100644 index 0000000000000000000000000000000000000000..71cdf5b6d2b86f4ab93e3b26cf3e4669029adf26 --- /dev/null +++ b/internal/ui/chat/unified_diff.go @@ -0,0 +1,171 @@ +package chat + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/crush/internal/diffdetect" + "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/ui/styles" +) + +type parsedDiffFile struct { + path string + before string + after string +} + +func looksLikeDiff(content string) bool { + return diffdetect.IsUnifiedDiff(content) +} + +func parseUnifiedDiff(content string) []parsedDiffFile { + type fileBuilder struct { + path string + before strings.Builder + after strings.Builder + } + + var files []fileBuilder + currentIdx := -1 + inHunk := false + lines := strings.Split(content, "\n") + + for i, line := range lines { + if strings.HasPrefix(line, "diff --git ") { + inHunk = false + parts := strings.SplitN(line, " ", 4) + if len(parts) >= 4 { + files = append(files, fileBuilder{path: strings.TrimPrefix(parts[3], "b/")}) + currentIdx = len(files) - 1 + } + continue + } + + if strings.HasPrefix(line, "@@") { + inHunk = true + continue + } + + if strings.HasPrefix(line, "index ") || strings.HasPrefix(line, "new file") || strings.HasPrefix(line, "deleted file") { + inHunk = false + continue + } + + nextIsPlusHeader := i+1 < len(lines) && strings.HasPrefix(lines[i+1], "+++ ") + if strings.HasPrefix(line, "--- ") && (!inHunk || nextIsPlusHeader) { + startedNewFileFromHunk := inHunk && nextIsPlusHeader + inHunk = false + p := strings.TrimPrefix(line, "--- ") + p = strings.TrimPrefix(p, "a/") + if idx := strings.Index(p, "\t"); idx >= 0 { + p = p[:idx] + } + if currentIdx < 0 || startedNewFileFromHunk { + files = append(files, fileBuilder{path: p}) + currentIdx = len(files) - 1 + continue + } + if p != "/dev/null" { + files[currentIdx].path = p + } + continue + } + + if strings.HasPrefix(line, "+++ ") && !inHunk { + p := strings.TrimPrefix(line, "+++ ") + p = strings.TrimPrefix(p, "b/") + if idx := strings.Index(p, "\t"); idx >= 0 { + p = p[:idx] + } + if currentIdx < 0 { + if p != "/dev/null" { + files = append(files, fileBuilder{path: p}) + currentIdx = len(files) - 1 + } + continue + } + if p != "/dev/null" && (files[currentIdx].path == "" || strings.HasPrefix(files[currentIdx].path, "/dev/null")) { + files[currentIdx].path = p + } + continue + } + + if currentIdx < 0 { + continue + } + + if strings.HasPrefix(line, "-") { + inHunk = true + files[currentIdx].before.WriteString(line[1:]) + files[currentIdx].before.WriteByte('\n') + continue + } + + if strings.HasPrefix(line, "+") { + inHunk = true + files[currentIdx].after.WriteString(line[1:]) + files[currentIdx].after.WriteByte('\n') + continue + } + + if strings.HasPrefix(line, " ") { + inHunk = true + lineContent := line[1:] + files[currentIdx].before.WriteString(lineContent) + files[currentIdx].before.WriteByte('\n') + files[currentIdx].after.WriteString(lineContent) + files[currentIdx].after.WriteByte('\n') + } + } + + result := make([]parsedDiffFile, 0, len(files)) + for _, f := range files { + result = append(result, parsedDiffFile{ + path: f.path, + before: strings.TrimSuffix(f.before.String(), "\n"), + after: strings.TrimSuffix(f.after.String(), "\n"), + }) + } + return result +} + +func toolOutputDiffContentFromUnified(sty *styles.Styles, content string, width int, expanded bool) string { + files := parseUnifiedDiff(content) + if len(files) == 0 { + bodyWidth := width - toolBodyLeftPaddingTotal + return sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.diff", content, 0, bodyWidth, expanded)) + } + bodyWidth := width - toolBodyLeftPaddingTotal + var blocks []string + for i, f := range files { + formatter := common.DiffFormatter(sty). + Before(f.path, f.before). + After(f.path, f.after). + Width(bodyWidth) + if len(files) > 1 { + formatter = formatter.FileName(f.path) + } + if width > maxTextWidth { + formatter = formatter.Split() + } + formatted := formatter.String() + if i < len(files)-1 { + formatted += "\n" + } + blocks = append(blocks, formatted) + } + combined := strings.Join(blocks, "\n") + lines := strings.Split(combined, "\n") + maxLines := responseContextHeight + if expanded { + maxLines = len(lines) + } + if len(lines) > maxLines && !expanded { + truncMsg := sty.Tool.DiffTruncation. + Width(bodyWidth). + Render(fmt.Sprintf(assistantMessageTruncateFormat, len(lines)-maxLines)) + combined = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg + } + return sty.Tool.Body.Render(combined) +} diff --git a/internal/ui/diffview/diffview.go b/internal/ui/diffview/diffview.go index db311593e0f57384f0417ba018d5c7cf0f88df5f..8616ee31158cc0bf83acc2f834e34f06401b5358 100644 --- a/internal/ui/diffview/diffview.go +++ b/internal/ui/diffview/diffview.go @@ -38,6 +38,7 @@ type DiffView struct { layout layout before file after file + fileName string contextLines int lineNumbers bool height int @@ -112,6 +113,12 @@ func (dv *DiffView) After(path, content string) *DiffView { return dv } +// FileName sets the file name header to display above the diff. +func (dv *DiffView) FileName(name string) *DiffView { + dv.fileName = name + return dv +} + // clearCaches clears all caches when content or major settings change. func (dv *DiffView) clearCaches() { dv.cachedLexer = nil @@ -287,6 +294,7 @@ func (dv *DiffView) adjustStyles() { dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber) dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber) dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber) + dv.style.Filename.LineNumber = setPadding(dv.style.Filename.LineNumber) } // detectNumDigits calculates the maximum number of digits needed for before and @@ -304,6 +312,10 @@ func (dv *DiffView) detectNumDigits() { func (dv *DiffView) detectTotalLines() { dv.totalLines = 0 + if dv.fileName != "" { + dv.totalLines++ + } + switch dv.layout { case layoutUnified: for _, h := range dv.unified.Hunks { @@ -415,6 +427,20 @@ func (dv *DiffView) renderUnified() string { outer: for i, h := range dv.unified.Hunks { + // Render file name header before the first hunk. + if i == 0 && dv.fileName != "" { + if shouldWrite() { + ls := dv.style.Filename + if dv.lineNumbers { + b.WriteString(ls.LineNumber.Render(pad("…", dv.beforeNumDigits))) + b.WriteString(ls.LineNumber.Render(pad("…", dv.afterNumDigits))) + } + content := ansi.Truncate(" "+dv.fileName, dv.fullCodeWidth, "…") + b.WriteString(ls.Code.Width(dv.fullCodeWidth).Render(content)) + b.WriteString("\n") + } + printedLines++ + } if shouldWrite() { ls := dv.style.DividerLine if dv.lineNumbers { @@ -525,6 +551,23 @@ func (dv *DiffView) renderSplit() string { outer: for i, h := range dv.splitHunks { + // Render file name header before the first hunk. + if i == 0 && dv.fileName != "" { + if shouldWrite() { + ls := dv.style.Filename + if dv.lineNumbers { + b.WriteString(ls.LineNumber.Render(pad("…", dv.beforeNumDigits))) + } + content := ansi.Truncate(" "+dv.fileName, dv.fullCodeWidth, "…") + b.WriteString(ls.Code.Width(dv.fullCodeWidth).Render(content)) + if dv.lineNumbers { + b.WriteString(ls.LineNumber.Render(pad("…", dv.afterNumDigits))) + } + b.WriteString(ls.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" ")) + b.WriteRune('\n') + } + printedLines++ + } if shouldWrite() { ls := dv.style.DividerLine if dv.lineNumbers { diff --git a/internal/ui/diffview/style.go b/internal/ui/diffview/style.go index 25fd08ac68c7b160bbdcfa61df737b48f12cb625..f0bb9b65e6ce3ff443d085eb9eb7d84dec434657 100644 --- a/internal/ui/diffview/style.go +++ b/internal/ui/diffview/style.go @@ -21,6 +21,7 @@ type Style struct { EqualLine LineStyle InsertLine LineStyle DeleteLine LineStyle + Filename LineStyle } // DefaultLightStyle provides a default light theme style for the diff view. @@ -70,6 +71,14 @@ func DefaultLightStyle() Style { Foreground(charmtone.Pepper). Background(lipgloss.Color("#ffebee")), }, + Filename: LineStyle{ + LineNumber: lipgloss.NewStyle(). + Foreground(charmtone.Iron). + Background(charmtone.Thunder), + Code: lipgloss.NewStyle(). + Foreground(charmtone.Iron). + Background(charmtone.Thunder), + }, } } @@ -120,5 +129,13 @@ func DefaultDarkStyle() Style { Foreground(charmtone.Salt). Background(lipgloss.Color("#3a3030")), }, + Filename: LineStyle{ + LineNumber: lipgloss.NewStyle(). + Foreground(charmtone.Smoke). + Background(charmtone.Sapphire), + Code: lipgloss.NewStyle(). + Foreground(charmtone.Smoke). + Background(charmtone.Sapphire), + }, } } diff --git a/internal/ui/styles/styles.go b/internal/ui/styles/styles.go index 12c5c99e0e2b9619777d64b746c053e0bd3e165b..4dec2a15d8e3d0493b69035eaa5d6e19c062548e 100644 --- a/internal/ui/styles/styles.go +++ b/internal/ui/styles/styles.go @@ -1056,6 +1056,14 @@ func DefaultStyles() Styles { Code: lipgloss.NewStyle(). Background(lipgloss.Color("#383030")), }, + Filename: diffview.LineStyle{ + LineNumber: lipgloss.NewStyle(). + Foreground(fgHalfMuted). + Background(bgBaseLighter), + Code: lipgloss.NewStyle(). + Foreground(fgHalfMuted). + Background(bgBaseLighter), + }, } s.FilePicker = filepicker.Styles{