provider_test.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"os"
  7	"sync"
  8	"testing"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/stretchr/testify/assert"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16type mockProviderClient struct {
 17	shouldFail bool
 18	data       []catwalk.Provider
 19}
 20
 21func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
 22	if m.shouldFail {
 23		return nil, errors.New("failed to load providers")
 24	}
 25	if len(m.data) > 0 {
 26		return m.data, nil
 27	}
 28	return []catwalk.Provider{
 29		{
 30			Name: "Mock",
 31		},
 32	}, nil
 33}
 34
 35func TestProvider_ProvidersReturnsFromClientIfNoCache(t *testing.T) {
 36	defer func() {
 37		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
 38	}()
 39	require.False(t, initialized)
 40	catwalkProviderData := []catwalk.Provider{
 41		{Name: "Mock1"},
 42		{Name: "Mock2"},
 43	}
 44	client := &mockProviderClient{shouldFail: false, data: catwalkProviderData}
 45
 46	autoUpdateDisabled := false // this doesn't matter within this test
 47	resolvedProviders, err := ProvidersWithClient(autoUpdateDisabled, client, "non-existent-cache-path")
 48
 49	require.NoError(t, err)
 50	assert.Equal(t, catwalkProviderData, resolvedProviders)
 51}
 52
 53func TestProvider_loadProvidersNoIssues(t *testing.T) {
 54	defer func() {
 55		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
 56	}()
 57	client := &mockProviderClient{shouldFail: false}
 58	tmpPath := t.TempDir() + "/providers.json"
 59
 60	providers, err := loadProviders(false, client, tmpPath)
 61	require.NoError(t, err)
 62	require.NotNil(t, providers)
 63	require.Len(t, providers, 1)
 64
 65	// check if file got saved
 66	fileInfo, err := os.Stat(tmpPath)
 67	require.NoError(t, err)
 68	require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
 69}
 70
 71func TestProvider_loadProvidersWithIssues(t *testing.T) {
 72	defer func() {
 73		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
 74	}()
 75	require.False(t, initialized)
 76	client := &mockProviderClient{shouldFail: true}
 77	tmpPath := t.TempDir() + "/providers.json"
 78	// store providers to a temporary file
 79	oldProviders := []catwalk.Provider{
 80		{
 81			Name: "OldProvider",
 82		},
 83	}
 84	data, err := json.Marshal(oldProviders)
 85	if err != nil {
 86		t.Fatalf("Failed to marshal old providers: %v", err)
 87	}
 88
 89	err = os.WriteFile(tmpPath, data, 0o644)
 90	if err != nil {
 91		t.Fatalf("Failed to write old providers to file: %v", err)
 92	}
 93
 94	providers, err := loadProviders(false, client, tmpPath)
 95	require.NoError(t, err)
 96	require.NotNil(t, providers)
 97	require.Len(t, providers, 1)
 98	require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
 99}
