provider_test.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"os"
  6	"path/filepath"
  7	"sync"
  8	"testing"
  9
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/stretchr/testify/require"
 12)
 13
 14func resetProviderState() {
 15	providerOnce = sync.Once{}
 16	providerList = nil
 17	providerErr = nil
 18	catwalkSyncer = &catwalkSync{}
 19	hyperSyncer = &hyperSync{}
 20}
 21
 22func TestProviders_Integration_AutoUpdateDisabled(t *testing.T) {
 23	tmpDir := t.TempDir()
 24	t.Setenv("XDG_DATA_HOME", tmpDir)
 25
 26	// Use a test-specific instance to avoid global state interference.
 27	testCatwalkSyncer := &catwalkSync{}
 28	testHyperSyncer := &hyperSync{}
 29
 30	originalCatwalSyncer := catwalkSyncer
 31	originalHyperSyncer := hyperSyncer
 32	defer func() {
 33		catwalkSyncer = originalCatwalSyncer
 34		hyperSyncer = originalHyperSyncer
 35	}()
 36
 37	catwalkSyncer = testCatwalkSyncer
 38	hyperSyncer = testHyperSyncer
 39
 40	resetProviderState()
 41	defer resetProviderState()
 42
 43	cfg := &Config{
 44		Options: &Options{
 45			DisableProviderAutoUpdate: true,
 46		},
 47	}
 48
 49	providers, err := Providers(cfg)
 50	require.NoError(t, err)
 51	require.NotNil(t, providers)
 52	require.Greater(t, len(providers), 5, "Expected embedded providers")
 53}
 54
 55func TestProviders_Integration_WithMockClients(t *testing.T) {
 56	tmpDir := t.TempDir()
 57	t.Setenv("XDG_DATA_HOME", tmpDir)
 58
 59	// Create fresh syncers for this test.
 60	testCatwalkSyncer := &catwalkSync{}
 61	testHyperSyncer := &hyperSync{}
 62
 63	// Initialize with mock clients.
 64	mockCatwalkClient := &mockCatwalkClient{
 65		providers: []catwalk.Provider{
 66			{Name: "Provider1", ID: "p1"},
 67			{Name: "Provider2", ID: "p2"},
 68		},
 69	}
 70	mockHyperClient := &mockHyperClient{
 71		provider: catwalk.Provider{
 72			Name: "Hyper",
 73			ID:   "hyper",
 74			Models: []catwalk.Model{
 75				{ID: "hyper-1", Name: "Hyper Model"},
 76			},
 77		},
 78	}
 79
 80	catwalkPath := tmpDir + "/crush/providers.json"
 81	hyperPath := tmpDir + "/crush/hyper.json"
 82
 83	testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
 84	testHyperSyncer.Init(mockHyperClient, hyperPath, true)
 85
 86	// Get providers from each syncer.
 87	catwalkProviders, err := testCatwalkSyncer.Get(t.Context())
 88	require.NoError(t, err)
 89	require.Len(t, catwalkProviders, 2)
 90
 91	hyperProvider, err := testHyperSyncer.Get(t.Context())
 92	require.NoError(t, err)
 93	require.Equal(t, "Hyper", hyperProvider.Name)
 94
 95	// Verify total.
 96	allProviders := append(catwalkProviders, hyperProvider)
 97	require.Len(t, allProviders, 3)
 98}
 99
