1package chat
2
3import (
4 "fmt"
5 "strings"
6
7 "github.com/charmbracelet/crush/internal/diffdetect"
8 "github.com/charmbracelet/crush/internal/ui/common"
9 "github.com/charmbracelet/crush/internal/ui/styles"
10)
11
12type parsedDiffFile struct {
13 path string
14 before string
15 after string
16}
17
18func looksLikeDiff(content string) bool {
19 return diffdetect.IsUnifiedDiff(content)
20}
21
22func parseUnifiedDiff(content string) []parsedDiffFile {
23 type fileBuilder struct {
24 path string
25 before strings.Builder
26 after strings.Builder
27 }
28
29 var files []fileBuilder
30 currentIdx := -1
31 inHunk := false
32 lines := strings.Split(content, "\n")
33
34 for i, line := range lines {
35 if strings.HasPrefix(line, "diff --git ") {
36 inHunk = false
37 parts := strings.SplitN(line, " ", 4)
38 if len(parts) >= 4 {
39 files = append(files, fileBuilder{path: strings.TrimPrefix(parts[3], "b/")})
40 currentIdx = len(files) - 1
41 }
42 continue
43 }
44
45 if strings.HasPrefix(line, "@@") {
46 inHunk = true
47 continue
48 }
49
50 if strings.HasPrefix(line, "index ") || strings.HasPrefix(line, "new file") || strings.HasPrefix(line, "deleted file") {
51 inHunk = false
52 continue
53 }
54
55 nextIsPlusHeader := i+1 < len(lines) && strings.HasPrefix(lines[i+1], "+++ ")
56 if strings.HasPrefix(line, "--- ") && (!inHunk || nextIsPlusHeader) {
57 startedNewFileFromHunk := inHunk && nextIsPlusHeader
58 inHunk = false
59 p := strings.TrimPrefix(line, "--- ")
60 p = strings.TrimPrefix(p, "a/")
61 if idx := strings.Index(p, "\t"); idx >= 0 {
62 p = p[:idx]
63 }
64 if currentIdx < 0 || startedNewFileFromHunk {
65 files = append(files, fileBuilder{path: p})
66 currentIdx = len(files) - 1
67 continue
68 }
69 if p != "/dev/null" {
70 files[currentIdx].path = p
71 }
72 continue
73 }
74
75 if strings.HasPrefix(line, "+++ ") && !inHunk {
76 p := strings.TrimPrefix(line, "+++ ")
77 p = strings.TrimPrefix(p, "b/")
78 if idx := strings.Index(p, "\t"); idx >= 0 {
79 p = p[:idx]
80 }
81 if currentIdx < 0 {
82 if p != "/dev/null" {
83 files = append(files, fileBuilder{path: p})
84 currentIdx = len(files) - 1
85 }
86 continue
87 }
88 if p != "/dev/null" && (files[currentIdx].path == "" || strings.HasPrefix(files[currentIdx].path, "/dev/null")) {
89 files[currentIdx].path = p
90 }
91 continue
92 }
93
94 if currentIdx < 0 {
95 continue
96 }
97
98 if strings.HasPrefix(line, "-") {
99 inHunk = true
100 files[currentIdx].before.WriteString(line[1:])
101 files[currentIdx].before.WriteByte('\n')
102 continue
103 }
104
105 if strings.HasPrefix(line, "+") {
106 inHunk = true
107 files[currentIdx].after.WriteString(line[1:])
108 files[currentIdx].after.WriteByte('\n')
109 continue
110 }
111
112 if strings.HasPrefix(line, " ") {
113 inHunk = true
114 lineContent := line[1:]
115 files[currentIdx].before.WriteString(lineContent)
116 files[currentIdx].before.WriteByte('\n')
117 files[currentIdx].after.WriteString(lineContent)
118 files[currentIdx].after.WriteByte('\n')
119 }
120 }
121
122 result := make([]parsedDiffFile, 0, len(files))
123 for _, f := range files {
124 result = append(result, parsedDiffFile{
125 path: f.path,
126 before: strings.TrimSuffix(f.before.String(), "\n"),
127 after: strings.TrimSuffix(f.after.String(), "\n"),
128 })
129 }
130 return result
131}
132
133func toolOutputDiffContentFromUnified(sty *styles.Styles, content string, width int, expanded bool) string {
134 files := parseUnifiedDiff(content)
135 if len(files) == 0 {
136 bodyWidth := width - toolBodyLeftPaddingTotal
137 return sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.diff", content, 0, bodyWidth, expanded))
138 }
139 bodyWidth := width - toolBodyLeftPaddingTotal
140 var blocks []string
141 for i, f := range files {
142 formatter := common.DiffFormatter(sty).
143 Before(f.path, f.before).
144 After(f.path, f.after).
145 Width(bodyWidth)
146 if len(files) > 1 {
147 formatter = formatter.FileName(f.path)
148 }
149 if width > maxTextWidth {
150 formatter = formatter.Split()
151 }
152 formatted := formatter.String()
153 if i < len(files)-1 {
154 formatted += "\n"
155 }
156 blocks = append(blocks, formatted)
157 }
158 combined := strings.Join(blocks, "\n")
159 lines := strings.Split(combined, "\n")
160 maxLines := responseContextHeight
161 if expanded {
162 maxLines = len(lines)
163 }
164 if len(lines) > maxLines && !expanded {
165 truncMsg := sty.Tool.DiffTruncation.
166 Width(bodyWidth).
167 Render(fmt.Sprintf(assistantMessageTruncateFormat, len(lines)-maxLines))
168 combined = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
169 }
170 return sty.Tool.Body.Render(combined)
171}