1package diffview
2
3import (
4 "fmt"
5 "os"
6 "strconv"
7 "strings"
8
9 "github.com/aymanbagabas/go-udiff"
10 "github.com/aymanbagabas/go-udiff/myers"
11 "github.com/charmbracelet/lipgloss/v2"
12 "github.com/charmbracelet/x/ansi"
13)
14
15const (
16 leadingSymbolsSize = 2
17 lineNumPadding = 1
18)
19
20type file struct {
21 path string
22 content string
23}
24
25type layout int
26
27const (
28 layoutUnified layout = iota + 1
29 layoutSplit
30)
31
32// DiffView represents a view for displaying differences between two files.
33type DiffView struct {
34 layout layout
35 before file
36 after file
37 contextLines int
38 lineNumbers bool
39 highlight bool
40 height int
41 width int
42 style Style
43
44 isComputed bool
45 err error
46 unified udiff.UnifiedDiff
47 edits []udiff.Edit
48
49 splitHunks []splitHunk
50
51 codeWidth int
52 fullCodeWidth int // with leading symbols
53 extraColOnAfter bool // add extra column on after panel
54 beforeNumDigits int
55 afterNumDigits int
56}
57
58// New creates a new DiffView with default settings.
59func New() *DiffView {
60 dv := &DiffView{
61 layout: layoutUnified,
62 contextLines: udiff.DefaultContextLines,
63 lineNumbers: true,
64 }
65 if lipgloss.HasDarkBackground(os.Stdin, os.Stdout) {
66 dv.style = DefaultDarkStyle
67 } else {
68 dv.style = DefaultLightStyle
69 }
70 return dv
71}
72
73// Unified sets the layout of the DiffView to unified.
74func (dv *DiffView) Unified() *DiffView {
75 dv.layout = layoutUnified
76 return dv
77}
78
79// Split sets the layout of the DiffView to split (side-by-side).
80func (dv *DiffView) Split() *DiffView {
81 dv.layout = layoutSplit
82 return dv
83}
84
85// Before sets the "before" file for the DiffView.
86func (dv *DiffView) Before(path, content string) *DiffView {
87 dv.before = file{path: path, content: content}
88 return dv
89}
90
91// After sets the "after" file for the DiffView.
92func (dv *DiffView) After(path, content string) *DiffView {
93 dv.after = file{path: path, content: content}
94 return dv
95}
96
97// ContextLines sets the number of context lines for the DiffView.
98func (dv *DiffView) ContextLines(contextLines int) *DiffView {
99 dv.contextLines = contextLines
100 return dv
101}
102
103// Style sets the style for the DiffView.
104func (dv *DiffView) Style(style Style) *DiffView {
105 dv.style = style
106 return dv
107}
108
109// LineNumbers sets whether to display line numbers in the DiffView.
110func (dv *DiffView) LineNumbers(lineNumbers bool) *DiffView {
111 dv.lineNumbers = lineNumbers
112 return dv
113}
114
115// SyntaxHightlight sets whether to enable syntax highlighting in the DiffView.
116func (dv *DiffView) SyntaxHightlight(highlight bool) *DiffView {
117 dv.highlight = highlight
118 return dv
119}
120
121// Height sets the height of the DiffView.
122func (dv *DiffView) Height(height int) *DiffView {
123 dv.height = height
124 return dv
125}
126
127// Width sets the width of the DiffView.
128func (dv *DiffView) Width(width int) *DiffView {
129 dv.width = width
130 return dv
131}
132
133// String returns the string representation of the DiffView.
134func (dv *DiffView) String() string {
135 if err := dv.computeDiff(); err != nil {
136 return err.Error()
137 }
138 dv.convertDiffToSplit()
139 dv.adjustStyles()
140 dv.detectNumDigits()
141
142 if dv.width <= 0 {
143 dv.detectCodeWidth()
144 } else {
145 dv.resizeCodeWidth()
146 }
147
148 style := lipgloss.NewStyle()
149 if dv.width > 0 {
150 style = style.MaxWidth(dv.width)
151 }
152 if dv.height > 0 {
153 style = style.MaxHeight(dv.height)
154 }
155
156 switch dv.layout {
157 case layoutUnified:
158 return style.Render(strings.TrimSuffix(dv.renderUnified(), "\n"))
159 case layoutSplit:
160 return style.Render(strings.TrimSuffix(dv.renderSplit(), "\n"))
161 default:
162 panic("unknown diffview layout")
163 }
164}
165
166// computeDiff computes the differences between the "before" and "after" files.
167func (dv *DiffView) computeDiff() error {
168 if dv.isComputed {
169 return dv.err
170 }
171 dv.isComputed = true
172 dv.edits = myers.ComputeEdits( //nolint:staticcheck
173 dv.before.content,
174 dv.after.content,
175 )
176 dv.unified, dv.err = udiff.ToUnifiedDiff(
177 dv.before.path,
178 dv.after.path,
179 dv.before.content,
180 dv.edits,
181 dv.contextLines,
182 )
183 return dv.err
184}
185
186// convertDiffToSplit converts the unified diff to a split diff if the layout is
187// set to split.
188func (dv *DiffView) convertDiffToSplit() {
189 if dv.layout != layoutSplit {
190 return
191 }
192
193 dv.splitHunks = make([]splitHunk, len(dv.unified.Hunks))
194 for i, h := range dv.unified.Hunks {
195 dv.splitHunks[i] = hunkToSplit(h)
196 }
197}
198
199// adjustStyles adjusts adds padding and alignment to the styles.
200func (dv *DiffView) adjustStyles() {
201 dv.style.MissingLine.LineNumber = setPadding(dv.style.MissingLine.LineNumber)
202 dv.style.DividerLine.LineNumber = setPadding(dv.style.DividerLine.LineNumber)
203 dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
204 dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
205 dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
206}
207
208// detectNumDigits calculates the maximum number of digits needed for before and
209// after line numbers.
210func (dv *DiffView) detectNumDigits() {
211 dv.beforeNumDigits = 0
212 dv.afterNumDigits = 0
213
214 for _, h := range dv.unified.Hunks {
215 dv.beforeNumDigits = max(dv.beforeNumDigits, len(strconv.Itoa(h.FromLine+len(h.Lines))))
216 dv.afterNumDigits = max(dv.afterNumDigits, len(strconv.Itoa(h.ToLine+len(h.Lines))))
217 }
218}
219
220func setPadding(s lipgloss.Style) lipgloss.Style {
221 return s.Padding(0, lineNumPadding).Align(lipgloss.Right)
222}
223
224// detectCodeWidth calculates the maximum width of code lines in the diff view.
225func (dv *DiffView) detectCodeWidth() {
226 switch dv.layout {
227 case layoutUnified:
228 dv.detectUnifiedCodeWidth()
229 case layoutSplit:
230 dv.detectSplitCodeWidth()
231 }
232 dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
233}
234
235// detectUnifiedCodeWidth calculates the maximum width of code lines in a
236// unified diff.
237func (dv *DiffView) detectUnifiedCodeWidth() {
238 dv.codeWidth = 0
239
240 for _, h := range dv.unified.Hunks {
241 shownLines := ansi.StringWidth(dv.hunkLineFor(h))
242
243 for _, l := range h.Lines {
244 lineWidth := ansi.StringWidth(strings.TrimSuffix(l.Content, "\n")) + 1
245 dv.codeWidth = max(dv.codeWidth, lineWidth, shownLines)
246 }
247 }
248}
249
250// detectSplitCodeWidth calculates the maximum width of code lines in a
251// split diff.
252func (dv *DiffView) detectSplitCodeWidth() {
253 dv.codeWidth = 0
254
255 for i, h := range dv.splitHunks {
256 shownLines := ansi.StringWidth(dv.hunkLineFor(dv.unified.Hunks[i]))
257
258 for _, l := range h.lines {
259 if l.before != nil {
260 codeWidth := ansi.StringWidth(strings.TrimSuffix(l.before.Content, "\n")) + 1
261 dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
262 }
263 if l.after != nil {
264 codeWidth := ansi.StringWidth(strings.TrimSuffix(l.after.Content, "\n")) + 1
265 dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
266 }
267 }
268 }
269}
270
271// resizeCodeWidth resizes the code width to fit within the specified width.
272func (dv *DiffView) resizeCodeWidth() {
273 fullNumWidth := dv.beforeNumDigits + dv.afterNumDigits
274 fullNumWidth += lineNumPadding * 4 // left and right padding for both line numbers
275
276 switch dv.layout {
277 case layoutUnified:
278 dv.codeWidth = dv.width - fullNumWidth - leadingSymbolsSize
279 case layoutSplit:
280 remainingWidth := dv.width - fullNumWidth - leadingSymbolsSize*2
281 dv.codeWidth = remainingWidth / 2
282 dv.extraColOnAfter = isOdd(remainingWidth)
283 }
284
285 dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
286}
287
288// renderUnified renders the unified diff view as a string.
289func (dv *DiffView) renderUnified() string {
290 var b strings.Builder
291
292 fullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
293 printedLines := 0
294
295outer:
296 for i, h := range dv.unified.Hunks {
297 if dv.lineNumbers {
298 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
299 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
300 }
301 content := ansi.Truncate(dv.hunkLineFor(h), dv.fullCodeWidth, "…")
302 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
303 b.WriteRune('\n')
304 printedLines++
305
306 beforeLine := h.FromLine
307 afterLine := h.ToLine
308
309 for j, l := range h.Lines {
310 // print ellipis if we don't have enough space to print the rest of the diff
311 hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
312 isLastHunk := i+1 == len(dv.unified.Hunks)
313 isLastLine := j+1 == len(h.Lines)
314 if hasReachedHeight && (!isLastHunk || !isLastLine) {
315 if dv.lineNumbers {
316 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
317 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
318 }
319 b.WriteString(fullContentStyle.Render(
320 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" …"),
321 ))
322 b.WriteRune('\n')
323 break outer
324 }
325
326 content := strings.TrimSuffix(l.Content, "\n")
327 content = ansi.Truncate(content, dv.codeWidth, "…")
328
329 switch l.Kind {
330 case udiff.Equal:
331 if dv.lineNumbers {
332 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
333 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
334 }
335 b.WriteString(fullContentStyle.Render(
336 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" " + content),
337 ))
338 beforeLine++
339 afterLine++
340 case udiff.Insert:
341 if dv.lineNumbers {
342 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
343 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
344 }
345 b.WriteString(fullContentStyle.Render(
346 dv.style.InsertLine.Symbol.Render("+ ") +
347 dv.style.InsertLine.Code.Width(dv.codeWidth).Render(content),
348 ))
349 afterLine++
350 case udiff.Delete:
351 if dv.lineNumbers {
352 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
353 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
354 }
355 b.WriteString(fullContentStyle.Render(
356 dv.style.DeleteLine.Symbol.Render("- ") +
357 dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(content),
358 ))
359 beforeLine++
360 }
361 b.WriteRune('\n')
362
363 printedLines++
364 }
365 }
366
367 for printedLines < dv.height {
368 if dv.lineNumbers {
369 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
370 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
371 }
372 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
373 b.WriteRune('\n')
374 printedLines++
375 }
376
377 return b.String()
378}
379
380// renderSplit renders the split (side-by-side) diff view as a string.
381func (dv *DiffView) renderSplit() string {
382 var b strings.Builder
383
384 beforeFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
385 afterFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth + btoi(dv.extraColOnAfter))
386 printedLines := 0
387
388outer:
389 for i, h := range dv.splitHunks {
390 if dv.lineNumbers {
391 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
392 }
393 content := ansi.Truncate(dv.hunkLineFor(dv.unified.Hunks[i]), dv.fullCodeWidth, "…")
394 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
395 if dv.lineNumbers {
396 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
397 }
398 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
399 b.WriteRune('\n')
400 printedLines++
401
402 beforeLine := h.fromLine
403 afterLine := h.toLine
404
405 for j, l := range h.lines {
406 // print ellipis if we don't have enough space to print the rest of the diff
407 hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
408 isLastHunk := i+1 == len(dv.unified.Hunks)
409 isLastLine := j+1 == len(h.lines)
410 if hasReachedHeight && (!isLastHunk || !isLastLine) {
411 if dv.lineNumbers {
412 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
413 }
414 b.WriteString(beforeFullContentStyle.Render(
415 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" …"),
416 ))
417 if dv.lineNumbers {
418 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
419 }
420 b.WriteString(afterFullContentStyle.Render(
421 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" …"),
422 ))
423 b.WriteRune('\n')
424 break outer
425 }
426
427 var beforeContent string
428 var afterContent string
429 if l.before != nil {
430 beforeContent = strings.TrimSuffix(l.before.Content, "\n")
431 beforeContent = ansi.Truncate(beforeContent, dv.codeWidth, "…")
432 }
433 if l.after != nil {
434 afterContent = strings.TrimSuffix(l.after.Content, "\n")
435 afterContent = ansi.Truncate(afterContent, dv.codeWidth+btoi(dv.extraColOnAfter), "…")
436 }
437
438 switch {
439 case l.before == nil:
440 if dv.lineNumbers {
441 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
442 }
443 b.WriteString(beforeFullContentStyle.Render(
444 dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "),
445 ))
446 case l.before.Kind == udiff.Equal:
447 if dv.lineNumbers {
448 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
449 }
450 b.WriteString(beforeFullContentStyle.Render(
451 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" " + beforeContent),
452 ))
453 beforeLine++
454 case l.before.Kind == udiff.Delete:
455 if dv.lineNumbers {
456 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
457 }
458 b.WriteString(beforeFullContentStyle.Render(
459 dv.style.DeleteLine.Symbol.Render("- ") +
460 dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(beforeContent),
461 ))
462 beforeLine++
463 }
464
465 switch {
466 case l.after == nil:
467 if dv.lineNumbers {
468 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
469 }
470 b.WriteString(afterFullContentStyle.Render(
471 dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "),
472 ))
473 case l.after.Kind == udiff.Equal:
474 if dv.lineNumbers {
475 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
476 }
477 b.WriteString(afterFullContentStyle.Render(
478 dv.style.EqualLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" " + afterContent),
479 ))
480 afterLine++
481 case l.after.Kind == udiff.Insert:
482 if dv.lineNumbers {
483 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
484 }
485 b.WriteString(afterFullContentStyle.Render(
486 dv.style.InsertLine.Symbol.Render("+ ") +
487 dv.style.InsertLine.Code.Width(dv.codeWidth+btoi(dv.extraColOnAfter)).Render(afterContent),
488 ))
489 afterLine++
490 }
491
492 b.WriteRune('\n')
493
494 printedLines++
495 }
496 }
497
498 for printedLines < dv.height {
499 if dv.lineNumbers {
500 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
501 }
502 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
503 if dv.lineNumbers {
504 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
505 }
506 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
507 b.WriteRune('\n')
508 printedLines++
509 }
510
511 return b.String()
512}
513
514// hunkLineFor formats the header line for a hunk in the unified diff view.
515func (dv *DiffView) hunkLineFor(h *udiff.Hunk) string {
516 beforeShownLines, afterShownLines := dv.hunkShownLines(h)
517
518 return fmt.Sprintf(
519 " @@ -%d,%d +%d,%d @@ ",
520 h.FromLine,
521 beforeShownLines,
522 h.ToLine,
523 afterShownLines,
524 )
525}
526
527// hunkShownLines calculates the number of lines shown in a hunk for both before
528// and after versions.
529func (dv *DiffView) hunkShownLines(h *udiff.Hunk) (before, after int) {
530 for _, l := range h.Lines {
531 switch l.Kind {
532 case udiff.Equal:
533 before++
534 after++
535 case udiff.Insert:
536 after++
537 case udiff.Delete:
538 before++
539 }
540 }
541 return
542}