diffview.go

  1package diffview
  2
  3import (
  4	"fmt"
  5	"image/color"
  6	"os"
  7	"strconv"
  8	"strings"
  9
 10	"github.com/alecthomas/chroma/v2"
 11	"github.com/alecthomas/chroma/v2/lexers"
 12	"github.com/aymanbagabas/go-udiff"
 13	"github.com/aymanbagabas/go-udiff/myers"
 14	"github.com/charmbracelet/lipgloss/v2"
 15	"github.com/charmbracelet/x/ansi"
 16)
 17
 18const (
 19	leadingSymbolsSize = 2
 20	lineNumPadding     = 1
 21)
 22
 23type file struct {
 24	path    string
 25	content string
 26}
 27
 28type layout int
 29
 30const (
 31	layoutUnified layout = iota + 1
 32	layoutSplit
 33)
 34
 35// DiffView represents a view for displaying differences between two files.
 36type DiffView struct {
 37	layout       layout
 38	before       file
 39	after        file
 40	contextLines int
 41	lineNumbers  bool
 42	highlight    bool
 43	height       int
 44	width        int
 45	xOffset      int
 46	yOffset      int
 47	style        Style
 48	tabWidth     int
 49	chromaStyle  *chroma.Style
 50
 51	isComputed bool
 52	err        error
 53	unified    udiff.UnifiedDiff
 54	edits      []udiff.Edit
 55
 56	splitHunks []splitHunk
 57
 58	codeWidth       int
 59	fullCodeWidth   int  // with leading symbols
 60	extraColOnAfter bool // add extra column on after panel
 61	beforeNumDigits int
 62	afterNumDigits  int
 63}
 64
 65// New creates a new DiffView with default settings.
 66func New() *DiffView {
 67	dv := &DiffView{
 68		layout:       layoutUnified,
 69		contextLines: udiff.DefaultContextLines,
 70		lineNumbers:  true,
 71		tabWidth:     8,
 72	}
 73	if lipgloss.HasDarkBackground(os.Stdin, os.Stdout) {
 74		dv.style = DefaultDarkStyle
 75	} else {
 76		dv.style = DefaultLightStyle
 77	}
 78	return dv
 79}
 80
 81// Unified sets the layout of the DiffView to unified.
 82func (dv *DiffView) Unified() *DiffView {
 83	dv.layout = layoutUnified
 84	return dv
 85}
 86
 87// Split sets the layout of the DiffView to split (side-by-side).
 88func (dv *DiffView) Split() *DiffView {
 89	dv.layout = layoutSplit
 90	return dv
 91}
 92
 93// Before sets the "before" file for the DiffView.
 94func (dv *DiffView) Before(path, content string) *DiffView {
 95	dv.before = file{path: path, content: content}
 96	return dv
 97}
 98
 99// After sets the "after" file for the DiffView.