100
101func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
102	defer func() {
103		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
104	}()
105	client := &mockProviderClient{shouldFail: true}
106	tmpPath := t.TempDir() + "/providers.json"
107
108	providers, err := loadProviders(false, client, tmpPath)
109	require.Error(t, err)
110	require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
111}
112
113type dynamicMockProviderClient struct {
114	mu        sync.Mutex
115	callCount int
116	providers [][]catwalk.Provider
117}
118
119func (m *dynamicMockProviderClient) GetProviders() ([]catwalk.Provider, error) {
120	m.mu.Lock()
121	defer m.mu.Unlock()
122
123	if m.callCount >= len(m.providers) {
124		return m.providers[len(m.providers)-1], nil
125	}
126
127	result := m.providers[m.callCount]
128	m.callCount++
129	return result, nil
130}
131
132func TestProvider_backgroundReloadAfterCacheUpdate(t *testing.T) {
133	defer func() {
134		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
135		providerList = nil
136	}()
137
138	tmpPath := t.TempDir() + "/providers.json"
139
140	initialProviders := []catwalk.Provider{
141		{Name: "InitialProvider"},
142	}
143	updatedProviders := []catwalk.Provider{
144		{Name: "UpdatedProvider"},
145		{Name: "NewProvider"},
146	}
147
148	client := &dynamicMockProviderClient{
149		providers: [][]catwalk.Provider{
150			updatedProviders,
151		},
152	}
153
154	data, err := json.Marshal(initialProviders)
155	require.NoError(t, err)
156	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
157
158	providerMu.Lock()
159	oldInitialized := initialized
160	initialized = false
161	providerMu.Unlock()
162
163	defer func() {
164		providerMu.Lock()
165		initialized = oldInitialized
166		providerMu.Unlock()
167	}()
168
169	providers, err := loadProviders(false, client, tmpPath)
170	require.NoError(t, err)
171	require.NotNil(t, providers)
172	require.Len(t, providers, 1)
173	require.Equal(t, "InitialProvider", providers[0].Name)
174
175	require.Eventually(t, func() bool {
176		reloadedProviders, err := loadProvidersFromCache(tmpPath)
177		if err != nil {
178			return false
179		}
180		return len(reloadedProviders) == 2
181	}, 2*time.Second, 50*time.Millisecond, "Background cache update should complete within 2 seconds")
182
183	reloadedProviders, err := loadProvidersFromCache(tmpPath)
184	require.NoError(t, err)
185	require.Len(t, reloadedProviders, 2)
186	require.Equal(t, "UpdatedProvider", reloadedProviders[0].Name)
187	require.Equal(t, "NewProvider", reloadedProviders[1].Name)
188
189	require.Eventually(t, func() bool {
190		providerMu.RLock()
191		defer providerMu.RUnlock()
192		return len(providerList) == 2
193	}, 2*time.Second, 50*time.Millisecond, "In-memory provider list should be reloaded")
194
195	providerMu.RLock()
196	inMemoryProviders := providerList
197	providerMu.RUnlock()
198
199	require.Len(t, inMemoryProviders, 2)
200	require.Equal(t, "UpdatedProvider", inMemoryProviders[0].Name)
201	require.Equal(t, "NewProvider", inMemoryProviders[1].Name)
202}
203
204func TestProvider_reloadProvidersThreadSafety(t *testing.T) {
205	defer func() {
206		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
207		providerList = nil
208	}()
209
210	tmpPath := t.TempDir() + "/providers.json"
211
212	initialProviders := []catwalk.Provider{
213		{Name: "Provider1"},
214	}
215	data, err := json.Marshal(initialProviders)
216	require.NoError(t, err)
217	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
218
219	providerMu.Lock()
220	oldList := providerList
221	oldErr := providerErr
222	oldInitialized := initialized
223	providerList = initialProviders
224	providerErr = nil
225	initialized = true
226	providerMu.Unlock()
227
228	defer func() {
229		providerMu.Lock()
230		providerList = oldList
231		providerErr = oldErr
232		initialized = oldInitialized
233		providerMu.Unlock()
234	}()
235
236	var wg sync.WaitGroup
237	for i := 0; i < 10; i++ {
238		wg.Add(1)
239		go func(iteration int) {
240			defer wg.Done()
241
242			updatedProviders := []catwalk.Provider{
243				{Name: "Provider1"},
244				{Name: "Provider2"},
245			}
246			data, err := json.Marshal(updatedProviders)
247			require.NoError(t, err)
248			require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
249
250			reloadProviders(tmpPath)
251
252			providerMu.RLock()
253			currentList := providerList
254			providerMu.RUnlock()
255
256			require.NotNil(t, currentList)
257		}(i)
258	}
259
260	wg.Wait()
261
262	providerMu.RLock()
263	finalList := providerList
264	providerMu.RUnlock()
265
266	require.Len(t, finalList, 2)
267}
268
269func TestProvider_reloadProvidersWithEmptyCache(t *testing.T) {
270	defer func() {
271		initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
272	}()
273
274	tmpPath := t.TempDir() + "/providers.json"
275
276	initialProviders := []catwalk.Provider{
277		{Name: "InitialProvider"},
278	}
279
280	providerMu.Lock()
281	oldList := providerList
282	oldErr := providerErr
283	oldInitialized := initialized
284	providerList = initialProviders
285	providerErr = nil
286	initialized = true
287	providerMu.Unlock()
288
289	defer func() {
290		providerMu.Lock()
291		providerList = oldList
292		providerErr = oldErr
293		initialized = oldInitialized
294		providerMu.Unlock()
295	}()
296
297	emptyProviders := []catwalk.Provider{}
298	data, err := json.Marshal(emptyProviders)
299	require.NoError(t, err)
300	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
301
302	reloadProviders(tmpPath)
303
304	providerMu.RLock()
305	currentList := providerList
306	providerMu.RUnlock()
307
308	require.Len(t, currentList, 1)
309	require.Equal(t, "InitialProvider", currentList[0].Name)
310}
311
312func TestProvider_reloadProvidersWithInvalidCache(t *testing.T) {
313	tmpPath := t.TempDir() + "/providers.json"
314
315	initialProviders := []catwalk.Provider{
316		{Name: "InitialProvider"},
317	}
318
319	providerMu.Lock()
320	oldList := providerList
321	oldErr := providerErr
322	oldInitialized := initialized
323	providerList = initialProviders
324	providerErr = nil
325	initialized = true
326	providerMu.Unlock()
327
328	defer func() {
329		providerMu.Lock()
330		providerList = oldList
331		providerErr = oldErr
332		initialized = oldInitialized
333		providerMu.Unlock()
334	}()
335
336	require.NoError(t, os.WriteFile(tmpPath, []byte("invalid json"), 0o644))
337
338	reloadProviders(tmpPath)
339
340	providerMu.RLock()
341	currentList := providerList
342	providerMu.RUnlock()
343
344	require.Len(t, currentList, 1)
345	require.Equal(t, "InitialProvider", currentList[0].Name)
346}