provider_test.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"os"
  7	"sync"
  8	"testing"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/stretchr/testify/assert"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16type mockProviderClient struct {
 17	shouldFail bool
 18	data       []catwalk.Provider
 19}
 20
 21func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
 22	if m.shouldFail {
 23		return nil, errors.New("failed to load providers")
 24	}
 25	if len(m.data) > 0 {
 26		return m.data, nil
 27	}
 28	return []catwalk.Provider{
 29		{
 30			Name: "Mock",
 31		},
 32	}, nil
 33}
 34
 35func TestProvider_ProvidersReturnsFromClientIfNoCache(t *testing.T) {
 36	defer func() {
 37		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
 38		providerList = nil
 39	}()
 40	require.False(t, initialized)
 41	catwalkProviderData := []catwalk.Provider{
 42		{Name: "Mock1"},
 43		{Name: "Mock2"},
 44	}
 45	client := &mockProviderClient{shouldFail: false, data: catwalkProviderData}
 46
 47	autoUpdateDisabled := false // this doesn't matter within this test
 48	resolvedProviders, err := ProvidersWithClient(autoUpdateDisabled, client, "non-existent-cache-path")
 49
 50	require.NoError(t, err)
 51	assert.Equal(t, catwalkProviderData, resolvedProviders)
 52}
 53
 54func TestProvider_loadProvidersNoIssues(t *testing.T) {
 55	defer func() {
 56		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
 57		providerList = nil
 58	}()
 59	client := &mockProviderClient{shouldFail: false}
 60	tmpPath := t.TempDir() + "/providers.json"
 61
 62	providers, err := loadProviders(false, client, tmpPath)
 63	require.NoError(t, err)
 64	require.NotNil(t, providers)
 65	require.Len(t, providers, 1)
 66
 67	// check if file got saved
 68	fileInfo, err := os.Stat(tmpPath)
 69	require.NoError(t, err)
 70	require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
 71}
 72
 73func TestProvider_loadProvidersWithIssues(t *testing.T) {
 74	defer func() {
 75		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
 76		providerList = nil
 77	}()
 78	require.False(t, initialized)
 79	client := &mockProviderClient{shouldFail: true}
 80	tmpPath := t.TempDir() + "/providers.json"
 81	// store providers to a temporary file
 82	oldProviders := []catwalk.Provider{
 83		{
 84			Name: "OldProvider",
 85		},
 86	}
 87	data, err := json.Marshal(oldProviders)
 88	if err != nil {
 89		t.Fatalf("Failed to marshal old providers: %v", err)
 90	}
 91
 92	err = os.WriteFile(tmpPath, data, 0o644)
 93	if err != nil {
 94		t.Fatalf("Failed to write old providers to file: %v", err)
 95	}
 96
 97	providers, err := loadProviders(false, client, tmpPath)
 98	require.NoError(t, err)
 99	require.NotNil(t, providers)
100	require.Len(t, providers, 1)
101	require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
102}
103
104func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
105	defer func() {
106		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
107		providerList = nil
108	}()
109	client := &mockProviderClient{shouldFail: true}
110	tmpPath := t.TempDir() + "/providers.json"
111
112	providers, err := loadProviders(false, client, tmpPath)
113	require.Error(t, err)
114	require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
115}
116
117type dynamicMockProviderClient struct {
118	mu        sync.Mutex
119	callCount int
120	providers [][]catwalk.Provider
121}
122
123func (m *dynamicMockProviderClient) GetProviders() ([]catwalk.Provider, error) {
124	m.mu.Lock()
125	defer m.mu.Unlock()
126
127	if m.callCount >= len(m.providers) {
128		return m.providers[len(m.providers)-1], nil
129	}
130
131	result := m.providers[m.callCount]
132	m.callCount++
133	return result, nil
134}
135
136func TestProvider_backgroundReloadAfterCacheUpdate(t *testing.T) {
137	defer func() {
138		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
139		providerList = nil
140	}()
141
142	tmpPath := t.TempDir() + "/providers.json"
143
144	initialProviders := []catwalk.Provider{
145		{Name: "InitialProvider"},
146	}
147	updatedProviders := []catwalk.Provider{
148		{Name: "UpdatedProvider"},
149		{Name: "NewProvider"},
150	}
151
152	client := &dynamicMockProviderClient{
153		providers: [][]catwalk.Provider{
154			updatedProviders,
155		},
156	}
157
158	data, err := json.Marshal(initialProviders)
159	require.NoError(t, err)
160	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
161
162	providerMu.Lock()
163	oldInitialized := initialized
164	initialized = false
165	providerMu.Unlock()
166
167	defer func() {
168		providerMu.Lock()
169		initialized = oldInitialized
170		providerMu.Unlock()
171	}()
172
173	providers, err := loadProviders(false, client, tmpPath)
174	require.NoError(t, err)
175	require.NotNil(t, providers)
176	require.Len(t, providers, 1)
177	require.Equal(t, "InitialProvider", providers[0].Name)
178
179	require.Eventually(t, func() bool {
180		reloadedProviders, err := loadProvidersFromCache(tmpPath)
181		if err != nil {
182			return false
183		}
184		return len(reloadedProviders) == 2
185	}, 2*time.Second, 50*time.Millisecond, "Background cache update should complete within 2 seconds")
186
187	reloadedProviders, err := loadProvidersFromCache(tmpPath)
188	require.NoError(t, err)
189	require.Len(t, reloadedProviders, 2)
190	require.Equal(t, "UpdatedProvider", reloadedProviders[0].Name)
191	require.Equal(t, "NewProvider", reloadedProviders[1].Name)
192
193	require.Eventually(t, func() bool {
194		providerMu.RLock()
195		defer providerMu.RUnlock()
196		return len(providerList) == 2
197	}, 2*time.Second, 50*time.Millisecond, "In-memory provider list should be reloaded")
198
199	providerMu.RLock()
200	inMemoryProviders := providerList
201	providerMu.RUnlock()
202
203	require.Len(t, inMemoryProviders, 2)
204	require.Equal(t, "UpdatedProvider", inMemoryProviders[0].Name)
205	require.Equal(t, "NewProvider", inMemoryProviders[1].Name)
206}
207
208func TestProvider_reloadProvidersThreadSafety(t *testing.T) {
209	defer func() {
210		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
211		providerList = nil
212	}()
213
214	tmpPath := t.TempDir() + "/providers.json"
215
216	initialProviders := []catwalk.Provider{
217		{Name: "Provider1"},
218	}
219	data, err := json.Marshal(initialProviders)
220	require.NoError(t, err)
221	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
222
223	providerMu.Lock()
224	oldList := providerList
225	oldErr := providerErr
226	oldInitialized := initialized
227	providerList = initialProviders
228	providerErr = nil
229	initialized = true
230	providerMu.Unlock()
231
232	defer func() {
233		providerMu.Lock()
234		providerList = oldList
235		providerErr = oldErr
236		initialized = oldInitialized
237		providerMu.Unlock()
238	}()
239
240	var wg sync.WaitGroup
241	for i := 0; i < 10; i++ {
242		wg.Add(1)
243		go func(iteration int) {
244			defer wg.Done()
245
246			updatedProviders := []catwalk.Provider{
247				{Name: "Provider1"},
248				{Name: "Provider2"},
249			}
250			data, err := json.Marshal(updatedProviders)
251			require.NoError(t, err)
252			require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
253
254			reloadProviders(tmpPath)
255
256			providerMu.RLock()
257			currentList := providerList
258			providerMu.RUnlock()
259
260			require.NotNil(t, currentList)
261		}(i)
262	}
263
264	wg.Wait()
265
266	providerMu.RLock()
267	finalList := providerList
268	providerMu.RUnlock()
269
270	require.Len(t, finalList, 2)
271}
272
273func TestProvider_reloadProvidersWithEmptyCache(t *testing.T) {
274	defer func() {
275		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
276		providerList = nil
277	}()
278
279	tmpPath := t.TempDir() + "/providers.json"
280
281	initialProviders := []catwalk.Provider{
282		{Name: "InitialProvider"},
283	}
284
285	providerMu.Lock()
286	oldList := providerList
287	oldErr := providerErr
288	oldInitialized := initialized
289	providerList = initialProviders
290	providerErr = nil
291	initialized = true
292	providerMu.Unlock()
293
294	defer func() {
295		providerMu.Lock()
296		providerList = oldList
297		providerErr = oldErr
298		initialized = oldInitialized
299		providerMu.Unlock()
300	}()
301
302	emptyProviders := []catwalk.Provider{}
303	data, err := json.Marshal(emptyProviders)
304	require.NoError(t, err)
305	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
306
307	reloadProviders(tmpPath)
308
309	providerMu.RLock()
310	currentList := providerList
311	providerMu.RUnlock()
312
313	require.Len(t, currentList, 1)
314	require.Equal(t, "InitialProvider", currentList[0].Name)
315}
316
317func TestProvider_reloadProvidersWithInvalidCache(t *testing.T) {
318	defer func() {
319		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
320		providerList = nil
321	}()
322
323	tmpPath := t.TempDir() + "/providers.json"
324
325	initialProviders := []catwalk.Provider{
326		{Name: "InitialProvider"},
327	}
328
329	providerMu.Lock()
330	oldList := providerList
331	oldErr := providerErr
332	oldInitialized := initialized
333	providerList = initialProviders
334	providerErr = nil
335	initialized = true
336	providerMu.Unlock()
337
338	defer func() {
339		providerMu.Lock()
340		providerList = oldList
341		providerErr = oldErr
342		initialized = oldInitialized
343		providerMu.Unlock()
344	}()
345
346	require.NoError(t, os.WriteFile(tmpPath, []byte("invalid json"), 0o644))
347
348	reloadProviders(tmpPath)
349
350	providerMu.RLock()
351	currentList := providerList
352	providerMu.RUnlock()
353
354	require.Len(t, currentList, 1)
355	require.Equal(t, "InitialProvider", currentList[0].Name)
356}