diffview.go

  1package diffview
  2
  3import (
  4	"fmt"
  5	"os"
  6	"strconv"
  7	"strings"
  8
  9	"github.com/aymanbagabas/go-udiff"
 10	"github.com/aymanbagabas/go-udiff/myers"
 11	"github.com/charmbracelet/lipgloss/v2"
 12	"github.com/charmbracelet/x/ansi"
 13)
 14
 15const (
 16	leadingSymbolsSize = 2
 17	lineNumPadding     = 1
 18)
 19
 20type file struct {
 21	path    string
 22	content string
 23}
 24
 25type layout int
 26
 27const (
 28	layoutUnified layout = iota + 1
 29	layoutSplit
 30)
 31
 32// DiffView represents a view for displaying differences between two files.
 33type DiffView struct {
 34	layout       layout
 35	before       file
 36	after        file
 37	contextLines int
 38	lineNumbers  bool
 39	highlight    bool
 40	height       int
 41	width        int
 42	style        Style
 43
 44	isComputed bool
 45	err        error
 46	unified    udiff.UnifiedDiff
 47	edits      []udiff.Edit
 48
 49	splitHunks []splitHunk
 50
 51	codeWidth       int
 52	fullCodeWidth   int  // with leading symbols
 53	extraColOnAfter bool // add extra column on after panel
 54	beforeNumDigits int
 55	afterNumDigits  int
 56}
 57
 58// New creates a new DiffView with default settings.
 59func New() *DiffView {
 60	dv := &DiffView{
 61		layout:       layoutUnified,
 62		contextLines: udiff.DefaultContextLines,
 63		lineNumbers:  true,
 64	}
 65	if lipgloss.HasDarkBackground(os.Stdin, os.Stdout) {
 66		dv.style = DefaultDarkStyle
 67	} else {
 68		dv.style = DefaultLightStyle
 69	}
 70	return dv
 71}
 72
 73// Unified sets the layout of the DiffView to unified.
 74func (dv *DiffView) Unified() *DiffView {
 75	dv.layout = layoutUnified
 76	return dv
 77}
 78
 79// Split sets the layout of the DiffView to split (side-by-side).
 80func (dv *DiffView) Split() *DiffView {
 81	dv.layout = layoutSplit
 82	return dv
 83}
 84
 85// Before sets the "before" file for the DiffView.
 86func (dv *DiffView) Before(path, content string) *DiffView {
 87	dv.before = file{path: path, content: content}
 88	return dv
 89}
 90
 91// After sets the "after" file for the DiffView.
 92func (dv *DiffView) After(path, content string) *DiffView {
 93	dv.after = file{path: path, content: content}
 94	return dv
 95}
 96
 97// ContextLines sets the number of context lines for the DiffView.
 98func (dv *DiffView) ContextLines(contextLines int) *DiffView {
 99	dv.contextLines = contextLines
100	return dv
101}
102
103// Style sets the style for the DiffView.
104func (dv *DiffView) Style(style Style) *DiffView {
105	dv.style = style
106	return dv
107}
108
109// LineNumbers sets whether to display line numbers in the DiffView.
110func (dv *DiffView) LineNumbers(lineNumbers bool) *DiffView {
111	dv.lineNumbers = lineNumbers
112	return dv
113}
114
115// SyntaxHightlight sets whether to enable syntax highlighting in the DiffView.
116func (dv *DiffView) SyntaxHightlight(highlight bool) *DiffView {
117	dv.highlight = highlight
118	return dv
119}
120
121// Height sets the height of the DiffView.
122func (dv *DiffView) Height(height int) *DiffView {
123	dv.height = height
124	return dv
125}
126
127// Width sets the width of the DiffView.
128func (dv *DiffView) Width(width int) *DiffView {
129	dv.width = width
130	return dv
131}
132
133// String returns the string representation of the DiffView.
134func (dv *DiffView) String() string {
135	if err := dv.computeDiff(); err != nil {
136		return err.Error()
137	}
138	dv.convertDiffToSplit()
139	dv.adjustStyles()
140	dv.detectNumDigits()
141
142	if dv.width <= 0 {
143		dv.detectCodeWidth()
144	} else {
145		dv.resizeCodeWidth()
146	}
147
148	style := lipgloss.NewStyle()
149	if dv.width > 0 {
150		style = style.MaxWidth(dv.width)
151	}
152	if dv.height > 0 {
153		style = style.MaxHeight(dv.height)
154	}
155
156	switch dv.layout {
157	case layoutUnified:
158		return style.Render(strings.TrimSuffix(dv.renderUnified(), "\n"))
159	case layoutSplit:
160		return style.Render(strings.TrimSuffix(dv.renderSplit(), "\n"))
161	default:
162		panic("unknown diffview layout")
163	}
164}
165
166// computeDiff computes the differences between the "before" and "after" files.
167func (dv *DiffView) computeDiff() error {
168	if dv.isComputed {
169		return dv.err
170	}
171	dv.isComputed = true
172	dv.edits = myers.ComputeEdits( //nolint:staticcheck
173		dv.before.content,
174		dv.after.content,
175	)
176	dv.unified, dv.err = udiff.ToUnifiedDiff(
177		dv.before.path,
178		dv.after.path,
179		dv.before.content,
180		dv.edits,
181		dv.contextLines,
182	)
183	return dv.err
184}
185
186// convertDiffToSplit converts the unified diff to a split diff if the layout is
187// set to split.
188func (dv *DiffView) convertDiffToSplit() {
189	if dv.layout != layoutSplit {
190		return
191	}
192
193	dv.splitHunks = make([]splitHunk, len(dv.unified.Hunks))
194	for i, h := range dv.unified.Hunks {
195		dv.splitHunks[i] = hunkToSplit(h)
196	}
197}
198
199// adjustStyles adjusts adds padding and alignment to the styles.
200func (dv *DiffView) adjustStyles() {
201	dv.style.MissingLine.LineNumber = setPadding(dv.style.MissingLine.LineNumber)
202	dv.style.DividerLine.LineNumber = setPadding(dv.style.DividerLine.LineNumber)
203	dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
204	dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
205	dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
206}
207
208// detectNumDigits calculates the maximum number of digits needed for before and
209// after line numbers.
210func (dv *DiffView) detectNumDigits() {
211	dv.beforeNumDigits = 0
212	dv.afterNumDigits = 0
213
214	for _, h := range dv.unified.Hunks {
215		dv.beforeNumDigits = max(dv.beforeNumDigits, len(strconv.Itoa(h.FromLine+len(h.Lines))))
216		dv.afterNumDigits = max(dv.afterNumDigits, len(strconv.Itoa(h.ToLine+len(h.Lines))))
217	}
218}
219
220func setPadding(s lipgloss.Style) lipgloss.Style {
221	return s.Padding(0, lineNumPadding).Align(lipgloss.Right)
222}
223
224// detectCodeWidth calculates the maximum width of code lines in the diff view.
225func (dv *DiffView) detectCodeWidth() {
226	switch dv.layout {
227	case layoutUnified:
228		dv.detectUnifiedCodeWidth()
229	case layoutSplit:
230		dv.detectSplitCodeWidth()
231	}
232	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
233}
234
235// detectUnifiedCodeWidth calculates the maximum width of code lines in a
236// unified diff.
237func (dv *DiffView) detectUnifiedCodeWidth() {
238	dv.codeWidth = 0
239
240	for _, h := range dv.unified.Hunks {
241		shownLines := ansi.StringWidth(dv.hunkLineFor(h))
242
243		for _, l := range h.Lines {
244			lineWidth := ansi.StringWidth(strings.TrimSuffix(l.Content, "\n")) + 1
245			dv.codeWidth = max(dv.codeWidth, lineWidth, shownLines)
246		}
247	}
248}
249
250// detectSplitCodeWidth calculates the maximum width of code lines in a
251// split diff.
252func (dv *DiffView) detectSplitCodeWidth() {
253	dv.codeWidth = 0
254
255	for i, h := range dv.splitHunks {
256		shownLines := ansi.StringWidth(dv.hunkLineFor(dv.unified.Hunks[i]))
257
258		for _, l := range h.lines {
259			if l.before != nil {
260				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.before.Content, "\n")) + 1
261				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
262			}
263			if l.after != nil {
264				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.after.Content, "\n")) + 1
265				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
266			}
267		}
268	}
269}
270
271// resizeCodeWidth resizes the code width to fit within the specified width.
272func (dv *DiffView) resizeCodeWidth() {
273	fullNumWidth := dv.beforeNumDigits + dv.afterNumDigits
274	fullNumWidth += lineNumPadding * 4 // left and right padding for both line numbers
275
276	switch dv.layout {
277	case layoutUnified:
278		dv.codeWidth = dv.width - fullNumWidth - leadingSymbolsSize
279	case layoutSplit:
280		remainingWidth := dv.width - fullNumWidth - leadingSymbolsSize*2
281		dv.codeWidth = remainingWidth / 2
282		dv.extraColOnAfter = isOdd(remainingWidth)
283	}
284
285	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
286}
287
288// renderUnified renders the unified diff view as a string.
289func (dv *DiffView) renderUnified() string {
290	var b strings.Builder
291
292	fullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
293	printedLines := 0
294
295	for _, h := range dv.unified.Hunks {
296		if dv.lineNumbers {
297			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
298			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
299		}
300		content := ansi.Truncate(dv.hunkLineFor(h), dv.fullCodeWidth, "…")
301		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
302		b.WriteRune('\n')
303		printedLines++
304
305		beforeLine := h.FromLine
306		afterLine := h.ToLine
307
308		for _, l := range h.Lines {
309			content := strings.TrimSuffix(l.Content, "\n")
310			content = ansi.Truncate(content, dv.codeWidth, "…")
311
312			switch l.Kind {
313			case udiff.Equal:
314				if dv.lineNumbers {
315					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
316					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
317				}
318				b.WriteString(fullContentStyle.Render(
319					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + content),
320				))
321				beforeLine++
322				afterLine++
323			case udiff.Insert:
324				if dv.lineNumbers {
325					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
326					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
327				}
328				b.WriteString(fullContentStyle.Render(
329					dv.style.InsertLine.Symbol.Render("+ ") +
330						dv.style.InsertLine.Code.Width(dv.codeWidth).Render(content),
331				))
332				afterLine++
333			case udiff.Delete:
334				if dv.lineNumbers {
335					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
336					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
337				}
338				b.WriteString(fullContentStyle.Render(
339					dv.style.DeleteLine.Symbol.Render("- ") +
340						dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(content),
341				))
342				beforeLine++
343			}
344			b.WriteRune('\n')
345
346			printedLines++
347		}
348	}
349
350	for printedLines < dv.height {
351		if dv.lineNumbers {
352			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
353			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
354		}
355		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render("  "))
356		b.WriteRune('\n')
357		printedLines++
358	}
359
360	return b.String()
361}
362
363// renderSplit renders the split (side-by-side) diff view as a string.
364func (dv *DiffView) renderSplit() string {
365	var b strings.Builder
366
367	beforeFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth)
368	afterFullContentStyle := lipgloss.NewStyle().MaxWidth(dv.fullCodeWidth + btoi(dv.extraColOnAfter))
369	printedLines := 0
370
371	for i, h := range dv.splitHunks {
372		if dv.lineNumbers {
373			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
374		}
375		content := ansi.Truncate(dv.hunkLineFor(dv.unified.Hunks[i]), dv.fullCodeWidth, "…")
376		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
377		if dv.lineNumbers {
378			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
379		}
380		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
381		b.WriteRune('\n')
382		printedLines++
383
384		beforeLine := h.fromLine
385		afterLine := h.toLine
386
387		for _, l := range h.lines {
388			var beforeContent string
389			var afterContent string
390			if l.before != nil {
391				beforeContent = strings.TrimSuffix(l.before.Content, "\n")
392				beforeContent = ansi.Truncate(beforeContent, dv.codeWidth, "…")
393			}
394			if l.after != nil {
395				afterContent = strings.TrimSuffix(l.after.Content, "\n")
396				afterContent = ansi.Truncate(afterContent, dv.codeWidth+btoi(dv.extraColOnAfter), "…")
397			}
398
399			switch {
400			case l.before == nil:
401				if dv.lineNumbers {
402					b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
403				}
404				b.WriteString(beforeFullContentStyle.Render(
405					dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render("  "),
406				))
407			case l.before.Kind == udiff.Equal:
408				if dv.lineNumbers {
409					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
410				}
411				b.WriteString(beforeFullContentStyle.Render(
412					dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + beforeContent),
413				))
414				beforeLine++
415			case l.before.Kind == udiff.Delete:
416				if dv.lineNumbers {
417					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
418				}
419				b.WriteString(beforeFullContentStyle.Render(
420					dv.style.DeleteLine.Symbol.Render("- ") +
421						dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(beforeContent),
422				))
423				beforeLine++
424			}
425
426			switch {
427			case l.after == nil:
428				if dv.lineNumbers {
429					b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
430				}
431				b.WriteString(afterFullContentStyle.Render(
432					dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render("  "),
433				))
434			case l.after.Kind == udiff.Equal:
435				if dv.lineNumbers {
436					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
437				}
438				b.WriteString(afterFullContentStyle.Render(
439					dv.style.EqualLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render("  " + afterContent),
440				))
441				afterLine++
442			case l.after.Kind == udiff.Insert:
443				if dv.lineNumbers {
444					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
445				}
446				b.WriteString(afterFullContentStyle.Render(
447					dv.style.InsertLine.Symbol.Render("+ ") +
448						dv.style.InsertLine.Code.Width(dv.codeWidth+btoi(dv.extraColOnAfter)).Render(afterContent),
449				))
450				afterLine++
451			}
452
453			b.WriteRune('\n')
454
455			printedLines++
456		}
457	}
458
459	for printedLines < dv.height {
460		if dv.lineNumbers {
461			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
462		}
463		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render(" "))
464		if dv.lineNumbers {
465			b.WriteString(dv.style.MissingLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
466		}
467		b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth + btoi(dv.extraColOnAfter)).Render(" "))
468		b.WriteRune('\n')
469		printedLines++
470	}
471
472	return b.String()
473}
474
475// hunkLineFor formats the header line for a hunk in the unified diff view.
476func (dv *DiffView) hunkLineFor(h *udiff.Hunk) string {
477	beforeShownLines, afterShownLines := dv.hunkShownLines(h)
478
479	return fmt.Sprintf(
480		"  @@ -%d,%d +%d,%d @@ ",
481		h.FromLine,
482		beforeShownLines,
483		h.ToLine,
484		afterShownLines,
485	)
486}
487
488// hunkShownLines calculates the number of lines shown in a hunk for both before
489// and after versions.
490func (dv *DiffView) hunkShownLines(h *udiff.Hunk) (before, after int) {
491	for _, l := range h.Lines {
492		switch l.Kind {
493		case udiff.Equal:
494			before++
495			after++
496		case udiff.Insert:
497			after++
498		case udiff.Delete:
499			before++
500		}
501	}
502	return
503}