diffview.go

  1package diffview
  2
  3import (
  4	"fmt"
  5	"os"
  6	"strconv"
  7	"strings"
  8
  9	"github.com/aymanbagabas/go-udiff"
 10	"github.com/aymanbagabas/go-udiff/myers"
 11	"github.com/charmbracelet/lipgloss/v2"
 12	"github.com/charmbracelet/x/ansi"
 13)
 14
 15const (
 16	leadingSymbolsSize = 2
 17	lineNumPadding     = 1
 18)
 19
 20type file struct {
 21	path    string
 22	content string
 23}
 24
 25type layout int
 26
 27const (
 28	layoutUnified layout = iota + 1
 29	layoutSplit
 30)
 31
 32// DiffView represents a view for displaying differences between two files.
 33type DiffView struct {
 34	layout       layout
 35	before       file
 36	after        file
 37	contextLines int
 38	lineNumbers  bool
 39	highlight    bool
 40	height       int
 41	width        int
 42	style        Style
 43
 44	isComputed bool
 45	err        error
 46	unified    udiff.UnifiedDiff
 47	edits      []udiff.Edit
 48
 49	splitHunks []splitHunk
 50
 51	codeWidth       int
 52	fullCodeWidth   int  // with leading symbols
 53	extraColOnAfter bool // add extra column on after panel
 54	beforeNumDigits int
 55	afterNumDigits  int
 56}
 57
 58// New creates a new DiffView with default settings.
 59func New() *DiffView {
 60	dv := &DiffView{
 61		layout:       layoutUnified,
 62		contextLines: udiff.DefaultContextLines,
 63		lineNumbers:  true,
 64	}
 65	if lipgloss.HasDarkBackground(os.Stdin, os.Stdout) {
 66		dv.style = DefaultDarkStyle
 67	} else {
 68		dv.style = DefaultLightStyle
 69	}
 70	return dv
 71}
 72
 73// Unified sets the layout of the DiffView to unified.
 74func (dv *DiffView) Unified() *DiffView {
 75	dv.layout = layoutUnified
 76	return dv
 77}
 78
 79// Split sets the layout of the DiffView to split (side-by-side).
 80func (dv *DiffView) Split() *DiffView {
 81	dv.layout = layoutSplit
 82	return dv
 83}
 84
 85// Before sets the "before" file for the DiffView.
 86func (dv *DiffView) Before(path, content string) *DiffView {
 87	dv.before = file{path: path, content: content}
 88	return dv
 89}
 90
 91// After sets the "after" file for the DiffView.
 92func (dv *DiffView) After(path, content string) *DiffView {
 93	dv.after = file{path: path, content: content}
 94	return dv
 95}
 96
 97// ContextLines sets the number of context lines for the DiffView.
 98func (dv *DiffView) ContextLines(contextLines int) *DiffView {
 99	dv.contextLines = contextLines
100	return dv
101}
102
103// Style sets the style for the DiffView.
104func (dv *DiffView) Style(style Style) *DiffView {
105	dv.style = style
106	return dv
107}
108
109// LineNumbers sets whether to display line numbers in the DiffView.
110func (dv *DiffView) LineNumbers(lineNumbers bool) *DiffView {
111	dv.lineNumbers = lineNumbers
112	return dv
113}
114
115// SyntaxHightlight sets whether to enable syntax highlighting in the DiffView.
116func (dv *DiffView) SyntaxHightlight(highlight bool) *DiffView {
117	dv.highlight = highlight
118	return dv
119}
120
121// Height sets the height of the DiffView.
122func (dv *DiffView) Height(height int) *DiffView {
123	dv.height = height
124	return dv
125}
126
127// Width sets the width of the DiffView.
128func (dv *DiffView) Width(width int) *DiffView {
129	dv.width = width
130	return dv
131}
132
133// String returns the string representation of the DiffView.
134func (dv *DiffView) String() string {
135	if err := dv.computeDiff(); err != nil {
136		return err.Error()
137	}
138	dv.convertDiffToSplit()
139	dv.adjustStyles()
140	dv.detectNumDigits()
141
142	if dv.width <= 0 {
143		dv.detectCodeWidth()
144	} else {
145		dv.resizeCodeWidth()
146	}
147
148	style := lipgloss.NewStyle()
149	if dv.width > 0 {
150		style = style.MaxWidth(dv.width)
151	}
152	if dv.height > 0 {
153		style = style.MaxHeight(dv.height)
154	}
155
156	switch dv.layout {
157	case layoutUnified:
158		return style.Render(strings.TrimSuffix(dv.renderUnified(), "\n"))
159	case layoutSplit:
160		return style.Render(strings.TrimSuffix(dv.renderSplit(), "\n"))
161	default:
162		panic("unknown diffview layout")
163	}
164}
165
166// computeDiff computes the differences between the "before" and "after" files.
167func (dv *DiffView) computeDiff() error {
168	if dv.isComputed {
169		return dv.err
170	}
171	dv.isComputed = true
172	dv.edits = myers.ComputeEdits( //nolint:staticcheck
173		dv.before.content,
174		dv.after.content,
175	)
176	dv.unified, dv.err = udiff.ToUnifiedDiff(
177		dv.before.path,
178		dv.after.path,
179		dv.before.content,
180		dv.edits,
181		dv.contextLines,
182	)
183	return dv.err
184}
185
186// convertDiffToSplit converts the unified diff to a split diff if the layout is
187// set to split.
188func (dv *DiffView) convertDiffToSplit() {
189	if dv.layout != layoutSplit {
190		return
191	}
192
193	dv.splitHunks = make([]splitHunk, len(dv.unified.Hunks))
194	for i, h := range dv.unified.Hunks {
195		dv.splitHunks[i] = hunkToSplit(h)
196	}
197}
198
199// adjustStyles adjusts adds padding and alignment to the styles.
200func (dv *DiffView) adjustStyles() {
201	setPadding := func(s lipgloss.Style) lipgloss.Style {
202		return s.Padding(0, lineNumPadding).Align(lipgloss.Right)
203	}
204	dv.style.MissingLine.LineNumber = setPadding(dv.style.MissingLine.LineNumber)
205	dv.style.DividerLine.LineNumber = setPadding(dv.style.DividerLine.LineNumber)
206	dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
207	dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
208	dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
209}
210
211// detectNumDigits calculates the maximum number of digits needed for before and
212// after line numbers.
213func (dv *DiffView) detectNumDigits() {
214	dv.beforeNumDigits = 0
215	dv.afterNumDigits = 0
216
217	for _, h := range dv.unified.Hunks {
218		dv.beforeNumDigits = max(dv.beforeNumDigits, len(strconv.Itoa(h.FromLine+len(h.Lines))))
219		dv.afterNumDigits = max(dv.afterNumDigits, len(strconv.Itoa(h.ToLine+len(h.Lines))))
220	}
221}
222
223// detectCodeWidth calculates the maximum width of code lines in the diff view.
224func (dv *DiffView) detectCodeWidth() {
225	switch dv.layout {
226	case layoutUnified:
227		dv.detectUnifiedCodeWidth()
228	case layoutSplit:
229		dv.detectSplitCodeWidth()
230	}
231	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
232}
233
234// detectUnifiedCodeWidth calculates the maximum width of code lines in a
235// unified diff.
236func (dv *DiffView) detectUnifiedCodeWidth() {
237	dv.codeWidth = 0
238
239	for _, h := range dv.unified.Hunks {
240		shownLines := ansi.StringWidth(dv.hunkLineFor(h))
241
242		for _, l := range h.Lines {
243			lineWidth := ansi.StringWidth(strings.TrimSuffix(l.Content, "\n")) + 1
244			dv.codeWidth = max(dv.codeWidth, lineWidth, shownLines)
245		}
246	}
247}
248
249// detectSplitCodeWidth calculates the maximum width of code lines in a
250// split diff.
251func (dv *DiffView) detectSplitCodeWidth() {
252	dv.codeWidth = 0
253
254	for i, h := range dv.splitHunks {
255		shownLines := ansi.StringWidth(dv.hunkLineFor(dv.unified.Hunks[i]))
256
257		for _, l := range h.lines {
258			if l.before != nil {
259				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.before.Content, "\n")) + 1
260				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
261			}
262			if l.after != nil {
263				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.after.Content, "\n")) + 1
264				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
265			}
266		}
267	}
268}
269
270// resizeCodeWidth resizes the code width to fit within the specified width.
271func (dv *DiffView) resizeCodeWidth() {
272	fullNumWidth := dv.beforeNumDigits + dv.afterNumDigits
273	fullNumWidth += lineNumPadding * 4 // left and right padding for both line numbers
274
275	switch dv.layout {
276	case layoutUnified:
277		dv.codeWidth = dv.width - fullNumWidth - leadingSymbolsSize
278	case layoutSplit:
279		remainingWidth := dv.width - fullNumWidth - leadingSymbolsSize*2
280		dv.codeWidth = remainingWidth / 2
281		dv.extraColOnAfter = isOdd(remainingWidth)
282	}
283
284	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
285}
286
287// renderUnified renders the unified diff view as a string.
288func (dv *DiffView) renderUnified() string {
289	var b strings.Builder
290
291	fullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
292	printedLines := 0
293
294outer:
295	for i, h := range dv.unified.Hunks {
296		if dv.lineNumbers {
297			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
298			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
299		}
300		content := ansi.Truncate(dv.hunkLineFor(h), dv.fullCodeWidth, "…")
301		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
302		b.WriteRune('\n')
303		printedLines++
304
305		beforeLine := h.FromLine
306		afterLine := h.ToLine
307
308		for j, l := range h.Lines {
309			// print ellipis if we don't have enough space to print the rest of the diff
310			hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
311			isLastHunk := i+1 == len(dv.unified.Hunks)
312			isLastLine := j+1 == len(h.Lines)
313			if hasReachedHeight && (!isLastHunk || !isLastLine) {
314				if dv.lineNumbers {
315					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
316					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
317				}
318				b.WriteString(fullContentStyle.Render(
319					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  …"),
320				))
321				b.WriteRune('\n')
322				break outer
323			}
324
325			content := strings.TrimSuffix(l.Content, "\n")
326			content = ansi.Truncate(content, dv.codeWidth, "…")
327
328			switch l.Kind {
329			case udiff.Equal:
330				if dv.lineNumbers {
331					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
332					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
333				}
334				b.WriteString(fullContentStyle.Render(
335					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + content),
336				))
337				beforeLine++
338				afterLine++
339			case udiff.Insert:
340				if dv.lineNumbers {
341					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
342					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
343				}
344				b.WriteString(fullContentStyle.Render(
345					dv.style.InsertLine.Symbol.Render("+ ") +
346						dv.style.InsertLine.Code.Width(dv.codeWidth).Render(content),
347				))
348				afterLine++
349			case udiff.Delete:
350				if dv.lineNumbers {
351					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
352					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
353				}
354				b.WriteString(fullContentStyle.Render(
355					dv.style.DeleteLine.Symbol.Render("- ") +
356						dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(content),
357				))
358				beforeLine++
359			}
360			b.WriteRune('\n')
361
362			printedLines++
363		}
364	}
365
366	for printedLines < dv.height {
367		if dv.lineNumbers {
368			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
369			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
370		}
371		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render("  "))
372		b.WriteRune('\n')
373		printedLines++
374	}
375
376	return b.String()
377}
378
379// renderSplit renders the split (side-by-side) diff view as a string.
380func (dv *DiffView) renderSplit() string {
381	var b strings.Builder
382
383	beforeFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
384	afterFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth + btoi(dv.extraColOnAfter))
385	printedLines := 0
386
387outer:
388	for i, h := range dv.splitHunks {
389		if dv.lineNumbers {
390			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
391		}
392		content := ansi.Truncate(dv.hunkLineFor(dv.unified.Hunks[i]), dv.fullCodeWidth, "…")
393		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
394		if dv.lineNumbers {
395			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
396		}
397		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
398		b.WriteRune('\n')
399		printedLines++
400
401		beforeLine := h.fromLine
402		afterLine := h.toLine
403
404		for j, l := range h.lines {
405			// print ellipis if we don't have enough space to print the rest of the diff
406			hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
407			isLastHunk := i+1 == len(dv.unified.Hunks)
408			isLastLine := j+1 == len(h.lines)
409			if hasReachedHeight && (!isLastHunk || !isLastLine) {
410				if dv.lineNumbers {
411					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
412				}
413				b.WriteString(beforeFullContentStyle.Render(
414					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  …"),
415				))
416				if dv.lineNumbers {
417					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
418				}
419				b.WriteString(afterFullContentStyle.Render(
420					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  …"),
421				))
422				b.WriteRune('\n')
423				break outer
424			}
425
426			var beforeContent string
427			var afterContent string
428			if l.before != nil {
429				beforeContent = strings.TrimSuffix(l.before.Content, "\n")
430				beforeContent = ansi.Truncate(beforeContent, dv.codeWidth, "…")
431			}
432			if l.after != nil {
433				afterContent = strings.TrimSuffix(l.after.Content, "\n")
434				afterContent = ansi.Truncate(afterContent, dv.codeWidth+btoi(dv.extraColOnAfter), "…")
435			}
436
437			switch {
438			case l.before == nil:
439				if dv.lineNumbers {
440					b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
441				}
442				b.WriteString(beforeFullContentStyle.Render(
443					dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render("  "),
444				))
445			case l.before.Kind == udiff.Equal:
446				if dv.lineNumbers {
447					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
448				}
449				b.WriteString(beforeFullContentStyle.Render(
450					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + beforeContent),
451				))
452				beforeLine++
453			case l.before.Kind == udiff.Delete:
454				if dv.lineNumbers {
455					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
456				}
457				b.WriteString(beforeFullContentStyle.Render(
458					dv.style.DeleteLine.Symbol.Render("- ") +
459						dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(beforeContent),
460				))
461				beforeLine++
462			}
463
464			switch {
465			case l.after == nil:
466				if dv.lineNumbers {
467					b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
468				}
469				b.WriteString(afterFullContentStyle.Render(
470					dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render("  "),
471				))
472			case l.after.Kind == udiff.Equal:
473				if dv.lineNumbers {
474					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
475				}
476				b.WriteString(afterFullContentStyle.Render(
477					dv.style.EqualLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render("  " + afterContent),
478				))
479				afterLine++
480			case l.after.Kind == udiff.Insert:
481				if dv.lineNumbers {
482					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
483				}
484				b.WriteString(afterFullContentStyle.Render(
485					dv.style.InsertLine.Symbol.Render("+ ") +
486						dv.style.InsertLine.Code.Width(dv.codeWidth+btoi(dv.extraColOnAfter)).Render(afterContent),
487				))
488				afterLine++
489			}
490
491			b.WriteRune('\n')
492
493			printedLines++
494		}
495	}
496
497	for printedLines < dv.height {
498		if dv.lineNumbers {
499			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
500		}
501		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
502		if dv.lineNumbers {
503			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
504		}
505		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
506		b.WriteRune('\n')
507		printedLines++
508	}
509
510	return b.String()
511}
512
513// hunkLineFor formats the header line for a hunk in the unified diff view.
514func (dv *DiffView) hunkLineFor(h *udiff.Hunk) string {
515	beforeShownLines, afterShownLines := dv.hunkShownLines(h)
516
517	return fmt.Sprintf(
518		"  @@ -%d,%d +%d,%d @@ ",
519		h.FromLine,
520		beforeShownLines,
521		h.ToLine,
522		afterShownLines,
523	)
524}
525
526// hunkShownLines calculates the number of lines shown in a hunk for both before
527// and after versions.
528func (dv *DiffView) hunkShownLines(h *udiff.Hunk) (before, after int) {
529	for _, l := range h.Lines {
530		switch l.Kind {
531		case udiff.Equal:
532			before++
533			after++
534		case udiff.Insert:
535			after++
536		case udiff.Delete:
537			before++
538		}
539	}
540	return
541}