diffview_test.go

  1package diffview_test
  2
  3import (
  4	_ "embed"
  5	"strings"
  6	"testing"
  7
  8	"github.com/charmbracelet/x/ansi"
  9	"github.com/charmbracelet/x/exp/golden"
 10	"github.com/opencode-ai/opencode/internal/exp/diffview"
 11)
 12
 13//go:embed testdata/TestDefault.before
 14var TestDefaultBefore string
 15
 16//go:embed testdata/TestDefault.after
 17var TestDefaultAfter string
 18
 19//go:embed testdata/TestMultipleHunks.before
 20var TestMultipleHunksBefore string
 21
 22//go:embed testdata/TestMultipleHunks.after
 23var TestMultipleHunksAfter string
 24
 25//go:embed testdata/TestNarrow.before
 26var TestNarrowBefore string
 27
 28//go:embed testdata/TestNarrow.after
 29var TestNarrowAfter string
 30
 31type (
 32	TestFunc  func(dv *diffview.DiffView) *diffview.DiffView
 33	TestFuncs map[string]TestFunc
 34)
 35
 36var (
 37	UnifiedFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 38		return dv.Unified()
 39	}
 40	SplitFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 41		return dv.Split()
 42	}
 43
 44	DefaultFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 45		return dv.
 46			Before("main.go", TestDefaultBefore).
 47			After("main.go", TestDefaultAfter)
 48	}
 49	NoLineNumbersFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 50		return dv.
 51			Before("main.go", TestDefaultBefore).
 52			After("main.go", TestDefaultAfter).
 53			LineNumbers(false)
 54	}
 55	MultipleHunksFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 56		return dv.
 57			Before("main.go", TestMultipleHunksBefore).
 58			After("main.go", TestMultipleHunksAfter)
 59	}
 60	CustomContextLinesFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 61		return dv.
 62			Before("main.go", TestMultipleHunksBefore).
 63			After("main.go", TestMultipleHunksAfter).
 64			ContextLines(4)
 65	}
 66	NarrowFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 67		return dv.
 68			Before("text.txt", TestNarrowBefore).
 69			After("text.txt", TestNarrowAfter)
 70	}
 71	SmallWidthFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 72		return dv.
 73			Before("main.go", TestMultipleHunksBefore).
 74			After("main.go", TestMultipleHunksAfter).
 75			Width(40)
 76	}
 77	LargeWidthFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 78		return dv.
 79			Before("main.go", TestMultipleHunksBefore).
 80			After("main.go", TestMultipleHunksAfter).
 81			Width(120)
 82	}
 83
 84	LightModeFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 85		return dv.Style(diffview.DefaultLightStyle)
 86	}
 87	DarkModeFunc = func(dv *diffview.DiffView) *diffview.DiffView {
 88		return dv.Style(diffview.DefaultDarkStyle)
 89	}
 90
 91	LayoutFuncs = TestFuncs{
 92		"Unified": UnifiedFunc,
 93		"Split":   SplitFunc,
 94	}
 95	BehaviorFuncs = TestFuncs{
 96		"Default":            DefaultFunc,
 97		"NoLineNumbers":      NoLineNumbersFunc,
 98		"MultipleHunks":      MultipleHunksFunc,
 99		"CustomContextLines": CustomContextLinesFunc,
100		"Narrow":             NarrowFunc,
101		"SmallWidth":         SmallWidthFunc,
102		"LargeWidth":         LargeWidthFunc,
103	}
104	ThemeFuncs = TestFuncs{
105		"LightMode": LightModeFunc,
106		"DarkMode":  DarkModeFunc,
107	}
108)
109
110func TestDiffView(t *testing.T) {
111	for layoutName, layoutFunc := range LayoutFuncs {
112		t.Run(layoutName, func(t *testing.T) {
113			for behaviorName, behaviorFunc := range BehaviorFuncs {
114				t.Run(behaviorName, func(t *testing.T) {
115					for themeName, themeFunc := range ThemeFuncs {
116						t.Run(themeName, func(t *testing.T) {
117							dv := diffview.New()
118							dv = layoutFunc(dv)
119							dv = behaviorFunc(dv)
120							dv = themeFunc(dv)
121
122							output := dv.String()
123							golden.RequireEqual(t, []byte(output))
124
125							switch behaviorName {
126							case "SmallWidth":
127								assertLineWidth(t, 40, output)
128							case "LargeWidth":
129								assertLineWidth(t, 120, output)
130							}
131						})
132					}
133				})
134			}
135		})
136	}
137}
138
139func assertLineWidth(t *testing.T, expected int, output string) {
140	var lineWidth int
141	for line := range strings.SplitSeq(output, "\n") {
142		lineWidth = max(lineWidth, ansi.StringWidth(line))
143	}
144	if lineWidth != expected {
145		t.Errorf("expected output width to be == %d, got %d", expected, lineWidth)
146	}
147}