diffview.go

  1package diffview
  2
  3import (
  4	"fmt"
  5	"os"
  6	"strconv"
  7	"strings"
  8
  9	"github.com/aymanbagabas/go-udiff"
 10	"github.com/aymanbagabas/go-udiff/myers"
 11	"github.com/charmbracelet/lipgloss/v2"
 12	"github.com/charmbracelet/x/ansi"
 13)
 14
 15const (
 16	leadingSymbolsSize = 2
 17	lineNumPadding     = 1
 18)
 19
 20type file struct {
 21	path    string
 22	content string
 23}
 24
 25type layout int
 26
 27const (
 28	layoutUnified layout = iota + 1
 29	layoutSplit
 30)
 31
 32// DiffView represents a view for displaying differences between two files.
 33type DiffView struct {
 34	layout       layout
 35	before       file
 36	after        file
 37	contextLines int
 38	lineNumbers  bool
 39	highlight    bool
 40	height       int
 41	width        int
 42	style        Style
 43
 44	isComputed bool
 45	err        error
 46	unified    udiff.UnifiedDiff
 47	edits      []udiff.Edit
 48
 49	splitHunks []splitHunk
 50
 51	codeWidth       int
 52	fullCodeWidth   int // with leading symbols
 53	beforeNumDigits int
 54	afterNumDigits  int
 55}
 56
 57// New creates a new DiffView with default settings.
 58func New() *DiffView {
 59	dv := &DiffView{
 60		layout:       layoutUnified,
 61		contextLines: udiff.DefaultContextLines,
 62		lineNumbers:  true,
 63	}
 64	if lipgloss.HasDarkBackground(os.Stdin, os.Stdout) {
 65		dv.style = DefaultDarkStyle
 66	} else {
 67		dv.style = DefaultLightStyle
 68	}
 69	return dv
 70}
 71
 72// Unified sets the layout of the DiffView to unified.
 73func (dv *DiffView) Unified() *DiffView {
 74	dv.layout = layoutUnified
 75	return dv
 76}
 77
 78// Split sets the layout of the DiffView to split (side-by-side).
 79func (dv *DiffView) Split() *DiffView {
 80	dv.layout = layoutSplit
 81	return dv
 82}
 83
 84// Before sets the "before" file for the DiffView.
 85func (dv *DiffView) Before(path, content string) *DiffView {
 86	dv.before = file{path: path, content: content}
 87	return dv
 88}
 89
 90// After sets the "after" file for the DiffView.
 91func (dv *DiffView) After(path, content string) *DiffView {
 92	dv.after = file{path: path, content: content}
 93	return dv
 94}
 95
 96// ContextLines sets the number of context lines for the DiffView.
 97func (dv *DiffView) ContextLines(contextLines int) *DiffView {
 98	dv.contextLines = contextLines
 99	return dv
100}
101
102// Style sets the style for the DiffView.
103func (dv *DiffView) Style(style Style) *DiffView {
104	dv.style = style
105	return dv
106}
107
108// LineNumbers sets whether to display line numbers in the DiffView.
109func (dv *DiffView) LineNumbers(lineNumbers bool) *DiffView {
110	dv.lineNumbers = lineNumbers
111	return dv
112}
113
114// SyntaxHightlight sets whether to enable syntax highlighting in the DiffView.
115func (dv *DiffView) SyntaxHightlight(highlight bool) *DiffView {
116	dv.highlight = highlight
117	return dv
118}
119
120// Height sets the height of the DiffView.
121func (dv *DiffView) Height(height int) *DiffView {
122	dv.height = height
123	return dv
124}
125
126// Width sets the width of the DiffView.
127func (dv *DiffView) Width(width int) *DiffView {
128	dv.width = width
129	return dv
130}
131
132// String returns the string representation of the DiffView.
133func (dv *DiffView) String() string {
134	if err := dv.computeDiff(); err != nil {
135		return err.Error()
136	}
137	dv.convertDiffToSplit()
138	dv.adjustStyles()
139	dv.detectNumDigits()
140
141	if dv.width <= 0 {
142		dv.detectCodeWidth()
143	} else {
144		dv.resizeCodeWidth()
145	}
146
147	switch dv.layout {
148	case layoutUnified:
149		return dv.renderUnified()
150	case layoutSplit:
151		return dv.renderSplit()
152	default:
153		panic("unknown diffview layout")
154	}
155}
156
157// computeDiff computes the differences between the "before" and "after" files.
158func (dv *DiffView) computeDiff() error {
159	if dv.isComputed {
160		return dv.err
161	}
162	dv.isComputed = true
163	dv.edits = myers.ComputeEdits( //nolint:staticcheck
164		dv.before.content,
165		dv.after.content,
166	)
167	dv.unified, dv.err = udiff.ToUnifiedDiff(
168		dv.before.path,
169		dv.after.path,
170		dv.before.content,
171		dv.edits,
172		dv.contextLines,
173	)
174	return dv.err
175}
176
177// convertDiffToSplit converts the unified diff to a split diff if the layout is
178// set to split.
179func (dv *DiffView) convertDiffToSplit() {
180	if dv.layout != layoutSplit {
181		return
182	}
183
184	dv.splitHunks = make([]splitHunk, len(dv.unified.Hunks))
185	for i, h := range dv.unified.Hunks {
186		dv.splitHunks[i] = hunkToSplit(h)
187	}
188}
189
190// adjustStyles adjusts adds padding and alignment to the styles.
191func (dv *DiffView) adjustStyles() {
192	dv.style.MissingLine.LineNumber = setPadding(dv.style.MissingLine.LineNumber)
193	dv.style.DividerLine.LineNumber = setPadding(dv.style.DividerLine.LineNumber)
194	dv.style.EqualLine.LineNumber = setPadding(dv.style.EqualLine.LineNumber)
195	dv.style.InsertLine.LineNumber = setPadding(dv.style.InsertLine.LineNumber)
196	dv.style.DeleteLine.LineNumber = setPadding(dv.style.DeleteLine.LineNumber)
197}
198
199// detectNumDigits calculates the maximum number of digits needed for before and
200// after line numbers.
201func (dv *DiffView) detectNumDigits() {
202	dv.beforeNumDigits = 0
203	dv.afterNumDigits = 0
204
205	for _, h := range dv.unified.Hunks {
206		dv.beforeNumDigits = max(dv.beforeNumDigits, len(strconv.Itoa(h.FromLine+len(h.Lines))))
207		dv.afterNumDigits = max(dv.afterNumDigits, len(strconv.Itoa(h.ToLine+len(h.Lines))))
208	}
209}
210
211func setPadding(s lipgloss.Style) lipgloss.Style {
212	return s.Padding(0, lineNumPadding).Align(lipgloss.Right)
213}
214
215// detectCodeWidth calculates the maximum width of code lines in the diff view.
216func (dv *DiffView) detectCodeWidth() {
217	switch dv.layout {
218	case layoutUnified:
219		dv.detectUnifiedCodeWidth()
220	case layoutSplit:
221		dv.detectSplitCodeWidth()
222	}
223	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
224}
225
226// detectUnifiedCodeWidth calculates the maximum width of code lines in a
227// unified diff.
228func (dv *DiffView) detectUnifiedCodeWidth() {
229	dv.codeWidth = 0
230
231	for _, h := range dv.unified.Hunks {
232		shownLines := ansi.StringWidth(dv.hunkLineFor(h))
233
234		for _, l := range h.Lines {
235			lineWidth := ansi.StringWidth(strings.TrimSuffix(l.Content, "\n")) + 1
236			dv.codeWidth = max(dv.codeWidth, lineWidth, shownLines)
237		}
238	}
239}
240
241// detectSplitCodeWidth calculates the maximum width of code lines in a
242// split diff.
243func (dv *DiffView) detectSplitCodeWidth() {
244	dv.codeWidth = 0
245
246	for i, h := range dv.splitHunks {
247		shownLines := ansi.StringWidth(dv.hunkLineFor(dv.unified.Hunks[i]))
248
249		for _, l := range h.lines {
250			if l.before != nil {
251				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.before.Content, "\n")) + 1
252				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
253			}
254			if l.after != nil {
255				codeWidth := ansi.StringWidth(strings.TrimSuffix(l.after.Content, "\n")) + 1
256				dv.codeWidth = max(dv.codeWidth, codeWidth, shownLines)
257			}
258		}
259	}
260}
261
262// resizeCodeWidth resizes the code width to fit within the specified width.
263func (dv *DiffView) resizeCodeWidth() {
264	fullNumWidth := dv.beforeNumDigits + dv.afterNumDigits
265	fullNumWidth += lineNumPadding * 4 // left and right padding for both line numbers
266
267	switch dv.layout {
268	case layoutUnified:
269		dv.codeWidth = dv.width - fullNumWidth - leadingSymbolsSize
270	case layoutSplit:
271		dv.codeWidth = (dv.width - fullNumWidth - leadingSymbolsSize*2) / 2
272	}
273
274	dv.fullCodeWidth = dv.codeWidth + leadingSymbolsSize
275}
276
277// renderUnified renders the unified diff view as a string.
278func (dv *DiffView) renderUnified() string {
279	var b strings.Builder
280
281	for _, h := range dv.unified.Hunks {
282		if dv.lineNumbers {
283			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
284			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
285		}
286		content := ansi.Truncate(dv.hunkLineFor(h), dv.fullCodeWidth, "…")
287		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
288		b.WriteRune('\n')
289
290		beforeLine := h.FromLine
291		afterLine := h.ToLine
292
293		for _, l := range h.Lines {
294			content := strings.TrimSuffix(l.Content, "\n")
295			content = ansi.Truncate(content, dv.codeWidth, "…")
296
297			switch l.Kind {
298			case udiff.Equal:
299				if dv.lineNumbers {
300					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
301					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
302				}
303				b.WriteString(dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + content))
304				beforeLine++
305				afterLine++
306			case udiff.Insert:
307				if dv.lineNumbers {
308					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
309					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
310				}
311				b.WriteString(dv.style.InsertLine.Symbol.Render("+ "))
312				b.WriteString(dv.style.InsertLine.Code.Width(dv.codeWidth).Render(content))
313				afterLine++
314			case udiff.Delete:
315				if dv.lineNumbers {
316					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
317					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
318				}
319				b.WriteString(dv.style.DeleteLine.Symbol.Render("- "))
320				b.WriteString(dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(content))
321				beforeLine++
322			}
323			b.WriteRune('\n')
324		}
325	}
326
327	return b.String()
328}
329
330// renderSplit renders the split (side-by-side) diff view as a string.
331func (dv *DiffView) renderSplit() string {
332	var b strings.Builder
333
334	for i, h := range dv.splitHunks {
335		if dv.lineNumbers {
336			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.beforeNumDigits)))
337		}
338		content := ansi.Truncate(dv.hunkLineFor(dv.unified.Hunks[i]), dv.fullCodeWidth, "…")
339		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(content))
340		if dv.lineNumbers {
341			b.WriteString(dv.style.DividerLine.LineNumber.Render(pad("…", dv.afterNumDigits)))
342		}
343		b.WriteString(dv.style.DividerLine.Code.Width(dv.fullCodeWidth).Render(" "))
344		b.WriteRune('\n')
345
346		beforeLine := h.fromLine
347		afterLine := h.toLine
348
349		for _, l := range h.lines {
350			var beforeContent string
351			var afterContent string
352			if l.before != nil {
353				beforeContent = strings.TrimSuffix(l.before.Content, "\n")
354				beforeContent = ansi.Truncate(beforeContent, dv.codeWidth, "…")
355			}
356			if l.after != nil {
357				afterContent = strings.TrimSuffix(l.after.Content, "\n")
358				afterContent = ansi.Truncate(afterContent, dv.codeWidth, "…")
359			}
360
361			switch {
362			case l.before == nil:
363				if dv.lineNumbers {
364					b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.beforeNumDigits)))
365				}
366				b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render("  "))
367			case l.before.Kind == udiff.Equal:
368				if dv.lineNumbers {
369					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
370				}
371				b.WriteString(dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + beforeContent))
372				beforeLine++
373			case l.before.Kind == udiff.Delete:
374				if dv.lineNumbers {
375					b.WriteString(dv.style.DeleteLine.LineNumber.Render(pad(beforeLine, dv.beforeNumDigits)))
376				}
377				b.WriteString(dv.style.DeleteLine.Symbol.Render("- "))
378				b.WriteString(dv.style.DeleteLine.Code.Width(dv.codeWidth).Render(beforeContent))
379				beforeLine++
380			}
381
382			switch {
383			case l.after == nil:
384				if dv.lineNumbers {
385					b.WriteString(dv.style.MissingLine.LineNumber.Render(pad(" ", dv.afterNumDigits)))
386				}
387				b.WriteString(dv.style.MissingLine.Code.Width(dv.fullCodeWidth).Render("  "))
388			case l.after.Kind == udiff.Equal:
389				if dv.lineNumbers {
390					b.WriteString(dv.style.EqualLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
391				}
392				b.WriteString(dv.style.EqualLine.Code.Width(dv.fullCodeWidth).Render("  " + afterContent))
393				afterLine++
394			case l.after.Kind == udiff.Insert:
395				if dv.lineNumbers {
396					b.WriteString(dv.style.InsertLine.LineNumber.Render(pad(afterLine, dv.afterNumDigits)))
397				}
398				b.WriteString(dv.style.InsertLine.Symbol.Render("+ "))
399				b.WriteString(dv.style.InsertLine.Code.Width(dv.codeWidth).Render(afterContent))
400				afterLine++
401			}
402
403			b.WriteRune('\n')
404		}
405	}
406
407	return b.String()
408}
409
410// hunkLineFor formats the header line for a hunk in the unified diff view.
411func (dv *DiffView) hunkLineFor(h *udiff.Hunk) string {
412	beforeShownLines, afterShownLines := dv.hunkShownLines(h)
413
414	return fmt.Sprintf(
415		"  @@ -%d,%d +%d,%d @@ ",
416		h.FromLine,
417		beforeShownLines,
418		h.ToLine,
419		afterShownLines,
420	)
421}
422
423// hunkShownLines calculates the number of lines shown in a hunk for both before
424// and after versions.
425func (dv *DiffView) hunkShownLines(h *udiff.Hunk) (before, after int) {
426	for _, l := range h.Lines {
427		switch l.Kind {
428		case udiff.Equal:
429			before++
430			after++
431		case udiff.Insert:
432			after++
433		case udiff.Delete:
434			before++
435		}
436	}
437	return
438}