1package diffview
2
3import (
4 "fmt"
5 "os"
6 "strconv"
7 "strings"
8
9 "github.com/aymanbagabas/go-udiff"
10 "github.com/aymanbagabas/go-udiff/myers"
11 "github.com/charmbracelet/lipgloss/v2"
12 "github.com/charmbracelet/x/ansi"
13)
14
15const (
16 leadingSymbolsSize = 2
17 lineNumPadding = 1
18)
19
20type file struct {
21 path string
22 content string
23}
24
25type layout int
26
27const (
28 layoutUnified layout = iota + 1
29 layoutSplit
30)
31
32// DiffView represents a view for displaying differences between two files.
33type DiffView struct {
34 layout layout
35 before file
36 after file
37 contextLines int
38 lineNumbers bool
39 highlight bool
40 height int
41 width int
42 style Style
43
44 isComputed bool
45 err error
46 unified udiff.UnifiedDiff
47 edits []udiff.Edit
48
49 splitHunks []splitHunk
50
51 codeWidth int
52 fullCodeWidth int // with leading symbols
53 extraColOnAfter bool // add extra column on after panel
54 beforeNumDigits int
55 afterNumDigits int
56}
57
58// New creates a new DiffView with default settings.
59func New() *DiffView {
60 dv := &DiffView{
61 layout: layoutUnified,
62 contextLines: udiff.DefaultContextLines,
63 lineNumbers: true,
64 }
65 if lipgloss.HasDarkBackground(os.Stdin, os.Stdout) {
66 dv.style = DefaultDarkStyle
67 } else {
68 dv.style = DefaultLightStyle
69 }
70 return dv
71}
72
73// Unified sets the layout of the DiffView to unified.
74func (dv *DiffView) Unified() *DiffView {
75 dv.layout = layoutUnified
76 return dv
77}
78
79// Split sets the layout of the DiffView to split (side-by-side).
80func (dv *DiffView) Split() *DiffView {
81 dv.layout = layoutSplit
82 return dv
83}
84
85// Before sets the "before" file for the DiffView.
86func (dv *DiffView) Before(path, content string) *DiffView {
87 dv.before = file{path: path, content: content}
88 return dv
89}
90
91// After sets the "after" file for the DiffView.
92func (dv *DiffView) After(path, content string) *DiffView {
93 dv.after = file{path: path, content: content}
94 return dv
95}
96
97// ContextLines sets the number of context lines for the DiffView.
98func (dv *DiffView) ContextLines(contextLines int) *DiffView {
99 dv.contextLines = contextLines
100 return dv
101}
102
103// Style sets the style for the DiffView.
104func (dv *DiffView) Style(style Style) *DiffView {
105 dv.style = style
106 return dv
107}
108
109// LineNumbers sets whether to display line numbers in the DiffView.
110func (dv *DiffView) LineNumbers(lineNumbers bool) *DiffView {
111 dv.lineNumbers = lineNumbers
112 return dv
113}
114
115// SyntaxHightlight sets whether to enable syntax highlighting in the DiffView.
116func (dv *DiffView) SyntaxHightlight(highlight bool) *DiffView {
117 dv.highlight = highlight
118 return dv
119}
120
121// Height sets the height of the DiffView.
122func (dv *DiffView) Height(height int) *DiffView {
123 dv.height = height
124 return dv
125}
126
127// Width sets the width of the DiffView.
128func (dv *DiffView) Width(width int) *DiffView {
129 dv.width = width
130 return dv
131}
132
133// String returns the string representation of the DiffView.
134func (dv *DiffView) String() string {
135 if err := dv.computeDiff(); err != nil {
136 return err.Error()
137 }
138 dv.convertDiffToSplit()
139 dv.adjustStyles()
140 dv.detectNumDigits()
141
142 if dv.width <= 0 {
143 dv.detectCodeWidth()
144 } else {
145 dv.resizeCodeWidth()
146 }
147
148 style := lipgloss.NewStyle()
149 if dv.width > 0 {
150 style = style.MaxWidth(dv.width)
151 }
152 if dv.height > 0 {
153 style = style.MaxHeight(dv.height)
154 }
155
156 switch dv.layout {
157 case layoutUnified:
158 return style.Render(strings.TrimSuffix(dv.renderUnified(), "\n"))
159 case layoutSplit:
160 return style.Render(strings.TrimSuffix(dv.renderSplit(), "\n"))
161 default:
162 panic("unknown diffview layout")
163 }
164}
165
166// computeDiff computes the differences between the "before" and "after" files.
167func (dv *DiffView) computeDiff() error {
168 if dv.isComputed {
169 return dv.err
170 }
171 dv.isComputed = true
172 dv.edits = myers.ComputeEdits( //nolint:staticcheck
173 dv.before.content,
174 dv.after.content,
175 )
176 dv.unified, dv.err = udiff.ToUnifiedDiff(
177 dv.before.path,
178 dv.after.path,
179 dv.before.content,
180 dv.edits,
181 dv.contextLines,
182 )
183 return dv.err
184}
185
186// convertDiffToSplit converts the unified diff to a split diff if the layout is
187// set to split.
188func (dv *DiffView) convertDiffToSplit() {
189 if dv.layout != layoutSplit {
190 return
191 }
192
193 dv.splitHunks = make([]splitHunk, len(dv.unified.Hunks))
194 for i, h := range dv.unified.Hunks {
195 dv.splitHunks[i] = hunkToSplit(h)
196 }
197}
198
199// adjustStyles adjusts adds padding and alignment to the styles.
200func (dv *DiffView) adjustStyles() {
201 dv.style.MissingLine.LineNumber = setPadding(dv.style.MissingLine.LineNumber)
202 dv.style.DividerLine.LineNumber = setPadding(dv.style.DividerLine.LineNumber)
203 dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
204 dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
205 dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
206}
207
208// detectNumDigits calculates the maximum number of digits needed for before and
209// after line numbers.
210func (dv *DiffView) detectNumDigits() {
211 dv.beforeNumDigits = 0
212 dv.afterNumDigits = 0
213
214 for _, h := range dv.unified.Hunks {
215 dv.beforeNumDigits = max(dv.beforeNumDigits, len(strconv.Itoa(h.FromLine+len(h.Lines))))
216 dv.afterNumDigits = max(dv.afterNumDigits, len(strconv.Itoa(h.ToLine+len(h.Lines))))
217 }
218}
219
220func setPadding(s lipgloss.Style) lipgloss.Style {
221 return s.Padding(0, lineNumPadding).Align(lipgloss.Right)
222}
223
224// detectCodeWidth calculates the maximum width of code lines in the diff view.
225func (dv *DiffView) detectCodeWidth() {
226 switch dv.layout {
227 case layoutUnified:
228 dv.detectUnifiedCodeWidth()
229 case layoutSplit:
230 dv.detectSplitCodeWidth()
231 }
232 dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
233}
234
235// detectUnifiedCodeWidth calculates the maximum width of code lines in a
236// unified diff.
237func (dv *DiffView) detectUnifiedCodeWidth() {
238 dv.codeWidth = 0
239
240 for _, h := range dv.unified.Hunks {
241 shownLines := ansi.StringWidth(dv.hunkLineFor(h))
242
243 for _, l := range h.Lines {
244 lineWidth := ansi.StringWidth(strings.TrimSuffix(l.Content, "\n")) + 1
245 dv.codeWidth = max(dv.codeWidth, lineWidth, shownLines)
246 }
247 }
248}
249
250// detectSplitCodeWidth calculates the maximum width of code lines in a
251// split diff.
252func (dv *DiffView) detectSplitCodeWidth() {
253 dv.codeWidth = 0
254
255 for i, h := range dv.splitHunks {
256 shownLines := ansi.StringWidth(dv.hunkLineFor(dv.unified.Hunks[i]))
257
258 for _, l := range h.lines {
259 if l.before != nil {
260 codeWidth := ansi.StringWidth(strings.TrimSuffix(l.before.Content, "\n")) + 1
261 dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
262 }
263 if l.after != nil {
264 codeWidth := ansi.StringWidth(strings.TrimSuffix(l.after.Content, "\n")) + 1
265 dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
266 }
267 }
268 }
269}
270
271// resizeCodeWidth resizes the code width to fit within the specified width.
272func (dv *DiffView) resizeCodeWidth() {
273 fullNumWidth := dv.beforeNumDigits + dv.afterNumDigits
274 fullNumWidth += lineNumPadding * 4 // left and right padding for both line numbers
275
276 switch dv.layout {
277 case layoutUnified:
278 dv.codeWidth = dv.width - fullNumWidth - leadingSymbolsSize
279 case layoutSplit:
280 remainingWidth := dv.width - fullNumWidth - leadingSymbolsSize*2
281 dv.codeWidth = remainingWidth / 2
282 dv.extraColOnAfter = isOdd(remainingWidth)
283 }
284
285 dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
286}
287
288// renderUnified renders the unified diff view as a string.
289func (dv *DiffView) renderUnified() string {
290 var b strings.Builder
291
292 fullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
293 printedLines := 0
294
295 for _, h := range dv.unified.Hunks {
296 if dv.lineNumbers {
297 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
298 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
299 }
300 content := ansi.Truncate(dv.hunkLineFor(h), dv.fullCodeWidth, "…")
301 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
302 b.WriteRune('\n')
303 printedLines++
304
305 beforeLine := h.FromLine
306 afterLine := h.ToLine
307
308 for _, l := range h.Lines {
309 content := strings.TrimSuffix(l.Content, "\n")
310 content = ansi.Truncate(content, dv.codeWidth, "…")
311
312 switch l.Kind {
313 case udiff.Equal:
314 if dv.lineNumbers {
315 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
316 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
317 }
318 b.WriteString(fullContentStyle.Render(
319 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" " + content),
320 ))
321 beforeLine++
322 afterLine++
323 case udiff.Insert:
324 if dv.lineNumbers {
325 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
326 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
327 }
328 b.WriteString(fullContentStyle.Render(
329 dv.style.InsertLine.Symbol.Render("+ ") +
330 dv.style.InsertLine.Code.Width(dv.codeWidth).Render(content),
331 ))
332 afterLine++
333 case udiff.Delete:
334 if dv.lineNumbers {
335 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
336 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
337 }
338 b.WriteString(fullContentStyle.Render(
339 dv.style.DeleteLine.Symbol.Render("- ") +
340 dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(content),
341 ))
342 beforeLine++
343 }
344 b.WriteRune('\n')
345
346 printedLines++
347 }
348 }
349
350 for printedLines < dv.height {
351 if dv.lineNumbers {
352 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
353 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
354 }
355 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
356 b.WriteRune('\n')
357 printedLines++
358 }
359
360 return b.String()
361}
362
363// renderSplit renders the split (side-by-side) diff view as a string.
364func (dv *DiffView) renderSplit() string {
365 var b strings.Builder
366
367 beforeFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
368 afterFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth + btoi(dv.extraColOnAfter))
369 printedLines := 0
370
371 for i, h := range dv.splitHunks {
372 if dv.lineNumbers {
373 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
374 }
375 content := ansi.Truncate(dv.hunkLineFor(dv.unified.Hunks[i]), dv.fullCodeWidth, "…")
376 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
377 if dv.lineNumbers {
378 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
379 }
380 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
381 b.WriteRune('\n')
382 printedLines++
383
384 beforeLine := h.fromLine
385 afterLine := h.toLine
386
387 for _, l := range h.lines {
388 var beforeContent string
389 var afterContent string
390 if l.before != nil {
391 beforeContent = strings.TrimSuffix(l.before.Content, "\n")
392 beforeContent = ansi.Truncate(beforeContent, dv.codeWidth, "…")
393 }
394 if l.after != nil {
395 afterContent = strings.TrimSuffix(l.after.Content, "\n")
396 afterContent = ansi.Truncate(afterContent, dv.codeWidth+btoi(dv.extraColOnAfter), "…")
397 }
398
399 switch {
400 case l.before == nil:
401 if dv.lineNumbers {
402 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
403 }
404 b.WriteString(beforeFullContentStyle.Render(
405 dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "),
406 ))
407 case l.before.Kind == udiff.Equal:
408 if dv.lineNumbers {
409 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
410 }
411 b.WriteString(beforeFullContentStyle.Render(
412 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" " + beforeContent),
413 ))
414 beforeLine++
415 case l.before.Kind == udiff.Delete:
416 if dv.lineNumbers {
417 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
418 }
419 b.WriteString(beforeFullContentStyle.Render(
420 dv.style.DeleteLine.Symbol.Render("- ") +
421 dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(beforeContent),
422 ))
423 beforeLine++
424 }
425
426 switch {
427 case l.after == nil:
428 if dv.lineNumbers {
429 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
430 }
431 b.WriteString(afterFullContentStyle.Render(
432 dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "),
433 ))
434 case l.after.Kind == udiff.Equal:
435 if dv.lineNumbers {
436 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
437 }
438 b.WriteString(afterFullContentStyle.Render(
439 dv.style.EqualLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" " + afterContent),
440 ))
441 afterLine++
442 case l.after.Kind == udiff.Insert:
443 if dv.lineNumbers {
444 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
445 }
446 b.WriteString(afterFullContentStyle.Render(
447 dv.style.InsertLine.Symbol.Render("+ ") +
448 dv.style.InsertLine.Code.Width(dv.codeWidth+btoi(dv.extraColOnAfter)).Render(afterContent),
449 ))
450 afterLine++
451 }
452
453 b.WriteRune('\n')
454
455 printedLines++
456 }
457 }
458
459 for printedLines < dv.height {
460 if dv.lineNumbers {
461 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
462 }
463 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
464 if dv.lineNumbers {
465 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
466 }
467 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
468 b.WriteRune('\n')
469 printedLines++
470 }
471
472 return b.String()
473}
474
475// hunkLineFor formats the header line for a hunk in the unified diff view.
476func (dv *DiffView) hunkLineFor(h *udiff.Hunk) string {
477 beforeShownLines, afterShownLines := dv.hunkShownLines(h)
478
479 return fmt.Sprintf(
480 " @@ -%d,%d +%d,%d @@ ",
481 h.FromLine,
482 beforeShownLines,
483 h.ToLine,
484 afterShownLines,
485 )
486}
487
488// hunkShownLines calculates the number of lines shown in a hunk for both before
489// and after versions.
490func (dv *DiffView) hunkShownLines(h *udiff.Hunk) (before, after int) {
491 for _, l := range h.Lines {
492 switch l.Kind {
493 case udiff.Equal:
494 before++
495 after++
496 case udiff.Insert:
497 after++
498 case udiff.Delete:
499 before++
500 }
501 }
502 return
503}