1package models
2
3import (
4 "encoding/json"
5 "io/fs"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 tea "github.com/charmbracelet/bubbletea/v2"
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/tui/exp/list"
15 "github.com/stretchr/testify/require"
16)
17
18// execCmdML runs a tea.Cmd through the ModelListComponent's Update loop.
19func execCmdML(t *testing.T, m *ModelListComponent, cmd tea.Cmd) {
20 t.Helper()
21 for cmd != nil {
22 msg := cmd()
23 var next tea.Cmd
24 _, next = m.Update(msg)
25 cmd = next
26 }
27}
28
29// readConfigJSON reads and unmarshals the JSON config file at path.
30func readConfigJSON(t *testing.T, path string) map[string]any {
31 t.Helper()
32 baseDir := filepath.Dir(path)
33 fileName := filepath.Base(path)
34 b, err := fs.ReadFile(os.DirFS(baseDir), fileName)
35 require.NoError(t, err)
36 var out map[string]any
37 require.NoError(t, json.Unmarshal(b, &out))
38 return out
39}
40
41// readRecentModels reads the recent_models section from the config file.
42func readRecentModels(t *testing.T, path string) map[string]any {
43 t.Helper()
44 out := readConfigJSON(t, path)
45 rm, ok := out["recent_models"].(map[string]any)
46 require.True(t, ok)
47 return rm
48}
49
50func TestModelList_RecentlyUsedSectionAndPrunesInvalid(t *testing.T) {
51 // Isolate config/data paths
52 cfgDir := t.TempDir()
53 dataDir := t.TempDir()
54 t.Setenv("XDG_CONFIG_HOME", cfgDir)
55 t.Setenv("XDG_DATA_HOME", dataDir)
56
57 // Pre-seed config so provider auto-update is disabled and we have recents
58 confPath := filepath.Join(cfgDir, "crush", "crush.json")
59 require.NoError(t, os.MkdirAll(filepath.Dir(confPath), 0o755))
60 initial := map[string]any{
61 "options": map[string]any{
62 "disable_provider_auto_update": true,
63 },
64 "models": map[string]any{
65 "large": map[string]any{
66 "model": "m1",
67 "provider": "p1",
68 },
69 },
70 "recent_models": map[string]any{
71 "large": []any{
72 map[string]any{"model": "m2", "provider": "p1"}, // valid
73 map[string]any{"model": "x", "provider": "unknown-provider"}, // invalid -> pruned
74 },
75 },
76 }
77 bts, err := json.Marshal(initial)
78 require.NoError(t, err)
79 require.NoError(t, os.WriteFile(confPath, bts, 0o644))
80
81 // Also create empty providers.json to prevent loading real providers
82 dataConfDir := filepath.Join(dataDir, "crush")
83 require.NoError(t, os.MkdirAll(dataConfDir, 0o755))
84 emptyProviders := []byte("[]")
85 require.NoError(t, os.WriteFile(filepath.Join(dataConfDir, "providers.json"), emptyProviders, 0o644))
86
87 // Initialize global config instance (no network due to auto-update disabled)
88 _, err = config.Init(cfgDir, dataDir, false)
89 require.NoError(t, err)
90
91 // Build a small provider set for the list component
92 provider := catwalk.Provider{
93 ID: catwalk.InferenceProvider("p1"),
94 Name: "Provider One",
95 Models: []catwalk.Model{
96 {ID: "m1", Name: "Model One", DefaultMaxTokens: 100},
97 {ID: "m2", Name: "Model Two", DefaultMaxTokens: 100}, // recent
98 },
99 }
100
101 // Create and initialize the component with our provider set
102 listKeyMap := list.DefaultKeyMap()
103 cmp := NewModelListComponent(listKeyMap, "Find your fave", false)
104 cmp.providers = []catwalk.Provider{provider}
105 execCmdML(t, cmp, cmp.Init())
106
107 // Find all recent items (IDs prefixed with "recent::") and verify pruning
108 groups := cmp.list.Groups()
109 require.NotEmpty(t, groups)
110 var recentItems []list.CompletionItem[ModelOption]
111 for _, g := range groups {
112 for _, it := range g.Items {
113 if strings.HasPrefix(it.ID(), "recent::") {
114 recentItems = append(recentItems, it)
115 }
116 }
117 }
118 require.NotEmpty(t, recentItems, "no recent items found")
119 // Ensure the valid recent (p1:m2) is present and the invalid one is not
120 foundValid := false
121 for _, it := range recentItems {
122 if it.ID() == "recent::p1:m2" {
123 foundValid = true
124 }
125 require.NotEqual(t, "recent::unknown-provider:x", it.ID(), "invalid recent should be pruned")
126 }
127 require.True(t, foundValid, "expected valid recent not found")
128
129 // Verify original config in cfgDir remains unchanged
130 origConfPath := filepath.Join(cfgDir, "crush", "crush.json")
131 afterOrig, err := fs.ReadFile(os.DirFS(filepath.Dir(origConfPath)), filepath.Base(origConfPath))
132 require.NoError(t, err)
133 var origParsed map[string]any
134 require.NoError(t, json.Unmarshal(afterOrig, &origParsed))
135 origRM := origParsed["recent_models"].(map[string]any)
136 origLarge := origRM["large"].([]any)
137 require.Len(t, origLarge, 2, "original config should be unchanged")
138
139 // Config should be rewritten with pruned recents in dataDir
140 dataConf := filepath.Join(dataDir, "crush", "crush.json")
141 rm := readRecentModels(t, dataConf)
142 largeAny, ok := rm["large"].([]any)
143 require.True(t, ok)
144 // Ensure that only valid recent(s) remain and the invalid one is removed
145 found := false
146 for _, v := range largeAny {
147 m := v.(map[string]any)
148 require.NotEqual(t, "unknown-provider", m["provider"], "invalid provider should be pruned")
149 if m["provider"] == "p1" && m["model"] == "m2" {
150 found = true
151 }
152 }
153 require.True(t, found, "persisted recents should include p1:m2")
154}
155
156func TestModelList_PrunesInvalidModelWithinValidProvider(t *testing.T) {
157 // Isolate config/data paths
158 cfgDir := t.TempDir()
159 dataDir := t.TempDir()
160 t.Setenv("XDG_CONFIG_HOME", cfgDir)
161 t.Setenv("XDG_DATA_HOME", dataDir)
162
163 // Pre-seed config with valid provider but one invalid model
164 confPath := filepath.Join(cfgDir, "crush", "crush.json")
165 require.NoError(t, os.MkdirAll(filepath.Dir(confPath), 0o755))
166 initial := map[string]any{
167 "options": map[string]any{
168 "disable_provider_auto_update": true,
169 },
170 "models": map[string]any{
171 "large": map[string]any{
172 "model": "m1",
173 "provider": "p1",
174 },
175 },
176 "recent_models": map[string]any{
177 "large": []any{
178 map[string]any{"model": "m1", "provider": "p1"}, // valid
179 map[string]any{"model": "missing", "provider": "p1"}, // invalid model
180 },
181 },
182 }
183 bts, err := json.Marshal(initial)
184 require.NoError(t, err)
185 require.NoError(t, os.WriteFile(confPath, bts, 0o644))
186
187 // Create empty providers.json
188 dataConfDir := filepath.Join(dataDir, "crush")
189 require.NoError(t, os.MkdirAll(dataConfDir, 0o755))
190 emptyProviders := []byte("[]")
191 require.NoError(t, os.WriteFile(filepath.Join(dataConfDir, "providers.json"), emptyProviders, 0o644))
192
193 // Initialize global config instance
194 _, err = config.Init(cfgDir, dataDir, false)
195 require.NoError(t, err)
196
197 // Build provider set that only includes m1, not "missing"
198 provider := catwalk.Provider{
199 ID: catwalk.InferenceProvider("p1"),
200 Name: "Provider One",
201 Models: []catwalk.Model{
202 {ID: "m1", Name: "Model One", DefaultMaxTokens: 100},
203 },
204 }
205
206 // Create and initialize component
207 listKeyMap := list.DefaultKeyMap()
208 cmp := NewModelListComponent(listKeyMap, "Find your fave", false)
209 cmp.providers = []catwalk.Provider{provider}
210 execCmdML(t, cmp, cmp.Init())
211
212 // Find all recent items
213 groups := cmp.list.Groups()
214 require.NotEmpty(t, groups)
215 var recentItems []list.CompletionItem[ModelOption]
216 for _, g := range groups {
217 for _, it := range g.Items {
218 if strings.HasPrefix(it.ID(), "recent::") {
219 recentItems = append(recentItems, it)
220 }
221 }
222 }
223 require.NotEmpty(t, recentItems, "valid recent should exist")
224
225 // Verify the valid recent is present and invalid model is not
226 foundValid := false
227 for _, it := range recentItems {
228 if it.ID() == "recent::p1:m1" {
229 foundValid = true
230 }
231 require.NotEqual(t, "recent::p1:missing", it.ID(), "invalid model should be pruned")
232 }
233 require.True(t, foundValid, "valid recent p1:m1 should be present")
234
235 // Verify original config in cfgDir remains unchanged
236 origConfPath := filepath.Join(cfgDir, "crush", "crush.json")
237 afterOrig, err := fs.ReadFile(os.DirFS(filepath.Dir(origConfPath)), filepath.Base(origConfPath))
238 require.NoError(t, err)
239 var origParsed map[string]any
240 require.NoError(t, json.Unmarshal(afterOrig, &origParsed))
241 origRM := origParsed["recent_models"].(map[string]any)
242 origLarge := origRM["large"].([]any)
243 require.Len(t, origLarge, 2, "original config should be unchanged")
244
245 // Config should be rewritten with pruned recents in dataDir
246 dataConf := filepath.Join(dataDir, "crush", "crush.json")
247 rm := readRecentModels(t, dataConf)
248 largeAny, ok := rm["large"].([]any)
249 require.True(t, ok)
250 require.Len(t, largeAny, 1, "should only have one valid model")
251 // Verify only p1:m1 remains
252 m := largeAny[0].(map[string]any)
253 require.Equal(t, "p1", m["provider"])
254 require.Equal(t, "m1", m["model"])
255}
256
257func TestModelKey_EmptyInputs(t *testing.T) {
258 // Empty provider
259 require.Equal(t, "", modelKey("", "model"))
260 // Empty model
261 require.Equal(t, "", modelKey("provider", ""))
262 // Both empty
263 require.Equal(t, "", modelKey("", ""))
264 // Valid inputs
265 require.Equal(t, "p:m", modelKey("p", "m"))
266}
267
268func TestModelList_AllRecentsInvalid(t *testing.T) {
269 // Isolate config/data paths
270 cfgDir := t.TempDir()
271 dataDir := t.TempDir()
272 t.Setenv("XDG_CONFIG_HOME", cfgDir)
273 t.Setenv("XDG_DATA_HOME", dataDir)
274
275 // Pre-seed config with only invalid recents
276 confPath := filepath.Join(cfgDir, "crush", "crush.json")
277 require.NoError(t, os.MkdirAll(filepath.Dir(confPath), 0o755))
278 initial := map[string]any{
279 "options": map[string]any{
280 "disable_provider_auto_update": true,
281 },
282 "models": map[string]any{
283 "large": map[string]any{
284 "model": "m1",
285 "provider": "p1",
286 },
287 },
288 "recent_models": map[string]any{
289 "large": []any{
290 map[string]any{"model": "x", "provider": "unknown1"},
291 map[string]any{"model": "y", "provider": "unknown2"},
292 },
293 },
294 }
295 bts, err := json.Marshal(initial)
296 require.NoError(t, err)
297 require.NoError(t, os.WriteFile(confPath, bts, 0o644))
298
299 // Also create empty providers.json and data config
300 dataConfDir := filepath.Join(dataDir, "crush")
301 require.NoError(t, os.MkdirAll(dataConfDir, 0o755))
302 emptyProviders := []byte("[]")
303 require.NoError(t, os.WriteFile(filepath.Join(dataConfDir, "providers.json"), emptyProviders, 0o644))
304
305 // Initialize global config instance with isolated dataDir
306 _, err = config.Init(cfgDir, dataDir, false)
307 require.NoError(t, err)
308
309 // Build provider set (doesn't include unknown1 or unknown2)
310 provider := catwalk.Provider{
311 ID: catwalk.InferenceProvider("p1"),
312 Name: "Provider One",
313 Models: []catwalk.Model{
314 {ID: "m1", Name: "Model One", DefaultMaxTokens: 100},
315 },
316 }
317
318 // Create and initialize component
319 listKeyMap := list.DefaultKeyMap()
320 cmp := NewModelListComponent(listKeyMap, "Find your fave", false)
321 cmp.providers = []catwalk.Provider{provider}
322 execCmdML(t, cmp, cmp.Init())
323
324 // Verify no recent items exist in UI
325 groups := cmp.list.Groups()
326 require.NotEmpty(t, groups)
327 var recentItems []list.CompletionItem[ModelOption]
328 for _, g := range groups {
329 for _, it := range g.Items {
330 if strings.HasPrefix(it.ID(), "recent::") {
331 recentItems = append(recentItems, it)
332 }
333 }
334 }
335 require.Empty(t, recentItems, "all invalid recents should be pruned, resulting in no recent section")
336
337 // Verify original config in cfgDir remains unchanged
338 origConfPath := filepath.Join(cfgDir, "crush", "crush.json")
339 afterOrig, err := fs.ReadFile(os.DirFS(filepath.Dir(origConfPath)), filepath.Base(origConfPath))
340 require.NoError(t, err)
341 var origParsed map[string]any
342 require.NoError(t, json.Unmarshal(afterOrig, &origParsed))
343 origRM := origParsed["recent_models"].(map[string]any)
344 origLarge := origRM["large"].([]any)
345 require.Len(t, origLarge, 2, "original config should be unchanged")
346
347 // Config should be rewritten with empty recents in dataDir
348 dataConf := filepath.Join(dataDir, "crush", "crush.json")
349 rm := readRecentModels(t, dataConf)
350 // When all recents are pruned, the value may be nil or an empty array
351 largeVal := rm["large"]
352 if largeVal == nil {
353 // nil is acceptable - means empty
354 return
355 }
356 largeAny, ok := largeVal.([]any)
357 require.True(t, ok, "large key should be nil or array")
358 require.Empty(t, largeAny, "persisted recents should be empty after pruning all invalid entries")
359}