recent_models_test.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"io/fs"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9
 10	"github.com/stretchr/testify/require"
 11)
 12
 13// readConfigJSON reads and unmarshals the JSON config file at path.
 14func readConfigJSON(t *testing.T, path string) map[string]any {
 15	t.Helper()
 16	baseDir := filepath.Dir(path)
 17	fileName := filepath.Base(path)
 18	b, err := fs.ReadFile(os.DirFS(baseDir), fileName)
 19	require.NoError(t, err)
 20	var out map[string]any
 21	require.NoError(t, json.Unmarshal(b, &out))
 22	return out
 23}
 24
 25// readRecentModels reads the recent_models section from the config file.
 26func readRecentModels(t *testing.T, path string) map[string]any {
 27	t.Helper()
 28	out := readConfigJSON(t, path)
 29	rm, ok := out["recent_models"].(map[string]any)
 30	require.True(t, ok)
 31	return rm
 32}
 33
 34func newTestService(t *testing.T) (*Service, string) {
 35	t.Helper()
 36	dir := t.TempDir()
 37	cfg := &Config{}
 38	cfg.setDefaults(dir, "")
 39	storePath := filepath.Join(dir, "config.json")
 40	svc := &Service{
 41		cfg:        cfg,
 42		store:      NewFileStore(storePath),
 43		workingDir: dir,
 44	}
 45	return svc, storePath
 46}
 47
 48func TestRecordRecentModel_AddsAndPersists(t *testing.T) {
 49	t.Parallel()
 50
 51	svc, storePath := newTestService(t)
 52	cfg := svc.cfg
 53
 54	err := svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})
 55	require.NoError(t, err)
 56
 57	// in-memory state
 58	require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1)
 59	require.Equal(t, "openai", cfg.RecentModels[SelectedModelTypeLarge][0].Provider)
 60	require.Equal(t, "gpt-4o", cfg.RecentModels[SelectedModelTypeLarge][0].Model)
 61
 62	// persisted state
 63	rm := readRecentModels(t, storePath)
 64	large, ok := rm[string(SelectedModelTypeLarge)].([]any)
 65	require.True(t, ok)
 66	require.Len(t, large, 1)
 67	item, ok := large[0].(map[string]any)
 68	require.True(t, ok)
 69	require.Equal(t, "openai", item["provider"])
 70	require.Equal(t, "gpt-4o", item["model"])
 71}
 72
 73func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) {
 74	t.Parallel()
 75
 76	svc, _ := newTestService(t)
 77	cfg := svc.cfg
 78
 79	// Add two entries
 80	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
 81	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"}))
 82	// Re-add first; should move to front and not duplicate
 83	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
 84
 85	got := cfg.RecentModels[SelectedModelTypeLarge]
 86	require.Len(t, got, 2)
 87	require.Equal(t, SelectedModel{Provider: "openai", Model: "gpt-4o"}, got[0])
 88	require.Equal(t, SelectedModel{Provider: "anthropic", Model: "claude"}, got[1])
 89}
 90
 91func TestRecordRecentModel_TrimsToMax(t *testing.T) {
 92	t.Parallel()
 93
 94	svc, storePath := newTestService(t)
 95	cfg := svc.cfg
 96
 97	// Insert 6 unique models; max is 5
 98	entries := []SelectedModel{
 99		{Provider: "p1", Model: "m1"},
100		{Provider: "p2", Model: "m2"},
101		{Provider: "p3", Model: "m3"},
102		{Provider: "p4", Model: "m4"},
103		{Provider: "p5", Model: "m5"},
104		{Provider: "p6", Model: "m6"},
105	}
106	for _, e := range entries {
107		require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, e))
108	}
109
110	// in-memory state
111	got := cfg.RecentModels[SelectedModelTypeLarge]
112	require.Len(t, got, 5)
113	// Newest first, capped at 5: p6..p2
114	require.Equal(t, SelectedModel{Provider: "p6", Model: "m6"}, got[0])
115	require.Equal(t, SelectedModel{Provider: "p5", Model: "m5"}, got[1])
116	require.Equal(t, SelectedModel{Provider: "p4", Model: "m4"}, got[2])
117	require.Equal(t, SelectedModel{Provider: "p3", Model: "m3"}, got[3])
118	require.Equal(t, SelectedModel{Provider: "p2", Model: "m2"}, got[4])
119
120	// persisted state: verify trimmed to 5 and newest-first order
121	rm := readRecentModels(t, storePath)
122	large, ok := rm[string(SelectedModelTypeLarge)].([]any)
123	require.True(t, ok)
124	require.Len(t, large, 5)
125	// Build provider:model IDs and verify order
126	var ids []string
127	for _, v := range large {
128		m := v.(map[string]any)
129		ids = append(ids, m["provider"].(string)+":"+m["model"].(string))
130	}
131	require.Equal(t, []string{"p6:m6", "p5:m5", "p4:m4", "p3:m3", "p2:m2"}, ids)
132}
133
134func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) {
135	t.Parallel()
136
137	svc, storePath := newTestService(t)
138	cfg := svc.cfg
139
140	// Missing provider
141	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"}))
142	// Missing model
143	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""}))
144
145	_, ok := cfg.RecentModels[SelectedModelTypeLarge]
146	// Map may be initialized, but should have no entries
147	if ok {
148		require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 0)
149	}
150	// No file should be written (stat via fs.FS)
151	baseDir := filepath.Dir(storePath)
152	fileName := filepath.Base(storePath)
153	_, err := fs.Stat(os.DirFS(baseDir), fileName)
154	require.True(t, os.IsNotExist(err))
155}
156
157func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) {
158	t.Parallel()
159
160	svc, storePath := newTestService(t)
161
162	entry := SelectedModel{Provider: "openai", Model: "gpt-4o"}
163	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, entry))
164
165	baseDir := filepath.Dir(storePath)
166	fileName := filepath.Base(storePath)
167	before, err := fs.ReadFile(os.DirFS(baseDir), fileName)
168	require.NoError(t, err)
169
170	// Get file ModTime to verify no write occurs
171	stBefore, err := fs.Stat(os.DirFS(baseDir), fileName)
172	require.NoError(t, err)
173	beforeMod := stBefore.ModTime()
174
175	// Re-record same entry should be a no-op (no write)
176	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, entry))
177
178	after, err := fs.ReadFile(os.DirFS(baseDir), fileName)
179	require.NoError(t, err)
180	require.Equal(t, string(before), string(after))
181
182	// Verify ModTime unchanged to ensure truly no write occurred
183	stAfter, err := fs.Stat(os.DirFS(baseDir), fileName)
184	require.NoError(t, err)
185	require.True(t, stAfter.ModTime().Equal(beforeMod), "file ModTime should not change on noop")
186}
187
188func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) {
189	t.Parallel()
190
191	svc, storePath := newTestService(t)
192	cfg := svc.cfg
193
194	sel := SelectedModel{Provider: "openai", Model: "gpt-4o"}
195	require.NoError(t, svc.UpdatePreferredModel(SelectedModelTypeSmall, sel))
196
197	// in-memory
198	require.Equal(t, sel, cfg.Models[SelectedModelTypeSmall])
199	require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1)
200
201	// persisted (read via fs.FS)
202	rm := readRecentModels(t, storePath)
203	small, ok := rm[string(SelectedModelTypeSmall)].([]any)
204	require.True(t, ok)
205	require.Len(t, small, 1)
206}
207
208func TestRecordRecentModel_TypeIsolation(t *testing.T) {
209	t.Parallel()
210
211	svc, storePath := newTestService(t)
212	cfg := svc.cfg
213
214	// Add models to both large and small types
215	largeModel := SelectedModel{Provider: "openai", Model: "gpt-4o"}
216	smallModel := SelectedModel{Provider: "anthropic", Model: "claude"}
217
218	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, largeModel))
219	require.NoError(t, svc.recordRecentModel(SelectedModelTypeSmall, smallModel))
220
221	// in-memory: verify types maintain separate histories
222	require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1)
223	require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1)
224	require.Equal(t, largeModel, cfg.RecentModels[SelectedModelTypeLarge][0])
225	require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0])
226
227	// Add another to large, verify small unchanged
228	anotherLarge := SelectedModel{Provider: "google", Model: "gemini"}
229	require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, anotherLarge))
230
231	require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 2)
232	require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1)
233	require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0])
234
235	// persisted state: verify both types exist with correct lengths and contents
236	rm := readRecentModels(t, storePath)
237
238	large, ok := rm[string(SelectedModelTypeLarge)].([]any)
239	require.True(t, ok)
240	require.Len(t, large, 2)
241	// Verify newest first for large type
242	require.Equal(t, "google", large[0].(map[string]any)["provider"])
243	require.Equal(t, "gemini", large[0].(map[string]any)["model"])
244	require.Equal(t, "openai", large[1].(map[string]any)["provider"])
245	require.Equal(t, "gpt-4o", large[1].(map[string]any)["model"])
246
247	small, ok := rm[string(SelectedModelTypeSmall)].([]any)
248	require.True(t, ok)
249	require.Len(t, small, 1)
250	require.Equal(t, "anthropic", small[0].(map[string]any)["provider"])
251	require.Equal(t, "claude", small[0].(map[string]any)["model"])
252}