1package diffview
2
3import (
4 "fmt"
5 "os"
6 "strconv"
7 "strings"
8
9 "github.com/aymanbagabas/go-udiff"
10 "github.com/aymanbagabas/go-udiff/myers"
11 "github.com/charmbracelet/lipgloss/v2"
12 "github.com/charmbracelet/x/ansi"
13)
14
15const (
16 leadingSymbolsSize = 2
17 lineNumPadding = 1
18)
19
20type file struct {
21 path string
22 content string
23}
24
25type layout int
26
27const (
28 layoutUnified layout = iota + 1
29 layoutSplit
30)
31
32// DiffView represents a view for displaying differences between two files.
33type DiffView struct {
34 layout layout
35 before file
36 after file
37 contextLines int
38 lineNumbers bool
39 highlight bool
40 height int
41 width int
42 style Style
43
44 isComputed bool
45 err error
46 unified udiff.UnifiedDiff
47 edits []udiff.Edit
48
49 splitHunks []splitHunk
50
51 codeWidth int
52 fullCodeWidth int // with leading symbols
53 extraColOnAfter bool // add extra column on after panel
54 beforeNumDigits int
55 afterNumDigits int
56}
57
58// New creates a new DiffView with default settings.
59func New() *DiffView {
60 dv := &DiffView{
61 layout: layoutUnified,
62 contextLines: udiff.DefaultContextLines,
63 lineNumbers: true,
64 }
65 if lipgloss.HasDarkBackground(os.Stdin, os.Stdout) {
66 dv.style = DefaultDarkStyle
67 } else {
68 dv.style = DefaultLightStyle
69 }
70 return dv
71}
72
73// Unified sets the layout of the DiffView to unified.
74func (dv *DiffView) Unified() *DiffView {
75 dv.layout = layoutUnified
76 return dv
77}
78
79// Split sets the layout of the DiffView to split (side-by-side).
80func (dv *DiffView) Split() *DiffView {
81 dv.layout = layoutSplit
82 return dv
83}
84
85// Before sets the "before" file for the DiffView.
86func (dv *DiffView) Before(path, content string) *DiffView {
87 dv.before = file{path: path, content: content}
88 return dv
89}
90
91// After sets the "after" file for the DiffView.
92func (dv *DiffView) After(path, content string) *DiffView {
93 dv.after = file{path: path, content: content}
94 return dv
95}
96
97// ContextLines sets the number of context lines for the DiffView.
98func (dv *DiffView) ContextLines(contextLines int) *DiffView {
99 dv.contextLines = contextLines
100 return dv
101}
102
103// Style sets the style for the DiffView.
104func (dv *DiffView) Style(style Style) *DiffView {
105 dv.style = style
106 return dv
107}
108
109// LineNumbers sets whether to display line numbers in the DiffView.
110func (dv *DiffView) LineNumbers(lineNumbers bool) *DiffView {
111 dv.lineNumbers = lineNumbers
112 return dv
113}
114
115// SyntaxHightlight sets whether to enable syntax highlighting in the DiffView.
116func (dv *DiffView) SyntaxHightlight(highlight bool) *DiffView {
117 dv.highlight = highlight
118 return dv
119}
120
121// Height sets the height of the DiffView.
122func (dv *DiffView) Height(height int) *DiffView {
123 dv.height = height
124 return dv
125}
126
127// Width sets the width of the DiffView.
128func (dv *DiffView) Width(width int) *DiffView {
129 dv.width = width
130 return dv
131}
132
133// String returns the string representation of the DiffView.
134func (dv *DiffView) String() string {
135 if err := dv.computeDiff(); err != nil {
136 return err.Error()
137 }
138 dv.convertDiffToSplit()
139 dv.adjustStyles()
140 dv.detectNumDigits()
141
142 if dv.width <= 0 {
143 dv.detectCodeWidth()
144 } else {
145 dv.resizeCodeWidth()
146 }
147
148 style := lipgloss.NewStyle()
149 if dv.width > 0 {
150 style = style.MaxWidth(dv.width)
151 }
152 if dv.height > 0 {
153 style = style.MaxHeight(dv.height)
154 }
155
156 switch dv.layout {
157 case layoutUnified:
158 return style.Render(strings.TrimSuffix(dv.renderUnified(), "\n"))
159 case layoutSplit:
160 return style.Render(strings.TrimSuffix(dv.renderSplit(), "\n"))
161 default:
162 panic("unknown diffview layout")
163 }
164}
165
166// computeDiff computes the differences between the "before" and "after" files.
167func (dv *DiffView) computeDiff() error {
168 if dv.isComputed {
169 return dv.err
170 }
171 dv.isComputed = true
172 dv.edits = myers.ComputeEdits( //nolint:staticcheck
173 dv.before.content,
174 dv.after.content,
175 )
176 dv.unified, dv.err = udiff.ToUnifiedDiff(
177 dv.before.path,
178 dv.after.path,
179 dv.before.content,
180 dv.edits,
181 dv.contextLines,
182 )
183 return dv.err
184}
185
186// convertDiffToSplit converts the unified diff to a split diff if the layout is
187// set to split.
188func (dv *DiffView) convertDiffToSplit() {
189 if dv.layout != layoutSplit {
190 return
191 }
192
193 dv.splitHunks = make([]splitHunk, len(dv.unified.Hunks))
194 for i, h := range dv.unified.Hunks {
195 dv.splitHunks[i] = hunkToSplit(h)
196 }
197}
198
199// adjustStyles adjusts adds padding and alignment to the styles.
200func (dv *DiffView) adjustStyles() {
201 setPadding := func(s lipgloss.Style) lipgloss.Style {
202 return s.Padding(0, lineNumPadding).Align(lipgloss.Right)
203 }
204 dv.style.MissingLine.LineNumber = setPadding(dv.style.MissingLine.LineNumber)
205 dv.style.DividerLine.LineNumber = setPadding(dv.style.DividerLine.LineNumber)
206 dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
207 dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
208 dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
209}
210
211// detectNumDigits calculates the maximum number of digits needed for before and
212// after line numbers.
213func (dv *DiffView) detectNumDigits() {
214 dv.beforeNumDigits = 0
215 dv.afterNumDigits = 0
216
217 for _, h := range dv.unified.Hunks {
218 dv.beforeNumDigits = max(dv.beforeNumDigits, len(strconv.Itoa(h.FromLine+len(h.Lines))))
219 dv.afterNumDigits = max(dv.afterNumDigits, len(strconv.Itoa(h.ToLine+len(h.Lines))))
220 }
221}
222
223// detectCodeWidth calculates the maximum width of code lines in the diff view.
224func (dv *DiffView) detectCodeWidth() {
225 switch dv.layout {
226 case layoutUnified:
227 dv.detectUnifiedCodeWidth()
228 case layoutSplit:
229 dv.detectSplitCodeWidth()
230 }
231 dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
232}
233
234// detectUnifiedCodeWidth calculates the maximum width of code lines in a
235// unified diff.
236func (dv *DiffView) detectUnifiedCodeWidth() {
237 dv.codeWidth = 0
238
239 for _, h := range dv.unified.Hunks {
240 shownLines := ansi.StringWidth(dv.hunkLineFor(h))
241
242 for _, l := range h.Lines {
243 lineWidth := ansi.StringWidth(strings.TrimSuffix(l.Content, "\n")) + 1
244 dv.codeWidth = max(dv.codeWidth, lineWidth, shownLines)
245 }
246 }
247}
248
249// detectSplitCodeWidth calculates the maximum width of code lines in a
250// split diff.
251func (dv *DiffView) detectSplitCodeWidth() {
252 dv.codeWidth = 0
253
254 for i, h := range dv.splitHunks {
255 shownLines := ansi.StringWidth(dv.hunkLineFor(dv.unified.Hunks[i]))
256
257 for _, l := range h.lines {
258 if l.before != nil {
259 codeWidth := ansi.StringWidth(strings.TrimSuffix(l.before.Content, "\n")) + 1
260 dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
261 }
262 if l.after != nil {
263 codeWidth := ansi.StringWidth(strings.TrimSuffix(l.after.Content, "\n")) + 1
264 dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
265 }
266 }
267 }
268}
269
270// resizeCodeWidth resizes the code width to fit within the specified width.
271func (dv *DiffView) resizeCodeWidth() {
272 fullNumWidth := dv.beforeNumDigits + dv.afterNumDigits
273 fullNumWidth += lineNumPadding * 4 // left and right padding for both line numbers
274
275 switch dv.layout {
276 case layoutUnified:
277 dv.codeWidth = dv.width - fullNumWidth - leadingSymbolsSize
278 case layoutSplit:
279 remainingWidth := dv.width - fullNumWidth - leadingSymbolsSize*2
280 dv.codeWidth = remainingWidth / 2
281 dv.extraColOnAfter = isOdd(remainingWidth)
282 }
283
284 dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
285}
286
287// renderUnified renders the unified diff view as a string.
288func (dv *DiffView) renderUnified() string {
289 var b strings.Builder
290
291 fullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
292 printedLines := 0
293
294outer:
295 for i, h := range dv.unified.Hunks {
296 if dv.lineNumbers {
297 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
298 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
299 }
300 content := ansi.Truncate(dv.hunkLineFor(h), dv.fullCodeWidth, "…")
301 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
302 b.WriteRune('\n')
303 printedLines++
304
305 beforeLine := h.FromLine
306 afterLine := h.ToLine
307
308 for j, l := range h.Lines {
309 // print ellipis if we don't have enough space to print the rest of the diff
310 hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
311 isLastHunk := i+1 == len(dv.unified.Hunks)
312 isLastLine := j+1 == len(h.Lines)
313 if hasReachedHeight && (!isLastHunk || !isLastLine) {
314 if dv.lineNumbers {
315 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
316 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
317 }
318 b.WriteString(fullContentStyle.Render(
319 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" …"),
320 ))
321 b.WriteRune('\n')
322 break outer
323 }
324
325 content := strings.TrimSuffix(l.Content, "\n")
326 content = ansi.Truncate(content, dv.codeWidth, "…")
327
328 switch l.Kind {
329 case udiff.Equal:
330 if dv.lineNumbers {
331 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
332 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
333 }
334 b.WriteString(fullContentStyle.Render(
335 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" " + content),
336 ))
337 beforeLine++
338 afterLine++
339 case udiff.Insert:
340 if dv.lineNumbers {
341 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
342 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
343 }
344 b.WriteString(fullContentStyle.Render(
345 dv.style.InsertLine.Symbol.Render("+ ") +
346 dv.style.InsertLine.Code.Width(dv.codeWidth).Render(content),
347 ))
348 afterLine++
349 case udiff.Delete:
350 if dv.lineNumbers {
351 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
352 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
353 }
354 b.WriteString(fullContentStyle.Render(
355 dv.style.DeleteLine.Symbol.Render("- ") +
356 dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(content),
357 ))
358 beforeLine++
359 }
360 b.WriteRune('\n')
361
362 printedLines++
363 }
364 }
365
366 for printedLines < dv.height {
367 if dv.lineNumbers {
368 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
369 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
370 }
371 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
372 b.WriteRune('\n')
373 printedLines++
374 }
375
376 return b.String()
377}
378
379// renderSplit renders the split (side-by-side) diff view as a string.
380func (dv *DiffView) renderSplit() string {
381 var b strings.Builder
382
383 beforeFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
384 afterFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth + btoi(dv.extraColOnAfter))
385 printedLines := 0
386
387outer:
388 for i, h := range dv.splitHunks {
389 if dv.lineNumbers {
390 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
391 }
392 content := ansi.Truncate(dv.hunkLineFor(dv.unified.Hunks[i]), dv.fullCodeWidth, "…")
393 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
394 if dv.lineNumbers {
395 b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
396 }
397 b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
398 b.WriteRune('\n')
399 printedLines++
400
401 beforeLine := h.fromLine
402 afterLine := h.toLine
403
404 for j, l := range h.lines {
405 // print ellipis if we don't have enough space to print the rest of the diff
406 hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
407 isLastHunk := i+1 == len(dv.unified.Hunks)
408 isLastLine := j+1 == len(h.lines)
409 if hasReachedHeight && (!isLastHunk || !isLastLine) {
410 if dv.lineNumbers {
411 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
412 }
413 b.WriteString(beforeFullContentStyle.Render(
414 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" …"),
415 ))
416 if dv.lineNumbers {
417 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
418 }
419 b.WriteString(afterFullContentStyle.Render(
420 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" …"),
421 ))
422 b.WriteRune('\n')
423 break outer
424 }
425
426 var beforeContent string
427 var afterContent string
428 if l.before != nil {
429 beforeContent = strings.TrimSuffix(l.before.Content, "\n")
430 beforeContent = ansi.Truncate(beforeContent, dv.codeWidth, "…")
431 }
432 if l.after != nil {
433 afterContent = strings.TrimSuffix(l.after.Content, "\n")
434 afterContent = ansi.Truncate(afterContent, dv.codeWidth+btoi(dv.extraColOnAfter), "…")
435 }
436
437 switch {
438 case l.before == nil:
439 if dv.lineNumbers {
440 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
441 }
442 b.WriteString(beforeFullContentStyle.Render(
443 dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "),
444 ))
445 case l.before.Kind == udiff.Equal:
446 if dv.lineNumbers {
447 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
448 }
449 b.WriteString(beforeFullContentStyle.Render(
450 dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render(" " + beforeContent),
451 ))
452 beforeLine++
453 case l.before.Kind == udiff.Delete:
454 if dv.lineNumbers {
455 b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
456 }
457 b.WriteString(beforeFullContentStyle.Render(
458 dv.style.DeleteLine.Symbol.Render("- ") +
459 dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(beforeContent),
460 ))
461 beforeLine++
462 }
463
464 switch {
465 case l.after == nil:
466 if dv.lineNumbers {
467 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
468 }
469 b.WriteString(afterFullContentStyle.Render(
470 dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "),
471 ))
472 case l.after.Kind == udiff.Equal:
473 if dv.lineNumbers {
474 b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
475 }
476 b.WriteString(afterFullContentStyle.Render(
477 dv.style.EqualLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" " + afterContent),
478 ))
479 afterLine++
480 case l.after.Kind == udiff.Insert:
481 if dv.lineNumbers {
482 b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
483 }
484 b.WriteString(afterFullContentStyle.Render(
485 dv.style.InsertLine.Symbol.Render("+ ") +
486 dv.style.InsertLine.Code.Width(dv.codeWidth+btoi(dv.extraColOnAfter)).Render(afterContent),
487 ))
488 afterLine++
489 }
490
491 b.WriteRune('\n')
492
493 printedLines++
494 }
495 }
496
497 for printedLines < dv.height {
498 if dv.lineNumbers {
499 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
500 }
501 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
502 if dv.lineNumbers {
503 b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
504 }
505 b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
506 b.WriteRune('\n')
507 printedLines++
508 }
509
510 return b.String()
511}
512
513// hunkLineFor formats the header line for a hunk in the unified diff view.
514func (dv *DiffView) hunkLineFor(h *udiff.Hunk) string {
515 beforeShownLines, afterShownLines := dv.hunkShownLines(h)
516
517 return fmt.Sprintf(
518 " @@ -%d,%d +%d,%d @@ ",
519 h.FromLine,
520 beforeShownLines,
521 h.ToLine,
522 afterShownLines,
523 )
524}
525
526// hunkShownLines calculates the number of lines shown in a hunk for both before
527// and after versions.
528func (dv *DiffView) hunkShownLines(h *udiff.Hunk) (before, after int) {
529 for _, l := range h.Lines {
530 switch l.Kind {
531 case udiff.Equal:
532 before++
533 after++
534 case udiff.Insert:
535 after++
536 case udiff.Delete:
537 before++
538 }
539 }
540 return
541}