diffview.go

  1package diffview
  2
  3import (
  4	"fmt"
  5	"os"
  6	"strconv"
  7	"strings"
  8
  9	"github.com/aymanbagabas/go-udiff"
 10	"github.com/aymanbagabas/go-udiff/myers"
 11	"github.com/charmbracelet/lipgloss/v2"
 12	"github.com/charmbracelet/x/ansi"
 13)
 14
 15const (
 16	leadingSymbolsSize = 2
 17	lineNumPadding     = 1
 18)
 19
 20type file struct {
 21	path    string
 22	content string
 23}
 24
 25type layout int
 26
 27const (
 28	layoutUnified layout = iota + 1
 29	layoutSplit
 30)
 31
 32// DiffView represents a view for displaying differences between two files.
 33type DiffView struct {
 34	layout       layout
 35	before       file
 36	after        file
 37	contextLines int
 38	lineNumbers  bool
 39	highlight    bool
 40	height       int
 41	width        int
 42	style        Style
 43
 44	isComputed bool
 45	err        error
 46	unified    udiff.UnifiedDiff
 47	edits      []udiff.Edit
 48
 49	splitHunks []splitHunk
 50
 51	codeWidth       int
 52	fullCodeWidth   int  // with leading symbols
 53	extraColOnAfter bool // add extra column on after panel
 54	beforeNumDigits int
 55	afterNumDigits  int
 56}
 57
 58// New creates a new DiffView with default settings.
 59func New() *DiffView {
 60	dv := &DiffView{
 61		layout:       layoutUnified,
 62		contextLines: udiff.DefaultContextLines,
 63		lineNumbers:  true,
 64	}
 65	if lipgloss.HasDarkBackground(os.Stdin, os.Stdout) {
 66		dv.style = DefaultDarkStyle
 67	} else {
 68		dv.style = DefaultLightStyle
 69	}
 70	return dv
 71}
 72
 73// Unified sets the layout of the DiffView to unified.
 74func (dv *DiffView) Unified() *DiffView {
 75	dv.layout = layoutUnified
 76	return dv
 77}
 78
 79// Split sets the layout of the DiffView to split (side-by-side).
 80func (dv *DiffView) Split() *DiffView {
 81	dv.layout = layoutSplit
 82	return dv
 83}
 84
 85// Before sets the "before" file for the DiffView.
 86func (dv *DiffView) Before(path, content string) *DiffView {
 87	dv.before = file{path: path, content: content}
 88	return dv
 89}
 90
 91// After sets the "after" file for the DiffView.
 92func (dv *DiffView) After(path, content string) *DiffView {
 93	dv.after = file{path: path, content: content}
 94	return dv
 95}
 96
 97// ContextLines sets the number of context lines for the DiffView.
 98func (dv *DiffView) ContextLines(contextLines int) *DiffView {
 99	dv.contextLines = contextLines
100	return dv
101}
102
103// Style sets the style for the DiffView.
104func (dv *DiffView) Style(style Style) *DiffView {
105	dv.style = style
106	return dv
107}
108
109// LineNumbers sets whether to display line numbers in the DiffView.
110func (dv *DiffView) LineNumbers(lineNumbers bool) *DiffView {
111	dv.lineNumbers = lineNumbers
112	return dv
113}
114
115// SyntaxHightlight sets whether to enable syntax highlighting in the DiffView.
116func (dv *DiffView) SyntaxHightlight(highlight bool) *DiffView {
117	dv.highlight = highlight
118	return dv
119}
120
121// Height sets the height of the DiffView.
122func (dv *DiffView) Height(height int) *DiffView {
123	dv.height = height
124	return dv
125}
126
127// Width sets the width of the DiffView.
128func (dv *DiffView) Width(width int) *DiffView {
129	dv.width = width
130	return dv
131}
132
133// String returns the string representation of the DiffView.
134func (dv *DiffView) String() string {
135	if err := dv.computeDiff(); err != nil {
136		return err.Error()
137	}
138	dv.convertDiffToSplit()
139	dv.adjustStyles()
140	dv.detectNumDigits()
141
142	if dv.width <= 0 {
143		dv.detectCodeWidth()
144	} else {
145		dv.resizeCodeWidth()
146	}
147
148	style := lipgloss.NewStyle()
149	if dv.width > 0 {
150		style = style.MaxWidth(dv.width)
151	}
152	if dv.height > 0 {
153		style = style.MaxHeight(dv.height)
154	}
155
156	switch dv.layout {
157	case layoutUnified:
158		return style.Render(strings.TrimSuffix(dv.renderUnified(), "\n"))
159	case layoutSplit:
160		return style.Render(strings.TrimSuffix(dv.renderSplit(), "\n"))
161	default:
162		panic("unknown diffview layout")
163	}
164}
165
166// computeDiff computes the differences between the "before" and "after" files.
167func (dv *DiffView) computeDiff() error {
168	if dv.isComputed {
169		return dv.err
170	}
171	dv.isComputed = true
172	dv.edits = myers.ComputeEdits( //nolint:staticcheck
173		dv.before.content,
174		dv.after.content,
175	)
176	dv.unified, dv.err = udiff.ToUnifiedDiff(
177		dv.before.path,
178		dv.after.path,
179		dv.before.content,
180		dv.edits,
181		dv.contextLines,
182	)
183	return dv.err
184}
185
186// convertDiffToSplit converts the unified diff to a split diff if the layout is
187// set to split.
188func (dv *DiffView) convertDiffToSplit() {
189	if dv.layout != layoutSplit {
190		return
191	}
192
193	dv.splitHunks = make([]splitHunk, len(dv.unified.Hunks))
194	for i, h := range dv.unified.Hunks {
195		dv.splitHunks[i] = hunkToSplit(h)
196	}
197}
198
199// adjustStyles adjusts adds padding and alignment to the styles.
200func (dv *DiffView) adjustStyles() {
201	dv.style.MissingLine.LineNumber = setPadding(dv.style.MissingLine.LineNumber)
202	dv.style.DividerLine.LineNumber = setPadding(dv.style.DividerLine.LineNumber)
203	dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
204	dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
205	dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
206}
207
208// detectNumDigits calculates the maximum number of digits needed for before and
209// after line numbers.
210func (dv *DiffView) detectNumDigits() {
211	dv.beforeNumDigits = 0
212	dv.afterNumDigits = 0
213
214	for _, h := range dv.unified.Hunks {
215		dv.beforeNumDigits = max(dv.beforeNumDigits, len(strconv.Itoa(h.FromLine+len(h.Lines))))
216		dv.afterNumDigits = max(dv.afterNumDigits, len(strconv.Itoa(h.ToLine+len(h.Lines))))
217	}
218}
219
220func setPadding(s lipgloss.Style) lipgloss.Style {
221	return s.Padding(0, lineNumPadding).Align(lipgloss.Right)
222}
223
224// detectCodeWidth calculates the maximum width of code lines in the diff view.
225func (dv *DiffView) detectCodeWidth() {
226	switch dv.layout {
227	case layoutUnified:
228		dv.detectUnifiedCodeWidth()
229	case layoutSplit:
230		dv.detectSplitCodeWidth()
231	}
232	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
233}
234
235// detectUnifiedCodeWidth calculates the maximum width of code lines in a
236// unified diff.
237func (dv *DiffView) detectUnifiedCodeWidth() {
238	dv.codeWidth = 0
239
240	for _, h := range dv.unified.Hunks {
241		shownLines := ansi.StringWidth(dv.hunkLineFor(h))
242
243		for _, l := range h.Lines {
244			lineWidth := ansi.StringWidth(strings.TrimSuffix(l.Content, "\n")) + 1
245			dv.codeWidth = max(dv.codeWidth, lineWidth, shownLines)
246		}
247	}
248}
249
250// detectSplitCodeWidth calculates the maximum width of code lines in a
251// split diff.
252func (dv *DiffView) detectSplitCodeWidth() {
253	dv.codeWidth = 0
254
255	for i, h := range dv.splitHunks {
256		shownLines := ansi.StringWidth(dv.hunkLineFor(dv.unified.Hunks[i]))
257
258		for _, l := range h.lines {
259			if l.before != nil {
260				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.before.Content, "\n")) + 1
261				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
262			}
263			if l.after != nil {
264				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.after.Content, "\n")) + 1
265				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
266			}
267		}
268	}
269}
270
271// resizeCodeWidth resizes the code width to fit within the specified width.
272func (dv *DiffView) resizeCodeWidth() {
273	fullNumWidth := dv.beforeNumDigits + dv.afterNumDigits
274	fullNumWidth += lineNumPadding * 4 // left and right padding for both line numbers
275
276	switch dv.layout {
277	case layoutUnified:
278		dv.codeWidth = dv.width - fullNumWidth - leadingSymbolsSize
279	case layoutSplit:
280		remainingWidth := dv.width - fullNumWidth - leadingSymbolsSize*2
281		dv.codeWidth = remainingWidth / 2
282		dv.extraColOnAfter = isOdd(remainingWidth)
283	}
284
285	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
286}
287
288// renderUnified renders the unified diff view as a string.
289func (dv *DiffView) renderUnified() string {
290	var b strings.Builder
291
292	fullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
293	printedLines := 0
294
295outer:
296	for i, h := range dv.unified.Hunks {
297		if dv.lineNumbers {
298			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
299			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
300		}
301		content := ansi.Truncate(dv.hunkLineFor(h), dv.fullCodeWidth, "…")
302		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
303		b.WriteRune('\n')
304		printedLines++
305
306		beforeLine := h.FromLine
307		afterLine := h.ToLine
308
309		for j, l := range h.Lines {
310			// print ellipis if we don't have enough space to print the rest of the diff
311			hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
312			isLastHunk := i+1 == len(dv.unified.Hunks)
313			isLastLine := j+1 == len(h.Lines)
314			if hasReachedHeight && (!isLastHunk || !isLastLine) {
315				if dv.lineNumbers {
316					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
317					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
318				}
319				b.WriteString(fullContentStyle.Render(
320					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  …"),
321				))
322				b.WriteRune('\n')
323				break outer
324			}
325
326			content := strings.TrimSuffix(l.Content, "\n")
327			content = ansi.Truncate(content, dv.codeWidth, "…")
328
329			switch l.Kind {
330			case udiff.Equal:
331				if dv.lineNumbers {
332					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
333					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
334				}
335				b.WriteString(fullContentStyle.Render(
336					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + content),
337				))
338				beforeLine++
339				afterLine++
340			case udiff.Insert:
341				if dv.lineNumbers {
342					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
343					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
344				}
345				b.WriteString(fullContentStyle.Render(
346					dv.style.InsertLine.Symbol.Render("+ ") +
347						dv.style.InsertLine.Code.Width(dv.codeWidth).Render(content),
348				))
349				afterLine++
350			case udiff.Delete:
351				if dv.lineNumbers {
352					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
353					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
354				}
355				b.WriteString(fullContentStyle.Render(
356					dv.style.DeleteLine.Symbol.Render("- ") +
357						dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(content),
358				))
359				beforeLine++
360			}
361			b.WriteRune('\n')
362
363			printedLines++
364		}
365	}
366
367	for printedLines < dv.height {
368		if dv.lineNumbers {
369			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
370			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
371		}
372		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render("  "))
373		b.WriteRune('\n')
374		printedLines++
375	}
376
377	return b.String()
378}
379
380// renderSplit renders the split (side-by-side) diff view as a string.
381func (dv *DiffView) renderSplit() string {
382	var b strings.Builder
383
384	beforeFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
385	afterFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth + btoi(dv.extraColOnAfter))
386	printedLines := 0
387
388outer:
389	for i, h := range dv.splitHunks {
390		if dv.lineNumbers {
391			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
392		}
393		content := ansi.Truncate(dv.hunkLineFor(dv.unified.Hunks[i]), dv.fullCodeWidth, "…")
394		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
395		if dv.lineNumbers {
396			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
397		}
398		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
399		b.WriteRune('\n')
400		printedLines++
401
402		beforeLine := h.fromLine
403		afterLine := h.toLine
404
405		for j, l := range h.lines {
406			// print ellipis if we don't have enough space to print the rest of the diff
407			hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
408			isLastHunk := i+1 == len(dv.unified.Hunks)
409			isLastLine := j+1 == len(h.lines)
410			if hasReachedHeight && (!isLastHunk || !isLastLine) {
411				if dv.lineNumbers {
412					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
413				}
414				b.WriteString(beforeFullContentStyle.Render(
415					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  …"),
416				))
417				if dv.lineNumbers {
418					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
419				}
420				b.WriteString(afterFullContentStyle.Render(
421					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  …"),
422				))
423				b.WriteRune('\n')
424				break outer
425			}
426
427			var beforeContent string
428			var afterContent string
429			if l.before != nil {
430				beforeContent = strings.TrimSuffix(l.before.Content, "\n")
431				beforeContent = ansi.Truncate(beforeContent, dv.codeWidth, "…")
432			}
433			if l.after != nil {
434				afterContent = strings.TrimSuffix(l.after.Content, "\n")
435				afterContent = ansi.Truncate(afterContent, dv.codeWidth+btoi(dv.extraColOnAfter), "…")
436			}
437
438			switch {
439			case l.before == nil:
440				if dv.lineNumbers {
441					b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
442				}
443				b.WriteString(beforeFullContentStyle.Render(
444					dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render("  "),
445				))
446			case l.before.Kind == udiff.Equal:
447				if dv.lineNumbers {
448					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
449				}
450				b.WriteString(beforeFullContentStyle.Render(
451					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + beforeContent),
452				))
453				beforeLine++
454			case l.before.Kind == udiff.Delete:
455				if dv.lineNumbers {
456					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
457				}
458				b.WriteString(beforeFullContentStyle.Render(
459					dv.style.DeleteLine.Symbol.Render("- ") +
460						dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(beforeContent),
461				))
462				beforeLine++
463			}
464
465			switch {
466			case l.after == nil:
467				if dv.lineNumbers {
468					b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
469				}
470				b.WriteString(afterFullContentStyle.Render(
471					dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render("  "),
472				))
473			case l.after.Kind == udiff.Equal:
474				if dv.lineNumbers {
475					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
476				}
477				b.WriteString(afterFullContentStyle.Render(
478					dv.style.EqualLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render("  " + afterContent),
479				))
480				afterLine++
481			case l.after.Kind == udiff.Insert:
482				if dv.lineNumbers {
483					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
484				}
485				b.WriteString(afterFullContentStyle.Render(
486					dv.style.InsertLine.Symbol.Render("+ ") +
487						dv.style.InsertLine.Code.Width(dv.codeWidth+btoi(dv.extraColOnAfter)).Render(afterContent),
488				))
489				afterLine++
490			}
491
492			b.WriteRune('\n')
493
494			printedLines++
495		}
496	}
497
498	for printedLines < dv.height {
499		if dv.lineNumbers {
500			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
501		}
502		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
503		if dv.lineNumbers {
504			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
505		}
506		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
507		b.WriteRune('\n')
508		printedLines++
509	}
510
511	return b.String()
512}
513
514// hunkLineFor formats the header line for a hunk in the unified diff view.
515func (dv *DiffView) hunkLineFor(h *udiff.Hunk) string {
516	beforeShownLines, afterShownLines := dv.hunkShownLines(h)
517
518	return fmt.Sprintf(
519		"  @@ -%d,%d +%d,%d @@ ",
520		h.FromLine,
521		beforeShownLines,
522		h.ToLine,
523		afterShownLines,
524	)
525}
526
527// hunkShownLines calculates the number of lines shown in a hunk for both before
528// and after versions.
529func (dv *DiffView) hunkShownLines(h *udiff.Hunk) (before, after int) {
530	for _, l := range h.Lines {
531		switch l.Kind {
532		case udiff.Equal:
533			before++
534			after++
535		case udiff.Insert:
536			after++
537		case udiff.Delete:
538			before++
539		}
540	}
541	return
542}