Detailed changes
@@ -0,0 +1,36 @@
+package diffdetect
+
+import "strings"
+
+// Signal describes which unified-diff markers were found while scanning text.
+type Signal struct {
+ HasHunk bool
+ HasFileHeader bool
+ HasGitHeader bool
+}
+
+// Inspect scans content for unified-diff markers.
+func Inspect(content string) Signal {
+ var signal Signal
+ for line := range strings.SplitSeq(content, "\n") {
+ if strings.HasPrefix(line, "@@") {
+ signal.HasHunk = true
+ }
+ if strings.HasPrefix(line, "--- ") || strings.HasPrefix(line, "+++ ") {
+ signal.HasFileHeader = true
+ }
+ if strings.HasPrefix(line, "diff --git ") {
+ signal.HasGitHeader = true
+ }
+ }
+ return signal
+}
+
+// IsUnifiedDiff reports whether content appears to be a unified diff.
+func IsUnifiedDiff(content string) bool {
+ signal := Inspect(content)
+ if signal.HasGitHeader && signal.HasFileHeader {
+ return true
+ }
+ return signal.HasHunk && signal.HasFileHeader
+}
@@ -0,0 +1,152 @@
+package diffdetect
+
+import "testing"
+
+func TestInspect(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ content string
+ want Signal
+ }{
+ {
+ name: "git unified diff",
+ content: `diff --git a/main.go b/main.go
+--- a/main.go
++++ b/main.go
+@@ -1,2 +1,3 @@
+ package main
++import "fmt"
+`,
+ want: Signal{HasHunk: true, HasFileHeader: true, HasGitHeader: true},
+ },
+ {
+ name: "non-git unified diff",
+ content: `--- old.c
++++ old.c
+@@ -1 +1 @@
+-old
++new
+`,
+ want: Signal{HasHunk: true, HasFileHeader: true, HasGitHeader: false},
+ },
+ {
+ name: "plain text",
+ content: "hello world",
+ want: Signal{},
+ },
+ {
+ name: "hunk only",
+ content: "@@ -1 +1 @@\n-old\n+new\n",
+ want: Signal{HasHunk: true},
+ },
+ {
+ name: "headers only",
+ content: "--- a/file\n+++ b/file\n",
+ want: Signal{HasFileHeader: true},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := Inspect(tt.content)
+ if got != tt.want {
+ t.Errorf("Inspect() = %+v, want %+v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsUnifiedDiff(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ content string
+ want bool
+ }{
+ {
+ name: "github-style multi-file diff",
+ content: `diff --git a/one.txt b/one.txt
+--- a/one.txt
++++ b/one.txt
+@@ -1 +1 @@
+-a
++b
+diff --git a/two.txt b/two.txt
+--- a/two.txt
++++ b/two.txt
+@@ -1 +1 @@
+-c
++d
+`,
+ want: true,
+ },
+ {
+ name: "non-git unified patch",
+ content: `--- a/old.c
++++ b/old.c
+@@ -1,3 +1,4 @@
+ #include <stdio.h>
+-int main() {
++int main(int argc, char **argv) {
+ return 0;
+ }
+`,
+ want: true,
+ },
+ {
+ name: "new file from dev null",
+ content: `--- /dev/null
++++ newfile.txt
+@@ -0,0 +1,2 @@
++hello
++world
+`,
+ want: true,
+ },
+ {
+ name: "markdown false positive candidate",
+ content: `- Item one
+- Item two
++ Bonus item
+- Item three
+`,
+ want: false,
+ },
+ {
+ name: "headers without hunk",
+ content: `--- a/somefile.txt
++++ b/somefile.txt
+Just some content here
+No hunk markers at all
+`,
+ want: false,
+ },
+ {
+ name: "hunk without headers",
+ content: `@@ -1,3 +1,4 @@
+ some line
+-another line
++changed line
+`,
+ want: false,
+ },
+ {
+ name: "empty",
+ content: "",
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := IsUnifiedDiff(tt.content); got != tt.want {
+ t.Errorf("IsUnifiedDiff() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
@@ -138,20 +138,7 @@ func (d *DockerMCPToolRenderContext) RenderTool(sty *styles.Styles, width int, o
// Handle text content.
if opts.Result.Content != "" {
- var body string
- var result json.RawMessage
- if err := json.Unmarshal([]byte(opts.Result.Content), &result); err == nil {
- prettyResult, err := json.MarshalIndent(result, "", " ")
- if err == nil {
- body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, bodyWidth, opts.ExpandedContent))
- } else {
- body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
- }
- } else if looksLikeMarkdown(opts.Result.Content) {
- body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", opts.Result.Content, 0, bodyWidth, opts.ExpandedContent))
- } else {
- body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
- }
+ body := renderToolResultTextContent(sty, opts.Result.Content, toolResultContentWidths{Body: bodyWidth, Diff: cappedWidth}, opts.ExpandedContent)
parts = append(parts, body)
}
@@ -5,7 +5,6 @@ import (
"strings"
"github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/stringext"
"github.com/charmbracelet/crush/internal/ui/styles"
)
@@ -32,7 +31,7 @@ type GenericToolRenderContext struct{}
// RenderTool implements the [ToolRenderer] interface.
func (g *GenericToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string {
cappedWidth := cappedMessageWidth(width)
- name := genericPrettyName(opts.ToolCall.Name)
+ name := humanizedToolName(opts.ToolCall.Name)
if opts.IsPending() {
return pendingTool(sty, name, opts.Anim, opts.Compact)
@@ -64,35 +63,11 @@ func (g *GenericToolRenderContext) RenderTool(sty *styles.Styles, width int, opt
bodyWidth := cappedWidth - toolBodyLeftPaddingTotal
- // Handle image data.
if opts.Result.Data != "" && strings.HasPrefix(opts.Result.MIMEType, "image/") {
body := sty.Tool.Body.Render(toolOutputImageContent(sty, opts.Result.Data, opts.Result.MIMEType))
return joinToolParts(header, body)
}
- // Try to parse result as JSON for pretty display.
- var result json.RawMessage
- var body string
- if err := json.Unmarshal([]byte(opts.Result.Content), &result); err == nil {
- prettyResult, err := json.MarshalIndent(result, "", " ")
- if err == nil {
- body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, bodyWidth, opts.ExpandedContent))
- } else {
- body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
- }
- } else if looksLikeMarkdown(opts.Result.Content) {
- body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", opts.Result.Content, 0, bodyWidth, opts.ExpandedContent))
- } else {
- body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
- }
-
+ body := renderToolResultTextContent(sty, opts.Result.Content, toolResultContentWidths{Body: bodyWidth, Diff: cappedWidth}, opts.ExpandedContent)
return joinToolParts(header, body)
}
-
-// genericPrettyName converts a snake_case or kebab-case tool name to a
-// human-readable title case name.
-func genericPrettyName(name string) string {
- name = strings.ReplaceAll(name, "_", " ")
- name = strings.ReplaceAll(name, "-", " ")
- return stringext.Capitalize(name)
-}
@@ -6,7 +6,6 @@ import (
"strings"
"github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/stringext"
"github.com/charmbracelet/crush/internal/ui/styles"
)
@@ -37,8 +36,8 @@ func (b *MCPToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *T
if len(toolNameParts) != 3 {
return toolErrorContent(sty, &message.ToolResult{Content: "Invalid tool name"}, cappedWidth)
}
- mcpName := prettyName(toolNameParts[1])
- toolName := prettyName(toolNameParts[2])
+ mcpName := humanizedToolName(toolNameParts[1])
+ toolName := humanizedToolName(toolNameParts[2])
mcpName = sty.Tool.MCPName.Render(mcpName)
toolName = sty.Tool.MCPToolName.Render(toolName)
@@ -74,48 +73,6 @@ func (b *MCPToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *T
}
bodyWidth := cappedWidth - toolBodyLeftPaddingTotal
- // see if the result is json
- var result json.RawMessage
- var body string
- if err := json.Unmarshal([]byte(opts.Result.Content), &result); err == nil {
- prettyResult, err := json.MarshalIndent(result, "", " ")
- if err == nil {
- body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, bodyWidth, opts.ExpandedContent))
- } else {
- body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
- }
- } else if looksLikeMarkdown(opts.Result.Content) {
- body = sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", opts.Result.Content, 0, bodyWidth, opts.ExpandedContent))
- } else {
- body = sty.Tool.Body.Render(toolOutputPlainContent(sty, opts.Result.Content, bodyWidth, opts.ExpandedContent))
- }
+ body := renderToolResultTextContent(sty, opts.Result.Content, toolResultContentWidths{Body: bodyWidth, Diff: cappedWidth}, opts.ExpandedContent)
return joinToolParts(header, body)
}
-
-func prettyName(name string) string {
- name = strings.ReplaceAll(name, "_", " ")
- name = strings.ReplaceAll(name, "-", " ")
- return stringext.Capitalize(name)
-}
-
-// looksLikeMarkdown checks if content appears to be markdown by looking for
-// common markdown patterns.
-func looksLikeMarkdown(content string) bool {
- patterns := []string{
- "# ", // headers
- "## ", // headers
- "**", // bold
- "```", // code fence
- "- ", // unordered list
- "1. ", // ordered list
- "> ", // blockquote
- "---", // horizontal rule
- "***", // horizontal rule
- }
- for _, p := range patterns {
- if strings.Contains(content, p) {
- return true
- }
- }
- return false
-}
@@ -0,0 +1,406 @@
+package chat
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestLooksLikeDiff(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ content string
+ want bool
+ }{
+ {
+ name: "simple unified diff",
+ content: `diff --git a/main.go b/main.go
+--- a/main.go
++++ b/main.go
+@@ -1,5 +1,6 @@
+ package main
+
++import "fmt"
++
+ func main() {
+- println("hello")
++ fmt.Println("hello")
+ }
+`,
+ want: true,
+ },
+ {
+ name: "plain text",
+ content: "This is just some plain text with no diff markers.",
+ want: false,
+ },
+ {
+ name: "empty string",
+ content: "",
+ want: false,
+ },
+ {
+ name: "markdown with headers",
+ content: `# Title
+
+Some content here.
+
+## Subtitle
+
+More content with **bold** text.
+`,
+ want: false,
+ },
+ {
+ name: "diff with mixed content",
+ content: `diff --git a/file.txt b/file.txt
+--- a/file.txt
++++ b/file.txt
+@@ -1 +1 @@
+-old line
++new line
+`,
+ want: true,
+ },
+ {
+ name: "only plus/minus without hunk or headers",
+ content: `Hello world
+---
+This is not really a diff
+Just some text with a few symbols
++ another line
+More regular content here
+And even more content
+`,
+ want: false,
+ },
+ {
+ name: "GitHub PR diff format",
+ content: `diff --git a/src/app.ts b/src/app.ts
+index abc1234..def5678 100644
+--- a/src/app.ts
++++ b/src/app.ts
+@@ -10,6 +10,8 @@ function handleRequest() {
+ const data = getData();
++ validate(data);
++ log(data);
+ return process(data);
+ }
+`,
+ want: true,
+ },
+ {
+ name: "non-git unified patch with hunk and headers",
+ content: `--- a/old.c
++++ b/old.c
+@@ -1,3 +1,4 @@
+ #include <stdio.h>
+-int main() {
++int main(int argc, char **argv) {
+ return 0;
+ }
+`,
+ want: true,
+ },
+ {
+ name: "file headers without hunk markers",
+ content: `--- a/somefile.txt
++++ b/somefile.txt
+Just some content here
+No hunk markers at all
+`,
+ want: false,
+ },
+ {
+ name: "hunk markers without file headers",
+ content: `@@ -1,3 +1,4 @@
+ some line
+-another line
++changed line
+`,
+ want: false,
+ },
+ {
+ name: "markdown list with plus signs",
+ content: `- Item one
+- Item two
++ Bonus item
+- Item three
+`,
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := looksLikeDiff(tt.content)
+ if got != tt.want {
+ t.Errorf("looksLikeDiff() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestParseUnifiedDiff(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want []parsedDiffFile
+ }{
+ {
+ name: "simple diff with additions and removals",
+ input: `diff --git a/main.go b/main.go
+--- a/main.go
++++ b/main.go
+@@ -1,5 +1,6 @@
+ package main
+
++import "fmt"
++
+ func main() {
+- println("hello")
++ fmt.Println("hello")
+ }
+`,
+ want: []parsedDiffFile{
+ {
+ path: "main.go",
+ before: "package main\n\nfunc main() {\n println(\"hello\")\n}",
+ after: "package main\n\nimport \"fmt\"\n\nfunc main() {\n fmt.Println(\"hello\")\n}",
+ },
+ },
+ },
+ {
+ name: "new file creation",
+ input: `diff --git a/newfile.go b/newfile.go
+new file mode 100644
+--- /dev/null
++++ b/newfile.go
+@@ -0,0 +1,3 @@
++package main
++
++func main() {}
+`,
+ want: []parsedDiffFile{
+ {
+ path: "newfile.go",
+ before: "",
+ after: "package main\n\nfunc main() {}",
+ },
+ },
+ },
+ {
+ name: "file deletion",
+ input: `diff --git a/oldfile.go b/oldfile.go
+deleted file mode 100644
+--- a/oldfile.go
++++ /dev/null
+@@ -1,3 +0,0 @@
+-package main
+-
+-func main() {}
+`,
+ want: []parsedDiffFile{
+ {
+ path: "oldfile.go",
+ before: "package main\n\nfunc main() {}",
+ after: "",
+ },
+ },
+ },
+ {
+ name: "non-diff content",
+ input: "Just some regular text",
+ want: nil,
+ },
+ {
+ name: "diff with timestamp in header",
+ input: `diff --git a/config.yml b/config.yml
+--- a/config.yml 2024-01-15 10:30:00
++++ b/config.yml 2024-01-15 10:31:00
+@@ -1,3 +1,4 @@
+ name: myapp
+-version: 1.0
++version: 1.1
++debug: true
+`,
+ want: []parsedDiffFile{
+ {
+ path: "config.yml",
+ before: "name: myapp\nversion: 1.0",
+ after: "name: myapp\nversion: 1.1\ndebug: true",
+ },
+ },
+ },
+ {
+ name: "multi-file diff",
+ input: `diff --git a/one.txt b/one.txt
+--- a/one.txt
++++ b/one.txt
+@@ -1,3 +1,3 @@
+ line one
+-line two
++line two updated
+ line three
+diff --git a/two.txt b/two.txt
+--- a/two.txt
++++ b/two.txt
+@@ -1,2 +1,3 @@
+ alpha
++beta
+ gamma
+`,
+ want: []parsedDiffFile{
+ {
+ path: "one.txt",
+ before: "line one\nline two\nline three",
+ after: "line one\nline two updated\nline three",
+ },
+ {
+ path: "two.txt",
+ before: "alpha\ngamma",
+ after: "alpha\nbeta\ngamma",
+ },
+ },
+ },
+ {
+ name: "non-git unified patch",
+ input: `--- old.c
++++ old.c
+@@ -1,3 +1,4 @@
+ #include <stdio.h>
+-int main() {
++int main(int argc, char **argv) {
+ return 0;
+ }
+`,
+ want: []parsedDiffFile{
+ {
+ path: "old.c",
+ before: "#include <stdio.h>\nint main() {\n return 0;\n}",
+ after: "#include <stdio.h>\nint main(int argc, char **argv) {\n return 0;\n}",
+ },
+ },
+ },
+ {
+ name: "non-git new file from /dev/null",
+ input: `--- /dev/null
++++ newfile.txt
+@@ -0,0 +1,2 @@
++hello
++world
+`,
+ want: []parsedDiffFile{
+ {
+ path: "newfile.txt",
+ before: "",
+ after: "hello\nworld",
+ },
+ },
+ },
+ {
+ name: "non-git new file with only +++ header",
+ input: `+++ brand_new.go
+@@ -0,0 +1,3 @@
++package main
++
++func main() {}
+`,
+ want: []parsedDiffFile{
+ {
+ path: "brand_new.go",
+ before: "",
+ after: "package main\n\nfunc main() {}",
+ },
+ },
+ },
+ {
+ name: "multi-hunk single file",
+ input: `diff --git a/big.go b/big.go
+--- a/big.go
++++ b/big.go
+@@ -1,4 +1,5 @@
+ package main
++import "os"
+
+ func init() {
+@@ -10,3 +11,3 @@
+- println("done")
++ fmt.Println("done")
+ }
+`,
+ want: []parsedDiffFile{
+ {
+ path: "big.go",
+ before: "package main\n\nfunc init() {\n println(\"done\")\n}",
+ after: "package main\nimport \"os\"\n\nfunc init() {\n fmt.Println(\"done\")\n}",
+ },
+ },
+ },
+ {
+ name: "hunk content starting with header-like prefixes",
+ input: `diff --git a/file.txt b/file.txt
+--- a/file.txt
++++ b/file.txt
+@@ -1,3 +1,3 @@
+---- tricky
+++++ newer
+ keep
+`,
+ want: []parsedDiffFile{
+ {
+ path: "file.txt",
+ before: "--- tricky\nkeep",
+ after: "+++ newer\nkeep",
+ },
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := parseUnifiedDiff(tt.input)
+ if len(got) != len(tt.want) {
+ t.Errorf("parseUnifiedDiff() returned %d files, want %d", len(got), len(tt.want))
+ return
+ }
+ for i, w := range tt.want {
+ if got[i].path != w.path {
+ t.Errorf("parseUnifiedDiff()[%d].path = %q, want %q", i, got[i].path, w.path)
+ }
+ if got[i].before != w.before {
+ t.Errorf("parseUnifiedDiff()[%d].before = %q, want %q", i, got[i].before, w.before)
+ }
+ if got[i].after != w.after {
+ t.Errorf("parseUnifiedDiff()[%d].after = %q, want %q", i, got[i].after, w.after)
+ }
+ }
+ })
+ }
+}
+
+func TestLooksLikeDiffVersusMarkdown(t *testing.T) {
+ t.Parallel()
+
+ // A unified diff should be detected as a diff, not markdown,
+ // even though it contains "-" which could match markdown patterns.
+ diffContent := strings.Join([]string{
+ "diff --git a/README.md b/README.md",
+ "--- a/README.md",
+ "+++ b/README.md",
+ "@@ -1,3 +1,3 @@",
+ " # Title",
+ "-Old subtitle",
+ "+New subtitle",
+ " Some content",
+ }, "\n")
+
+ if !looksLikeDiff(diffContent) {
+ t.Error("looksLikeDiff() should detect unified diff")
+ }
+}
@@ -0,0 +1,59 @@
+package chat
+
+import (
+ "encoding/json"
+ "strings"
+
+ "github.com/charmbracelet/crush/internal/diffdetect"
+ "github.com/charmbracelet/crush/internal/stringext"
+ "github.com/charmbracelet/crush/internal/ui/styles"
+)
+
+type toolResultContentWidths struct {
+ Body int
+ Diff int
+}
+
+func humanizedToolName(name string) string {
+ name = strings.ReplaceAll(name, "_", " ")
+ name = strings.ReplaceAll(name, "-", " ")
+ return stringext.Capitalize(name)
+}
+
+func looksLikeMarkdown(content string) bool {
+ patterns := []string{
+ "# ",
+ "## ",
+ "**",
+ "```",
+ "- ",
+ "1. ",
+ "> ",
+ "---",
+ "***",
+ }
+ for _, p := range patterns {
+ if strings.Contains(content, p) {
+ return true
+ }
+ }
+ return false
+}
+
+func renderToolResultTextContent(sty *styles.Styles, content string, widths toolResultContentWidths, expanded bool) string {
+ var result json.RawMessage
+ if err := json.Unmarshal([]byte(content), &result); err == nil {
+ prettyResult, err := json.MarshalIndent(result, "", " ")
+ if err == nil {
+ return sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.json", string(prettyResult), 0, widths.Body, expanded))
+ }
+ return sty.Tool.Body.Render(toolOutputPlainContent(sty, content, widths.Body, expanded))
+ }
+ if diffdetect.IsUnifiedDiff(content) {
+ return toolOutputDiffContentFromUnified(sty, content, widths.Diff, expanded)
+ }
+ if looksLikeMarkdown(content) {
+ return sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.md", content, 0, widths.Body, expanded))
+ }
+ return sty.Tool.Body.Render(toolOutputPlainContent(sty, content, widths.Body, expanded))
+}
@@ -0,0 +1,116 @@
+package chat
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/ui/styles"
+)
+
+func TestHumanizedToolName(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {name: "snake case", input: "mcp_github_get", want: "Mcp Github Get"},
+ {name: "kebab case", input: "web-fetch", want: "Web Fetch"},
+ {name: "mixed", input: "job_output-tool", want: "Job Output Tool"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := humanizedToolName(tt.input); got != tt.want {
+ t.Fatalf("humanizedToolName() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestLooksLikeMarkdown(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ content string
+ want bool
+ }{
+ {name: "header", content: "# Title", want: true},
+ {name: "code fence", content: "```go\nfmt.Println(\"x\")\n```", want: true},
+ {name: "plain", content: "hello world", want: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := looksLikeMarkdown(tt.content); got != tt.want {
+ t.Fatalf("looksLikeMarkdown() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestRenderToolResultTextContent(t *testing.T) {
+ t.Parallel()
+
+ sty := styles.DefaultStyles()
+ styPtr := &sty
+ widths := toolResultContentWidths{Body: 80, Diff: 82}
+
+ t.Run("json branch", func(t *testing.T) {
+ t.Parallel()
+ content := `{"a":1}`
+ var result json.RawMessage
+ if err := json.Unmarshal([]byte(content), &result); err != nil {
+ t.Fatalf("json.Unmarshal() error = %v", err)
+ }
+ prettyResult, err := json.MarshalIndent(result, "", " ")
+ if err != nil {
+ t.Fatalf("json.MarshalIndent() error = %v", err)
+ }
+ expected := styPtr.Tool.Body.Render(toolOutputCodeContent(styPtr, "result.json", string(prettyResult), 0, widths.Body, false))
+ got := renderToolResultTextContent(styPtr, content, widths, false)
+ if got != expected {
+ t.Fatal("renderToolResultTextContent() did not choose JSON rendering")
+ }
+ })
+
+ t.Run("diff branch before markdown", func(t *testing.T) {
+ t.Parallel()
+ content := `diff --git a/README.md b/README.md
+--- a/README.md
++++ b/README.md
+@@ -1 +1 @@
+-# Old
++# New
+`
+ expected := toolOutputDiffContentFromUnified(styPtr, content, widths.Diff, false)
+ got := renderToolResultTextContent(styPtr, content, widths, false)
+ if got != expected {
+ t.Fatal("renderToolResultTextContent() did not choose diff rendering")
+ }
+ })
+
+ t.Run("markdown branch", func(t *testing.T) {
+ t.Parallel()
+ content := "# Title\n\nBody"
+ expected := styPtr.Tool.Body.Render(toolOutputCodeContent(styPtr, "result.md", content, 0, widths.Body, false))
+ got := renderToolResultTextContent(styPtr, content, widths, false)
+ if got != expected {
+ t.Fatal("renderToolResultTextContent() did not choose markdown rendering")
+ }
+ })
+
+ t.Run("plain branch", func(t *testing.T) {
+ t.Parallel()
+ content := "plain text"
+ expected := styPtr.Tool.Body.Render(toolOutputPlainContent(styPtr, content, widths.Body, false))
+ got := renderToolResultTextContent(styPtr, content, widths, false)
+ if got != expected {
+ t.Fatal("renderToolResultTextContent() did not choose plain rendering")
+ }
+ })
+}
@@ -1406,6 +1406,6 @@ func prettifyToolName(name string) string {
case tools.WriteToolName:
return "Write"
default:
- return genericPrettyName(name)
+ return humanizedToolName(name)
}
}
@@ -0,0 +1,171 @@
+package chat
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/charmbracelet/crush/internal/diffdetect"
+ "github.com/charmbracelet/crush/internal/ui/common"
+ "github.com/charmbracelet/crush/internal/ui/styles"
+)
+
+type parsedDiffFile struct {
+ path string
+ before string
+ after string
+}
+
+func looksLikeDiff(content string) bool {
+ return diffdetect.IsUnifiedDiff(content)
+}
+
+func parseUnifiedDiff(content string) []parsedDiffFile {
+ type fileBuilder struct {
+ path string
+ before strings.Builder
+ after strings.Builder
+ }
+
+ var files []fileBuilder
+ currentIdx := -1
+ inHunk := false
+ lines := strings.Split(content, "\n")
+
+ for i, line := range lines {
+ if strings.HasPrefix(line, "diff --git ") {
+ inHunk = false
+ parts := strings.SplitN(line, " ", 4)
+ if len(parts) >= 4 {
+ files = append(files, fileBuilder{path: strings.TrimPrefix(parts[3], "b/")})
+ currentIdx = len(files) - 1
+ }
+ continue
+ }
+
+ if strings.HasPrefix(line, "@@") {
+ inHunk = true
+ continue
+ }
+
+ if strings.HasPrefix(line, "index ") || strings.HasPrefix(line, "new file") || strings.HasPrefix(line, "deleted file") {
+ inHunk = false
+ continue
+ }
+
+ nextIsPlusHeader := i+1 < len(lines) && strings.HasPrefix(lines[i+1], "+++ ")
+ if strings.HasPrefix(line, "--- ") && (!inHunk || nextIsPlusHeader) {
+ startedNewFileFromHunk := inHunk && nextIsPlusHeader
+ inHunk = false
+ p := strings.TrimPrefix(line, "--- ")
+ p = strings.TrimPrefix(p, "a/")
+ if idx := strings.Index(p, "\t"); idx >= 0 {
+ p = p[:idx]
+ }
+ if currentIdx < 0 || startedNewFileFromHunk {
+ files = append(files, fileBuilder{path: p})
+ currentIdx = len(files) - 1
+ continue
+ }
+ if p != "/dev/null" {
+ files[currentIdx].path = p
+ }
+ continue
+ }
+
+ if strings.HasPrefix(line, "+++ ") && !inHunk {
+ p := strings.TrimPrefix(line, "+++ ")
+ p = strings.TrimPrefix(p, "b/")
+ if idx := strings.Index(p, "\t"); idx >= 0 {
+ p = p[:idx]
+ }
+ if currentIdx < 0 {
+ if p != "/dev/null" {
+ files = append(files, fileBuilder{path: p})
+ currentIdx = len(files) - 1
+ }
+ continue
+ }
+ if p != "/dev/null" && (files[currentIdx].path == "" || strings.HasPrefix(files[currentIdx].path, "/dev/null")) {
+ files[currentIdx].path = p
+ }
+ continue
+ }
+
+ if currentIdx < 0 {
+ continue
+ }
+
+ if strings.HasPrefix(line, "-") {
+ inHunk = true
+ files[currentIdx].before.WriteString(line[1:])
+ files[currentIdx].before.WriteByte('\n')
+ continue
+ }
+
+ if strings.HasPrefix(line, "+") {
+ inHunk = true
+ files[currentIdx].after.WriteString(line[1:])
+ files[currentIdx].after.WriteByte('\n')
+ continue
+ }
+
+ if strings.HasPrefix(line, " ") {
+ inHunk = true
+ lineContent := line[1:]
+ files[currentIdx].before.WriteString(lineContent)
+ files[currentIdx].before.WriteByte('\n')
+ files[currentIdx].after.WriteString(lineContent)
+ files[currentIdx].after.WriteByte('\n')
+ }
+ }
+
+ result := make([]parsedDiffFile, 0, len(files))
+ for _, f := range files {
+ result = append(result, parsedDiffFile{
+ path: f.path,
+ before: strings.TrimSuffix(f.before.String(), "\n"),
+ after: strings.TrimSuffix(f.after.String(), "\n"),
+ })
+ }
+ return result
+}
+
+func toolOutputDiffContentFromUnified(sty *styles.Styles, content string, width int, expanded bool) string {
+ files := parseUnifiedDiff(content)
+ if len(files) == 0 {
+ bodyWidth := width - toolBodyLeftPaddingTotal
+ return sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.diff", content, 0, bodyWidth, expanded))
+ }
+ bodyWidth := width - toolBodyLeftPaddingTotal
+ var blocks []string
+ for i, f := range files {
+ formatter := common.DiffFormatter(sty).
+ Before(f.path, f.before).
+ After(f.path, f.after).
+ Width(bodyWidth)
+ if len(files) > 1 {
+ formatter = formatter.FileName(f.path)
+ }
+ if width > maxTextWidth {
+ formatter = formatter.Split()
+ }
+ formatted := formatter.String()
+ if i < len(files)-1 {
+ formatted += "\n"
+ }
+ blocks = append(blocks, formatted)
+ }
+ combined := strings.Join(blocks, "\n")
+ lines := strings.Split(combined, "\n")
+ maxLines := responseContextHeight
+ if expanded {
+ maxLines = len(lines)
+ }
+ if len(lines) > maxLines && !expanded {
+ truncMsg := sty.Tool.DiffTruncation.
+ Width(bodyWidth).
+ Render(fmt.Sprintf(assistantMessageTruncateFormat, len(lines)-maxLines))
+ combined = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
+ }
+ return sty.Tool.Body.Render(combined)
+}
@@ -38,6 +38,7 @@ type DiffView struct {
layout layout
before file
after file
+ fileName string
contextLines int
lineNumbers bool
height int
@@ -112,6 +113,12 @@ func (dv *DiffView) After(path, content string) *DiffView {
return dv
}
+// FileName sets the file name header to display above the diff.
+func (dv *DiffView) FileName(name string) *DiffView {
+ dv.fileName = name
+ return dv
+}
+
// clearCaches clears all caches when content or major settings change.
func (dv *DiffView) clearCaches() {
dv.cachedLexer = nil
@@ -287,6 +294,7 @@ func (dv *DiffView) adjustStyles() {
dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
+ dv.style.Filename.LineNumber = setPadding(dv.style.Filename.LineNumber)
}
// detectNumDigits calculates the maximum number of digits needed for before and
@@ -304,6 +312,10 @@ func (dv *DiffView) detectNumDigits() {
func (dv *DiffView) detectTotalLines() {
dv.totalLines = 0
+ if dv.fileName != "" {
+ dv.totalLines++
+ }
+
switch dv.layout {
case layoutUnified:
for _, h := range dv.unified.Hunks {
@@ -415,6 +427,20 @@ func (dv *DiffView) renderUnified() string {
outer:
for i, h := range dv.unified.Hunks {
+ // Render file name header before the first hunk.
+ if i == 0 && dv.fileName != "" {
+ if shouldWrite() {
+ ls := dv.style.Filename
+ if dv.lineNumbers {
+ b.WriteString(ls.LineNumber.Render(pad("…", dv.beforeNumDigits)))
+ b.WriteString(ls.LineNumber.Render(pad("…", dv.afterNumDigits)))
+ }
+ content := ansi.Truncate(" "+dv.fileName, dv.fullCodeWidth, "…")
+ b.WriteString(ls.Code.Width(dv.fullCodeWidth).Render(content))
+ b.WriteString("\n")
+ }
+ printedLines++
+ }
if shouldWrite() {
ls := dv.style.DividerLine
if dv.lineNumbers {
@@ -525,6 +551,23 @@ func (dv *DiffView) renderSplit() string {
outer:
for i, h := range dv.splitHunks {
+ // Render file name header before the first hunk.
+ if i == 0 && dv.fileName != "" {
+ if shouldWrite() {
+ ls := dv.style.Filename
+ if dv.lineNumbers {
+ b.WriteString(ls.LineNumber.Render(pad("…", dv.beforeNumDigits)))
+ }
+ content := ansi.Truncate(" "+dv.fileName, dv.fullCodeWidth, "…")
+ b.WriteString(ls.Code.Width(dv.fullCodeWidth).Render(content))
+ if dv.lineNumbers {
+ b.WriteString(ls.LineNumber.Render(pad("…", dv.afterNumDigits)))
+ }
+ b.WriteString(ls.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
+ b.WriteRune('\n')
+ }
+ printedLines++
+ }
if shouldWrite() {
ls := dv.style.DividerLine
if dv.lineNumbers {
@@ -21,6 +21,7 @@ type Style struct {
EqualLine LineStyle
InsertLine LineStyle
DeleteLine LineStyle
+ Filename LineStyle
}
// DefaultLightStyle provides a default light theme style for the diff view.
@@ -70,6 +71,14 @@ func DefaultLightStyle() Style {
Foreground(charmtone.Pepper).
Background(lipgloss.Color("#ffebee")),
},
+ Filename: LineStyle{
+ LineNumber: lipgloss.NewStyle().
+ Foreground(charmtone.Iron).
+ Background(charmtone.Thunder),
+ Code: lipgloss.NewStyle().
+ Foreground(charmtone.Iron).
+ Background(charmtone.Thunder),
+ },
}
}
@@ -120,5 +129,13 @@ func DefaultDarkStyle() Style {
Foreground(charmtone.Salt).
Background(lipgloss.Color("#3a3030")),
},
+ Filename: LineStyle{
+ LineNumber: lipgloss.NewStyle().
+ Foreground(charmtone.Smoke).
+ Background(charmtone.Sapphire),
+ Code: lipgloss.NewStyle().
+ Foreground(charmtone.Smoke).
+ Background(charmtone.Sapphire),
+ },
}
}
@@ -1056,6 +1056,14 @@ func DefaultStyles() Styles {
Code: lipgloss.NewStyle().
Background(lipgloss.Color("#383030")),
},
+ Filename: diffview.LineStyle{
+ LineNumber: lipgloss.NewStyle().
+ Foreground(fgHalfMuted).
+ Background(bgBaseLighter),
+ Code: lipgloss.NewStyle().
+ Foreground(fgHalfMuted).
+ Background(bgBaseLighter),
+ },
}
s.FilePicker = filepicker.Styles{