100func TestProviders_Integration_WithCachedData(t *testing.T) {
101	tmpDir := t.TempDir()
102	t.Setenv("XDG_DATA_HOME", tmpDir)
103
104	// Create cache files.
105	catwalkPath := tmpDir + "/crush/providers.json"
106	hyperPath := tmpDir + "/crush/hyper.json"
107
108	require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
109
110	// Write Catwalk cache.
111	catwalkProviders := []catwalk.Provider{
112		{Name: "Cached1", ID: "c1"},
113		{Name: "Cached2", ID: "c2"},
114	}
115	data, err := json.Marshal(catwalkProviders)
116	require.NoError(t, err)
117	require.NoError(t, os.WriteFile(catwalkPath, data, 0o644))
118
119	// Write Hyper cache.
120	hyperProvider := catwalk.Provider{
121		Name: "Cached Hyper",
122		ID:   "hyper",
123	}
124	data, err = json.Marshal(hyperProvider)
125	require.NoError(t, err)
126	require.NoError(t, os.WriteFile(hyperPath, data, 0o644))
127
128	// Create fresh syncers.
129	testCatwalkSyncer := &catwalkSync{}
130	testHyperSyncer := &hyperSync{}
131
132	// Mock clients that return ErrNotModified.
133	mockCatwalkClient := &mockCatwalkClient{
134		err: catwalk.ErrNotModified,
135	}
136	mockHyperClient := &mockHyperClient{
137		err: catwalk.ErrNotModified,
138	}
139
140	testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
141	testHyperSyncer.Init(mockHyperClient, hyperPath, true)
142
143	// Get providers - should use cached.
144	catwalkResult, err := testCatwalkSyncer.Get(t.Context())
145	require.NoError(t, err)
146	require.Len(t, catwalkResult, 2)
147	require.Equal(t, "Cached1", catwalkResult[0].Name)
148
149	hyperResult, err := testHyperSyncer.Get(t.Context())
150	require.NoError(t, err)
151	require.Equal(t, "Cached Hyper", hyperResult.Name)
152}
153
154func TestProviders_Integration_CatwalkFailsHyperSucceeds(t *testing.T) {
155	tmpDir := t.TempDir()
156	t.Setenv("XDG_DATA_HOME", tmpDir)
157
158	testCatwalkSyncer := &catwalkSync{}
159	testHyperSyncer := &hyperSync{}
160
161	// Catwalk fails, Hyper succeeds.
162	mockCatwalkClient := &mockCatwalkClient{
163		err: catwalk.ErrNotModified, // Will use embedded.
164	}
165	mockHyperClient := &mockHyperClient{
166		provider: catwalk.Provider{
167			Name: "Hyper",
168			ID:   "hyper",
169			Models: []catwalk.Model{
170				{ID: "hyper-1", Name: "Hyper Model"},
171			},
172		},
173	}
174
175	catwalkPath := tmpDir + "/crush/providers.json"
176	hyperPath := tmpDir + "/crush/hyper.json"
177
178	testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
179	testHyperSyncer.Init(mockHyperClient, hyperPath, true)
180
181	catwalkResult, err := testCatwalkSyncer.Get(t.Context())
182	require.NoError(t, err)
183	require.NotEmpty(t, catwalkResult) // Should have embedded.
184
185	hyperResult, err := testHyperSyncer.Get(t.Context())
186	require.NoError(t, err)
187	require.Equal(t, "Hyper", hyperResult.Name)
188}
189
190func TestProviders_Integration_BothFail(t *testing.T) {
191	tmpDir := t.TempDir()
192	t.Setenv("XDG_DATA_HOME", tmpDir)
193
194	testCatwalkSyncer := &catwalkSync{}
195	testHyperSyncer := &hyperSync{}
196
197	// Both fail.
198	mockCatwalkClient := &mockCatwalkClient{
199		err: catwalk.ErrNotModified,
200	}
201	mockHyperClient := &mockHyperClient{
202		provider: catwalk.Provider{}, // Empty provider.
203	}
204
205	catwalkPath := tmpDir + "/crush/providers.json"
206	hyperPath := tmpDir + "/crush/hyper.json"
207
208	testCatwalkSyncer.Init(mockCatwalkClient, catwalkPath, true)
209	testHyperSyncer.Init(mockHyperClient, hyperPath, true)
210
211	catwalkResult, err := testCatwalkSyncer.Get(t.Context())
212	require.NoError(t, err)
213	require.NotEmpty(t, catwalkResult) // Should fall back to embedded.
214
215	hyperResult, err := testHyperSyncer.Get(t.Context())
216	require.NoError(t, err)
217	require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models.
218}
219
220func TestCache_StoreAndGet(t *testing.T) {
221	t.Parallel()
222
223	tmpDir := t.TempDir()
224	cachePath := tmpDir + "/test.json"
225
226	cache := newCache[[]catwalk.Provider](cachePath)
227
228	providers := []catwalk.Provider{
229		{Name: "Provider1", ID: "p1"},
230		{Name: "Provider2", ID: "p2"},
231	}
232
233	// Store.
234	err := cache.Store(providers)
235	require.NoError(t, err)
236
237	// Get.
238	result, etag, err := cache.Get()
239	require.NoError(t, err)
240	require.Len(t, result, 2)
241	require.Equal(t, "Provider1", result[0].Name)
242	require.NotEmpty(t, etag)
243}
244
245func TestCache_GetNonExistent(t *testing.T) {
246	t.Parallel()
247
248	tmpDir := t.TempDir()
249	cachePath := tmpDir + "/nonexistent.json"
250
251	cache := newCache[[]catwalk.Provider](cachePath)
252
253	_, _, err := cache.Get()
254	require.Error(t, err)
255	require.Contains(t, err.Error(), "failed to read provider cache file")
256}
257
258func TestCache_GetInvalidJSON(t *testing.T) {
259	t.Parallel()
260
261	tmpDir := t.TempDir()
262	cachePath := tmpDir + "/invalid.json"
263
264	require.NoError(t, os.WriteFile(cachePath, []byte("invalid json"), 0o644))
265
266	cache := newCache[[]catwalk.Provider](cachePath)
267
268	_, _, err := cache.Get()
269	require.Error(t, err)
270	require.Contains(t, err.Error(), "failed to unmarshal provider data from cache")
271}
272
273func TestCachePathFor(t *testing.T) {
274	tests := []struct {
275		name        string
276		xdgDataHome string
277		expected    string
278	}{
279		{
280			name:        "with XDG_DATA_HOME",
281			xdgDataHome: "/custom/data",
282			expected:    "/custom/data/crush/providers.json",
283		},
284		{
285			name:        "without XDG_DATA_HOME",
286			xdgDataHome: "",
287			expected:    "", // Will use platform-specific default.
288		},
289	}
290
291	for _, tt := range tests {
292		t.Run(tt.name, func(t *testing.T) {
293			if tt.xdgDataHome != "" {
294				t.Setenv("XDG_DATA_HOME", tt.xdgDataHome)
295			} else {
296				t.Setenv("XDG_DATA_HOME", "")
297			}
298
299			result := cachePathFor("providers")
300			if tt.expected != "" {
301				require.Equal(t, tt.expected, filepath.ToSlash(result))
302			} else {
303				require.Contains(t, result, "crush")
304				require.Contains(t, result, "providers.json")
305			}
306		})
307	}
308}