unified_diff.go

  1package chat
  2
  3import (
  4	"fmt"
  5	"strings"
  6
  7	"github.com/charmbracelet/crush/internal/diffdetect"
  8	"github.com/charmbracelet/crush/internal/ui/common"
  9	"github.com/charmbracelet/crush/internal/ui/styles"
 10)
 11
 12type parsedDiffFile struct {
 13	path   string
 14	before string
 15	after  string
 16}
 17
 18func looksLikeDiff(content string) bool {
 19	return diffdetect.IsUnifiedDiff(content)
 20}
 21
 22func parseUnifiedDiff(content string) []parsedDiffFile {
 23	type fileBuilder struct {
 24		path   string
 25		before strings.Builder
 26		after  strings.Builder
 27	}
 28
 29	var files []fileBuilder
 30	currentIdx := -1
 31	inHunk := false
 32	lines := strings.Split(content, "\n")
 33
 34	for i, line := range lines {
 35		if strings.HasPrefix(line, "diff --git ") {
 36			inHunk = false
 37			parts := strings.SplitN(line, " ", 4)
 38			if len(parts) >= 4 {
 39				files = append(files, fileBuilder{path: strings.TrimPrefix(parts[3], "b/")})
 40				currentIdx = len(files) - 1
 41			}
 42			continue
 43		}
 44
 45		if strings.HasPrefix(line, "@@") {
 46			inHunk = true
 47			continue
 48		}
 49
 50		if strings.HasPrefix(line, "index ") || strings.HasPrefix(line, "new file") || strings.HasPrefix(line, "deleted file") {
 51			inHunk = false
 52			continue
 53		}
 54
 55		nextIsPlusHeader := i+1 < len(lines) && strings.HasPrefix(lines[i+1], "+++ ")
 56		if strings.HasPrefix(line, "--- ") && (!inHunk || nextIsPlusHeader) {
 57			startedNewFileFromHunk := inHunk && nextIsPlusHeader
 58			inHunk = false
 59			p := strings.TrimPrefix(line, "--- ")
 60			p = strings.TrimPrefix(p, "a/")
 61			if idx := strings.Index(p, "\t"); idx >= 0 {
 62				p = p[:idx]
 63			}
 64			if currentIdx < 0 || startedNewFileFromHunk {
 65				files = append(files, fileBuilder{path: p})
 66				currentIdx = len(files) - 1
 67				continue
 68			}
 69			if p != "/dev/null" {
 70				files[currentIdx].path = p
 71			}
 72			continue
 73		}
 74
 75		if strings.HasPrefix(line, "+++ ") && !inHunk {
 76			p := strings.TrimPrefix(line, "+++ ")
 77			p = strings.TrimPrefix(p, "b/")
 78			if idx := strings.Index(p, "\t"); idx >= 0 {
 79				p = p[:idx]
 80			}
 81			if currentIdx < 0 {
 82				if p != "/dev/null" {
 83					files = append(files, fileBuilder{path: p})
 84					currentIdx = len(files) - 1
 85				}
 86				continue
 87			}
 88			if p != "/dev/null" && (files[currentIdx].path == "" || strings.HasPrefix(files[currentIdx].path, "/dev/null")) {
 89				files[currentIdx].path = p
 90			}
 91			continue
 92		}
 93
 94		if currentIdx < 0 {
 95			continue
 96		}
 97
 98		if strings.HasPrefix(line, "-") {
 99			inHunk = true
100			files[currentIdx].before.WriteString(line[1:])
101			files[currentIdx].before.WriteByte('\n')
102			continue
103		}
104
105		if strings.HasPrefix(line, "+") {
106			inHunk = true
107			files[currentIdx].after.WriteString(line[1:])
108			files[currentIdx].after.WriteByte('\n')
109			continue
110		}
111
112		if strings.HasPrefix(line, " ") {
113			inHunk = true
114			lineContent := line[1:]
115			files[currentIdx].before.WriteString(lineContent)
116			files[currentIdx].before.WriteByte('\n')
117			files[currentIdx].after.WriteString(lineContent)
118			files[currentIdx].after.WriteByte('\n')
119		}
120	}
121
122	result := make([]parsedDiffFile, 0, len(files))
123	for _, f := range files {
124		result = append(result, parsedDiffFile{
125			path:   f.path,
126			before: strings.TrimSuffix(f.before.String(), "\n"),
127			after:  strings.TrimSuffix(f.after.String(), "\n"),
128		})
129	}
130	return result
131}
132
133func toolOutputDiffContentFromUnified(sty *styles.Styles, content string, width int, expanded bool) string {
134	files := parseUnifiedDiff(content)
135	if len(files) == 0 {
136		bodyWidth := width - toolBodyLeftPaddingTotal
137		return sty.Tool.Body.Render(toolOutputCodeContent(sty, "result.diff", content, 0, bodyWidth, expanded))
138	}
139	bodyWidth := width - toolBodyLeftPaddingTotal
140	var blocks []string
141	for i, f := range files {
142		formatter := common.DiffFormatter(sty).
143			Before(f.path, f.before).
144			After(f.path, f.after).
145			Width(bodyWidth)
146		if len(files) > 1 {
147			formatter = formatter.FileName(f.path)
148		}
149		if width > maxTextWidth {
150			formatter = formatter.Split()
151		}
152		formatted := formatter.String()
153		if i < len(files)-1 {
154			formatted += "\n"
155		}
156		blocks = append(blocks, formatted)
157	}
158	combined := strings.Join(blocks, "\n")
159	lines := strings.Split(combined, "\n")
160	maxLines := responseContextHeight
161	if expanded {
162		maxLines = len(lines)
163	}
164	if len(lines) > maxLines && !expanded {
165		truncMsg := sty.Tool.DiffTruncation.
166			Width(bodyWidth).
167			Render(fmt.Sprintf(assistantMessageTruncateFormat, len(lines)-maxLines))
168		combined = strings.Join(lines[:maxLines], "\n") + "\n" + truncMsg
169	}
170	return sty.Tool.Body.Render(combined)
171}