100func (dv *DiffView) After(path, content string) *DiffView {
101	dv.after = file{path: path, content: content}
102	return dv
103}
104
105// ContextLines sets the number of context lines for the DiffView.
106func (dv *DiffView) ContextLines(contextLines int) *DiffView {
107	dv.contextLines = contextLines
108	return dv
109}
110
111// Style sets the style for the DiffView.
112func (dv *DiffView) Style(style Style) *DiffView {
113	dv.style = style
114	return dv
115}
116
117// LineNumbers sets whether to display line numbers in the DiffView.
118func (dv *DiffView) LineNumbers(lineNumbers bool) *DiffView {
119	dv.lineNumbers = lineNumbers
120	return dv
121}
122
123// SyntaxHightlight sets whether to enable syntax highlighting in the DiffView.
124func (dv *DiffView) SyntaxHightlight(highlight bool) *DiffView {
125	dv.highlight = highlight
126	return dv
127}
128
129// Height sets the height of the DiffView.
130func (dv *DiffView) Height(height int) *DiffView {
131	dv.height = height
132	return dv
133}
134
135// Width sets the width of the DiffView.
136func (dv *DiffView) Width(width int) *DiffView {
137	dv.width = width
138	return dv
139}
140
141// XOffset sets the horizontal offset for the DiffView.
142func (dv *DiffView) XOffset(xOffset int) *DiffView {
143	dv.xOffset = xOffset
144	return dv
145}
146
147// YOffset sets the vertical offset for the DiffView.
148func (dv *DiffView) YOffset(yOffset int) *DiffView {
149	dv.yOffset = yOffset
150	return dv
151}
152
153// TabWidth sets the tab width. Only relevant for code that contains tabs, like
154// Go code.
155func (dv *DiffView) TabWidth(tabWidth int) *DiffView {
156	dv.tabWidth = tabWidth
157	return dv
158}
159
160// ChromaStyle sets the chroma style for syntax highlighting.
161// If nil, no syntax highlighting will be applied.
162func (dv *DiffView) ChromaStyle(style *chroma.Style) *DiffView {
163	dv.chromaStyle = style
164	return dv
165}
166
167// String returns the string representation of the DiffView.
168func (dv *DiffView) String() string {
169	dv.replaceTabs()
170	if err := dv.computeDiff(); err != nil {
171		return err.Error()
172	}
173	dv.convertDiffToSplit()
174	dv.adjustStyles()
175	dv.detectNumDigits()
176
177	if dv.width <= 0 {
178		dv.detectCodeWidth()
179	} else {
180		dv.resizeCodeWidth()
181	}
182
183	style := lipgloss.NewStyle()
184	if dv.width > 0 {
185		style = style.MaxWidth(dv.width)
186	}
187	if dv.height > 0 {
188		style = style.MaxHeight(dv.height)
189	}
190
191	switch dv.layout {
192	case layoutUnified:
193		return style.Render(strings.TrimSuffix(dv.renderUnified(), "\n"))
194	case layoutSplit:
195		return style.Render(strings.TrimSuffix(dv.renderSplit(), "\n"))
196	default:
197		panic("unknown diffview layout")
198	}
199}
200
201// replaceTabs replaces tabs in the before and after file contents with spaces
202// according to the specified tab width.
203func (dv *DiffView) replaceTabs() {
204	spaces := strings.Repeat(" ", dv.tabWidth)
205	dv.before.content = strings.ReplaceAll(dv.before.content, "\t", spaces)
206	dv.after.content = strings.ReplaceAll(dv.after.content, "\t", spaces)
207}
208
209// computeDiff computes the differences between the "before" and "after" files.
210func (dv *DiffView) computeDiff() error {
211	if dv.isComputed {
212		return dv.err
213	}
214	dv.isComputed = true
215	dv.edits = myers.ComputeEdits( //nolint:staticcheck
216		dv.before.content,
217		dv.after.content,
218	)
219	dv.unified, dv.err = udiff.ToUnifiedDiff(
220		dv.before.path,
221		dv.after.path,
222		dv.before.content,
223		dv.edits,
224		dv.contextLines,
225	)
226	return dv.err
227}
228
229// convertDiffToSplit converts the unified diff to a split diff if the layout is
230// set to split.
231func (dv *DiffView) convertDiffToSplit() {
232	if dv.layout != layoutSplit {
233		return
234	}
235
236	dv.splitHunks = make([]splitHunk, len(dv.unified.Hunks))
237	for i, h := range dv.unified.Hunks {
238		dv.splitHunks[i] = hunkToSplit(h)
239	}
240}
241
242// adjustStyles adjusts adds padding and alignment to the styles.
243func (dv *DiffView) adjustStyles() {
244	setPadding := func(s lipgloss.Style) lipgloss.Style {
245		return s.Padding(0, lineNumPadding).Align(lipgloss.Right)
246	}
247	dv.style.MissingLine.LineNumber = setPadding(dv.style.MissingLine.LineNumber)
248	dv.style.DividerLine.LineNumber = setPadding(dv.style.DividerLine.LineNumber)
249	dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
250	dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
251	dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
252}
253
254// detectNumDigits calculates the maximum number of digits needed for before and
255// after line numbers.
256func (dv *DiffView) detectNumDigits() {
257	dv.beforeNumDigits = 0
258	dv.afterNumDigits = 0
259
260	for _, h := range dv.unified.Hunks {
261		dv.beforeNumDigits = max(dv.beforeNumDigits, len(strconv.Itoa(h.FromLine+len(h.Lines))))
262		dv.afterNumDigits = max(dv.afterNumDigits, len(strconv.Itoa(h.ToLine+len(h.Lines))))
263	}
264}
265
266// detectCodeWidth calculates the maximum width of code lines in the diff view.
267func (dv *DiffView) detectCodeWidth() {
268	switch dv.layout {
269	case layoutUnified:
270		dv.detectUnifiedCodeWidth()
271	case layoutSplit:
272		dv.detectSplitCodeWidth()
273	}
274	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
275}
276
277// detectUnifiedCodeWidth calculates the maximum width of code lines in a
278// unified diff.
279func (dv *DiffView) detectUnifiedCodeWidth() {
280	dv.codeWidth = 0
281
282	for _, h := range dv.unified.Hunks {
283		shownLines := ansi.StringWidth(dv.hunkLineFor(h))
284
285		for _, l := range h.Lines {
286			lineWidth := ansi.StringWidth(strings.TrimSuffix(l.Content, "\n")) + 1
287			dv.codeWidth = max(dv.codeWidth, lineWidth, shownLines)
288		}
289	}
290}
291
292// detectSplitCodeWidth calculates the maximum width of code lines in a
293// split diff.
294func (dv *DiffView) detectSplitCodeWidth() {
295	dv.codeWidth = 0
296
297	for i, h := range dv.splitHunks {
298		shownLines := ansi.StringWidth(dv.hunkLineFor(dv.unified.Hunks[i]))
299
300		for _, l := range h.lines {
301			if l.before != nil {
302				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.before.Content, "\n")) + 1
303				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
304			}
305			if l.after != nil {
306				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.after.Content, "\n")) + 1
307				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
308			}
309		}
310	}
311}
312
313// resizeCodeWidth resizes the code width to fit within the specified width.
314func (dv *DiffView) resizeCodeWidth() {
315	fullNumWidth := dv.beforeNumDigits + dv.afterNumDigits
316	fullNumWidth += lineNumPadding * 4 // left and right padding for both line numbers
317
318	switch dv.layout {
319	case layoutUnified:
320		dv.codeWidth = dv.width - fullNumWidth - leadingSymbolsSize
321	case layoutSplit:
322		remainingWidth := dv.width - fullNumWidth - leadingSymbolsSize*2
323		dv.codeWidth = remainingWidth / 2
324		dv.extraColOnAfter = isOdd(remainingWidth)
325	}
326
327	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
328}
329
330// renderUnified renders the unified diff view as a string.
331func (dv *DiffView) renderUnified() string {
332	var b strings.Builder
333
334	fullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
335	printedLines := -dv.yOffset
336
337	write := func(s string) {
338		if printedLines >= 0 {
339			b.WriteString(s)
340		}
341	}
342
343outer:
344	for i, h := range dv.unified.Hunks {
345		ls := dv.style.DividerLine
346		if dv.lineNumbers {
347			write(ls.LineNumber.Render(pad("…", dv.beforeNumDigits)))
348			write(ls.LineNumber.Render(pad("…", dv.afterNumDigits)))
349		}
350		content := ansi.Truncate(dv.hunkLineFor(h), dv.fullCodeWidth, "…")
351		write(ls.Code.Width(dv.fullCodeWidth).Render(content))
352		write("\n")
353		printedLines++
354
355		beforeLine := h.FromLine
356		afterLine := h.ToLine
357
358		for j, l := range h.Lines {
359			// print ellipis if we don't have enough space to print the rest of the diff
360			hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
361			isLastHunk := i+1 == len(dv.unified.Hunks)
362			isLastLine := j+1 == len(h.Lines)
363			if hasReachedHeight && (!isLastHunk || !isLastLine) {
364				ls := dv.lineStyleForType(l.Kind)
365				if dv.lineNumbers {
366					write(ls.LineNumber.Render(pad("…", dv.beforeNumDigits)))
367					write(ls.LineNumber.Render(pad("…", dv.afterNumDigits)))
368				}
369				write(fullContentStyle.Render(
370					ls.Code.Width(dv.fullCodeWidth).Render("  …"),
371				))
372				write("\n")
373				break outer
374			}
375
376			getContent := func(ls LineStyle) string {
377				content := strings.TrimSuffix(l.Content, "\n")
378				content = dv.hightlightCode(content, ls.Code.GetBackground())
379				content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content))
380				content = ansi.Truncate(content, dv.codeWidth, "…")
381				return content
382			}
383
384			leadingEllipsis := dv.xOffset > 0 && strings.TrimSpace(content) != ""
385
386			switch l.Kind {
387			case udiff.Equal:
388				ls := dv.style.EqualLine
389				content := getContent(ls)
390				if dv.lineNumbers {
391					write(ls.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
392					write(ls.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
393				}
394				write(fullContentStyle.Render(
395					ls.Code.Width(dv.fullCodeWidth).Render(ternary(leadingEllipsis, " …", "  ") + content),
396				))
397				beforeLine++
398				afterLine++
399			case udiff.Insert:
400				ls := dv.style.InsertLine
401				content := getContent(ls)
402				if dv.lineNumbers {
403					write(ls.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
404					write(ls.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
405				}
406				write(fullContentStyle.Render(
407					ls.Symbol.Render(ternary(leadingEllipsis, "+…", "+ ")) +
408						ls.Code.Width(dv.codeWidth).Render(content),
409				))
410				afterLine++
411			case udiff.Delete:
412				ls := dv.style.DeleteLine
413				content := getContent(ls)
414				if dv.lineNumbers {
415					write(ls.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
416					write(ls.LineNumber.Render(pad(" ", dv.afterNumDigits)))
417				}
418				write(fullContentStyle.Render(
419					ls.Symbol.Render(ternary(leadingEllipsis, "-…", "- ")) +
420						ls.Code.Width(dv.codeWidth).Render(content),
421				))
422				beforeLine++
423			}
424			write("\n")
425
426			printedLines++
427		}
428	}
429
430	for printedLines < dv.height {
431		ls := dv.style.MissingLine
432		if dv.lineNumbers {
433			write(ls.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
434			write(ls.LineNumber.Render(pad(" ", dv.afterNumDigits)))
435		}
436		write(ls.Code.Width(dv.fullCodeWidth).Render("  "))
437		write("\n")
438		printedLines++
439	}
440
441	return b.String()
442}
443
444// renderSplit renders the split (side-by-side) diff view as a string.
445func (dv *DiffView) renderSplit() string {
446	var b strings.Builder
447
448	beforeFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
449	afterFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth + btoi(dv.extraColOnAfter))
450	printedLines := -dv.yOffset
451
452	write := func(s string) {
453		if printedLines >= 0 {
454			b.WriteString(s)
455		}
456	}
457
458outer:
459	for i, h := range dv.splitHunks {
460		ls := dv.style.DividerLine
461		if dv.lineNumbers {
462			write(ls.LineNumber.Render(pad("…", dv.beforeNumDigits)))
463		}
464		content := ansi.Truncate(dv.hunkLineFor(dv.unified.Hunks[i]), dv.fullCodeWidth, "…")
465		write(ls.Code.Width(dv.fullCodeWidth).Render(content))
466		if dv.lineNumbers {
467			write(ls.LineNumber.Render(pad("…", dv.afterNumDigits)))
468		}
469		write(ls.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
470		write("\n")
471		printedLines++
472
473		beforeLine := h.fromLine
474		afterLine := h.toLine
475
476		for j, l := range h.lines {
477			// print ellipis if we don't have enough space to print the rest of the diff
478			hasReachedHeight := dv.height > 0 && printedLines+1 == dv.height
479			isLastHunk := i+1 == len(dv.unified.Hunks)
480			isLastLine := j+1 == len(h.lines)
481			if hasReachedHeight && (!isLastHunk || !isLastLine) {
482				ls := dv.style.MissingLine
483				if l.before != nil {
484					ls = dv.lineStyleForType(l.before.Kind)
485				}
486				if dv.lineNumbers {
487					write(ls.LineNumber.Render(pad("…", dv.beforeNumDigits)))
488				}
489				write(beforeFullContentStyle.Render(
490					ls.Code.Width(dv.fullCodeWidth).Render("  …"),
491				))
492				ls = dv.style.MissingLine
493				if l.after != nil {
494					ls = dv.lineStyleForType(l.after.Kind)
495				}
496				if dv.lineNumbers {
497					write(ls.LineNumber.Render(pad("…", dv.afterNumDigits)))
498				}
499				write(afterFullContentStyle.Render(
500					ls.Code.Width(dv.fullCodeWidth).Render("  …"),
501				))
502				write("\n")
503				break outer
504			}
505
506			getContent := func(content string, ls LineStyle) string {
507				content = strings.TrimSuffix(content, "\n")
508				content = dv.hightlightCode(content, ls.Code.GetBackground())
509				content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content))
510				content = ansi.Truncate(content, dv.codeWidth, "…")
511				return content
512			}
513			getLeadingEllipsis := func(content string) bool {
514				return dv.xOffset > 0 && strings.TrimSpace(content) != ""
515			}
516
517			switch {
518			case l.before == nil:
519				ls := dv.style.MissingLine
520				if dv.lineNumbers {
521					write(ls.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
522				}
523				write(beforeFullContentStyle.Render(
524					ls.Code.Width(dv.fullCodeWidth).Render("  "),
525				))
526			case l.before.Kind == udiff.Equal:
527				ls := dv.style.EqualLine
528				content := getContent(l.before.Content, ls)
529				leadingEllipsis := getLeadingEllipsis(content)
530				if dv.lineNumbers {
531					write(ls.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
532				}
533				write(beforeFullContentStyle.Render(
534					ls.Code.Width(dv.fullCodeWidth).Render(ternary(leadingEllipsis, " …", "  ") + content),
535				))
536				beforeLine++
537			case l.before.Kind == udiff.Delete:
538				ls := dv.style.DeleteLine
539				content := getContent(l.before.Content, ls)
540				leadingEllipsis := getLeadingEllipsis(content)
541				if dv.lineNumbers {
542					write(ls.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
543				}
544				write(beforeFullContentStyle.Render(
545					ls.Symbol.Render(ternary(leadingEllipsis, "-…", "- ")) +
546						ls.Code.Width(dv.codeWidth).Render(content),
547				))
548				beforeLine++
549			}
550
551			switch {
552			case l.after == nil:
553				ls := dv.style.MissingLine
554				if dv.lineNumbers {
555					write(ls.LineNumber.Render(pad(" ", dv.afterNumDigits)))
556				}
557				write(afterFullContentStyle.Render(
558					ls.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render("  "),
559				))
560			case l.after.Kind == udiff.Equal:
561				ls := dv.style.EqualLine
562				content := getContent(l.after.Content, ls)
563				leadingEllipsis := getLeadingEllipsis(content)
564				if dv.lineNumbers {
565					write(ls.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
566				}
567				write(afterFullContentStyle.Render(
568					ls.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(ternary(leadingEllipsis, " …", "  ") + content),
569				))
570				afterLine++
571			case l.after.Kind == udiff.Insert:
572				ls := dv.style.InsertLine
573				content := getContent(l.after.Content, ls)
574				leadingEllipsis := getLeadingEllipsis(content)
575				if dv.lineNumbers {
576					write(ls.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
577				}
578				write(afterFullContentStyle.Render(
579					ls.Symbol.Render(ternary(leadingEllipsis, "+…", "+ ")) +
580						ls.Code.Width(dv.codeWidth+btoi(dv.extraColOnAfter)).Render(content),
581				))
582				afterLine++
583			}
584
585			write("\n")
586
587			printedLines++
588		}
589	}
590
591	for printedLines < dv.height {
592		ls := dv.style.MissingLine
593		if dv.lineNumbers {
594			write(ls.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
595		}
596		write(ls.Code.Width(dv.fullCodeWidth).Render(" "))
597		if dv.lineNumbers {
598			write(ls.LineNumber.Render(pad(" ", dv.afterNumDigits)))
599		}
600		write(ls.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
601		write("\n")
602		printedLines++
603	}
604
605	return b.String()
606}
607
608// hunkLineFor formats the header line for a hunk in the unified diff view.
609func (dv *DiffView) hunkLineFor(h *udiff.Hunk) string {
610	beforeShownLines, afterShownLines := dv.hunkShownLines(h)
611
612	return fmt.Sprintf(
613		"  @@ -%d,%d +%d,%d @@ ",
614		h.FromLine,
615		beforeShownLines,
616		h.ToLine,
617		afterShownLines,
618	)
619}
620
621// hunkShownLines calculates the number of lines shown in a hunk for both before
622// and after versions.
623func (dv *DiffView) hunkShownLines(h *udiff.Hunk) (before, after int) {
624	for _, l := range h.Lines {
625		switch l.Kind {
626		case udiff.Equal:
627			before++
628			after++
629		case udiff.Insert:
630			after++
631		case udiff.Delete:
632			before++
633		}
634	}
635	return
636}
637
638func (dv *DiffView) lineStyleForType(t udiff.OpKind) LineStyle {
639	switch t {
640	case udiff.Equal:
641		return dv.style.EqualLine
642	case udiff.Insert:
643		return dv.style.InsertLine
644	case udiff.Delete:
645		return dv.style.DeleteLine
646	default:
647		return dv.style.MissingLine
648	}
649}
650
651func (dv *DiffView) hightlightCode(source string, bgColor color.Color) string {
652	if dv.chromaStyle == nil {
653		return source
654	}
655
656	l := dv.getChromaLexer(source)
657	f := dv.getChromaFormatter(bgColor)
658
659	it, err := l.Tokenise(nil, source)
660	if err != nil {
661		return source
662	}
663
664	var b strings.Builder
665	if err := f.Format(&b, dv.chromaStyle, it); err != nil {
666		return source
667	}
668	return b.String()
669}
670
671func (dv *DiffView) getChromaLexer(source string) chroma.Lexer {
672	l := lexers.Match(dv.before.path)
673	if l == nil {
674		l = lexers.Analyse(source)
675	}
676	if l == nil {
677		l = lexers.Fallback
678	}
679	return chroma.Coalesce(l)
680}
681
682func (dv *DiffView) getChromaFormatter(gbColor color.Color) chroma.Formatter {
683	return chromaFormatter{
684		bgColor: gbColor,
685	}
686}