1package config
2
3import (
4 "encoding/json"
5 "os"
6 "path/filepath"
7 "sync"
8 "testing"
9
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/stretchr/testify/require"
12)
13
14func resetProviderState() {
15 providerOnce = sync.Once{}
16 providerList = nil
17 providerErr = nil
18 catwalkSyncer = &catwalkSync{}
19 hyperSyncer = &hyperSync{}
20}
21
22func TestProviders_Integration_AutoUpdateDisabled(t *testing.T) {
23 tmpDir := t.TempDir()
24 t.Setenv("XDG_DATA_HOME", tmpDir)
25
26 // Use a test-specific instance to avoid global state interference.
27 testCatwalkSyncer := &catwalkSync{}
28 testHyperSyncer := &hyperSync{}
29
30 originalCatwalSyncer := catwalkSyncer
31 originalHyperSyncer := hyperSyncer
32 defer func() {
33 catwalkSyncer = originalCatwalSyncer
34 hyperSyncer = originalHyperSyncer
35 }()
36
37 catwalkSyncer = testCatwalkSyncer
38 hyperSyncer = testHyperSyncer
39
40 resetProviderState()
41 defer resetProviderState()
42
43 cfg := &Config{
44 Options: &Options{
45 DisableProviderAutoUpdate: true,
46 },
47 }
48
49 providers, err := Providers(cfg)
50 require.NoError(t, err)
51 require.NotNil(t, providers)
52 require.Greater(t, len(providers), 5, "Expected embedded providers")
53}
54
55func TestProviders_Integration_WithMockClients(t *testing.T) {
56 tmpDir := t.TempDir()
57 t.Setenv("XDG_DATA_HOME", tmpDir)
58
59 // Create fresh syncers for this test.
60 testCatwalkSyncer := &catwalkSync{}
61 testHyperSyncer := &hyperSync{}
62
63 // Initialize with mock clients.
64 mockCatwalkClient := &mockCatwalkClient{
65 providers: []catwalk.Provider{
66 {Name: "Provider1", ID: "p1"},
67 {Name: "Provider2", ID: "p2"},
68 },
69 }
70 mockHyperClient := &mockHyperClient{
71 provider: catwalk.Provider{
72 Name: "Hyper",
73 ID: "hyper",
74 Models: []catwalk.Model{
75 {ID: "hyper-1", Name: "Hyper Model"},
76 },
77 },
78 }
79
80 catwalkPath := tmpDir + "/crush/providers.json"
81 hyperPath := tmpDir + "/crush/hyper.json"
82
83 testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
84 testHyperSyncer.Init(mockHyperClient, hyperPath, true)
85
86 // Get providers from each syncer.
87 catwalkProviders, err := testCatwalkSyncer.Get(t.Context())
88 require.NoError(t, err)
89 require.Len(t, catwalkProviders, 2)
90
91 hyperProvider, err := testHyperSyncer.Get(t.Context())
92 require.NoError(t, err)
93 require.Equal(t, "Hyper", hyperProvider.Name)
94
95 // Verify total.
96 allProviders := append(catwalkProviders, hyperProvider)
97 require.Len(t, allProviders, 3)
98}
99
100func TestProviders_Integration_WithCachedData(t *testing.T) {
101 tmpDir := t.TempDir()
102 t.Setenv("XDG_DATA_HOME", tmpDir)
103
104 // Create cache files.
105 catwalkPath := tmpDir + "/crush/providers.json"
106 hyperPath := tmpDir + "/crush/hyper.json"
107
108 require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
109
110 // Write Catwalk cache.
111 catwalkProviders := []catwalk.Provider{
112 {Name: "Cached1", ID: "c1"},
113 {Name: "Cached2", ID: "c2"},
114 }
115 data, err := json.Marshal(catwalkProviders)
116 require.NoError(t, err)
117 require.NoError(t, os.WriteFile(catwalkPath, data, 0o644))
118
119 // Write Hyper cache.
120 hyperProvider := catwalk.Provider{
121 Name: "Cached Hyper",
122 ID: "hyper",
123 }
124 data, err = json.Marshal(hyperProvider)
125 require.NoError(t, err)
126 require.NoError(t, os.WriteFile(hyperPath, data, 0o644))
127
128 // Create fresh syncers.
129 testCatwalkSyncer := &catwalkSync{}
130 testHyperSyncer := &hyperSync{}
131
132 // Mock clients that return ErrNotModified.
133 mockCatwalkClient := &mockCatwalkClient{
134 err: catwalk.ErrNotModified,
135 }
136 mockHyperClient := &mockHyperClient{
137 err: catwalk.ErrNotModified,
138 }
139
140 testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
141 testHyperSyncer.Init(mockHyperClient, hyperPath, true)
142
143 // Get providers - should use cached.
144 catwalkResult, err := testCatwalkSyncer.Get(t.Context())
145 require.NoError(t, err)
146 require.Len(t, catwalkResult, 2)
147 require.Equal(t, "Cached1", catwalkResult[0].Name)
148
149 hyperResult, err := testHyperSyncer.Get(t.Context())
150 require.NoError(t, err)
151 require.Equal(t, "Cached Hyper", hyperResult.Name)
152}
153
154func TestProviders_Integration_CatwalkFailsHyperSucceeds(t *testing.T) {
155 tmpDir := t.TempDir()
156 t.Setenv("XDG_DATA_HOME", tmpDir)
157
158 testCatwalkSyncer := &catwalkSync{}
159 testHyperSyncer := &hyperSync{}
160
161 // Catwalk fails, Hyper succeeds.
162 mockCatwalkClient := &mockCatwalkClient{
163 err: catwalk.ErrNotModified, // Will use embedded.
164 }
165 mockHyperClient := &mockHyperClient{
166 provider: catwalk.Provider{
167 Name: "Hyper",
168 ID: "hyper",
169 Models: []catwalk.Model{
170 {ID: "hyper-1", Name: "Hyper Model"},
171 },
172 },
173 }
174
175 catwalkPath := tmpDir + "/crush/providers.json"
176 hyperPath := tmpDir + "/crush/hyper.json"
177
178 testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
179 testHyperSyncer.Init(mockHyperClient, hyperPath, true)
180
181 catwalkResult, err := testCatwalkSyncer.Get(t.Context())
182 require.NoError(t, err)
183 require.NotEmpty(t, catwalkResult) // Should have embedded.
184
185 hyperResult, err := testHyperSyncer.Get(t.Context())
186 require.NoError(t, err)
187 require.Equal(t, "Hyper", hyperResult.Name)
188}
189
190func TestProviders_Integration_BothFail(t *testing.T) {
191 tmpDir := t.TempDir()
192 t.Setenv("XDG_DATA_HOME", tmpDir)
193
194 testCatwalkSyncer := &catwalkSync{}
195 testHyperSyncer := &hyperSync{}
196
197 // Both fail.
198 mockCatwalkClient := &mockCatwalkClient{
199 err: catwalk.ErrNotModified,
200 }
201 mockHyperClient := &mockHyperClient{
202 provider: catwalk.Provider{}, // Empty provider.
203 }
204
205 catwalkPath := tmpDir + "/crush/providers.json"
206 hyperPath := tmpDir + "/crush/hyper.json"
207
208 testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
209 testHyperSyncer.Init(mockHyperClient, hyperPath, true)
210
211 catwalkResult, err := testCatwalkSyncer.Get(t.Context())
212 require.NoError(t, err)
213 require.NotEmpty(t, catwalkResult) // Should fall back to embedded.
214
215 hyperResult, err := testHyperSyncer.Get(t.Context())
216 require.NoError(t, err)
217 require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models.
218}
219
220func TestCache_StoreAndGet(t *testing.T) {
221 t.Parallel()
222
223 tmpDir := t.TempDir()
224 cachePath := tmpDir + "/test.json"
225
226 cache := newCache[[]catwalk.Provider](cachePath)
227
228 providers := []catwalk.Provider{
229 {Name: "Provider1", ID: "p1"},
230 {Name: "Provider2", ID: "p2"},
231 }
232
233 // Store.
234 err := cache.Store(providers)
235 require.NoError(t, err)
236
237 // Get.
238 result, etag, err := cache.Get()
239 require.NoError(t, err)
240 require.Len(t, result, 2)
241 require.Equal(t, "Provider1", result[0].Name)
242 require.NotEmpty(t, etag)
243}
244
245func TestCache_GetNonExistent(t *testing.T) {
246 t.Parallel()
247
248 tmpDir := t.TempDir()
249 cachePath := tmpDir + "/nonexistent.json"
250
251 cache := newCache[[]catwalk.Provider](cachePath)
252
253 _, _, err := cache.Get()
254 require.Error(t, err)
255 require.Contains(t, err.Error(), "failed to read provider cache file")
256}
257
258func TestCache_GetInvalidJSON(t *testing.T) {
259 t.Parallel()
260
261 tmpDir := t.TempDir()
262 cachePath := tmpDir + "/invalid.json"
263
264 require.NoError(t, os.WriteFile(cachePath, []byte("invalid json"), 0o644))
265
266 cache := newCache[[]catwalk.Provider](cachePath)
267
268 _, _, err := cache.Get()
269 require.Error(t, err)
270 require.Contains(t, err.Error(), "failed to unmarshal provider data from cache")
271}
272
273func TestCachePathFor(t *testing.T) {
274 tests := []struct {
275 name string
276 xdgDataHome string
277 expected string
278 }{
279 {
280 name: "with XDG_DATA_HOME",
281 xdgDataHome: "/custom/data",
282 expected: "/custom/data/crush/providers.json",
283 },
284 {
285 name: "without XDG_DATA_HOME",
286 xdgDataHome: "",
287 expected: "", // Will use platform-specific default.
288 },
289 }
290
291 for _, tt := range tests {
292 t.Run(tt.name, func(t *testing.T) {
293 if tt.xdgDataHome != "" {
294 t.Setenv("XDG_DATA_HOME", tt.xdgDataHome)
295 } else {
296 t.Setenv("XDG_DATA_HOME", "")
297 }
298
299 result := cachePathFor("providers")
300 if tt.expected != "" {
301 require.Equal(t, tt.expected, filepath.ToSlash(result))
302 } else {
303 require.Contains(t, result, "crush")
304 require.Contains(t, result, "providers.json")
305 }
306 })
307 }
308}