provider_test.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"os"
  7	"sync"
  8	"testing"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/stretchr/testify/require"
 13)
 14
 15type mockProviderClient struct {
 16	shouldFail bool
 17}
 18
 19func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
 20	if m.shouldFail {
 21		return nil, errors.New("failed to load providers")
 22	}
 23	return []catwalk.Provider{
 24		{
 25			Name: "Mock",
 26		},
 27	}, nil
 28}
 29
 30func TestProvider_loadProvidersNoIssues(t *testing.T) {
 31	client := &mockProviderClient{shouldFail: false}
 32	tmpPath := t.TempDir() + "/providers.json"
 33	cfg := &Config{}
 34	providers, err := loadProviders(false, client, tmpPath, cfg)
 35	require.NoError(t, err)
 36	require.NotNil(t, providers)
 37	require.Len(t, providers, 1)
 38
 39	// check if file got saved
 40	fileInfo, err := os.Stat(tmpPath)
 41	require.NoError(t, err)
 42	require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
 43}
 44
 45func TestProvider_loadProvidersWithIssues(t *testing.T) {
 46	client := &mockProviderClient{shouldFail: true}
 47	tmpPath := t.TempDir() + "/providers.json"
 48	// store providers to a temporary file
 49	oldProviders := []catwalk.Provider{
 50		{
 51			Name: "OldProvider",
 52		},
 53	}
 54	data, err := json.Marshal(oldProviders)
 55	if err != nil {
 56		t.Fatalf("Failed to marshal old providers: %v", err)
 57	}
 58
 59	err = os.WriteFile(tmpPath, data, 0o644)
 60	if err != nil {
 61		t.Fatalf("Failed to write old providers to file: %v", err)
 62	}
 63	cfg := &Config{}
 64	providers, err := loadProviders(false, client, tmpPath, cfg)
 65	require.NoError(t, err)
 66	require.NotNil(t, providers)
 67	require.Len(t, providers, 1)
 68	require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
 69}
 70
 71func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
 72	client := &mockProviderClient{shouldFail: true}
 73	tmpPath := t.TempDir() + "/providers.json"
 74	cfg := &Config{}
 75	providers, err := loadProviders(false, client, tmpPath, cfg)
 76	require.Error(t, err)
 77	require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
 78}
 79
 80type dynamicMockProviderClient struct {
 81	mu        sync.Mutex
 82	callCount int
 83	providers [][]catwalk.Provider
 84}
 85
 86func (m *dynamicMockProviderClient) GetProviders() ([]catwalk.Provider, error) {
 87	m.mu.Lock()
 88	defer m.mu.Unlock()
 89
 90	if m.callCount >= len(m.providers) {
 91		return m.providers[len(m.providers)-1], nil
 92	}
 93
 94	result := m.providers[m.callCount]
 95	m.callCount++
 96	return result, nil
 97}
 98
 99func TestProvider_backgroundReloadAfterCacheUpdate(t *testing.T) {
100	t.Parallel()
101
102	tmpPath := t.TempDir() + "/providers.json"
103
104	initialProviders := []catwalk.Provider{
105		{Name: "InitialProvider"},
106	}
107	updatedProviders := []catwalk.Provider{
108		{Name: "UpdatedProvider"},
109		{Name: "NewProvider"},
110	}
111
112	client := &dynamicMockProviderClient{
113		providers: [][]catwalk.Provider{
114			updatedProviders,
115		},
116	}
117
118	data, err := json.Marshal(initialProviders)
119	require.NoError(t, err)
120	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
121
122	providerMu.Lock()
123	oldInitialized := initialized
124	initialized = false
125	providerMu.Unlock()
126
127	defer func() {
128		providerMu.Lock()
129		initialized = oldInitialized
130		providerMu.Unlock()
131	}()
132
133	cfg := &Config{}
134	providers, err := loadProviders(false, client, tmpPath, cfg)
135	require.NoError(t, err)
136	require.NotNil(t, providers)
137	require.Len(t, providers, 1)
138	require.Equal(t, "InitialProvider", providers[0].Name)
139
140	require.Eventually(t, func() bool {
141		reloadedProviders, err := loadProvidersFromCache(tmpPath)
142		if err != nil {
143			return false
144		}
145		return len(reloadedProviders) == 2
146	}, 2*time.Second, 50*time.Millisecond, "Background cache update should complete within 2 seconds")
147
148	reloadedProviders, err := loadProvidersFromCache(tmpPath)
149	require.NoError(t, err)
150	require.Len(t, reloadedProviders, 2)
151	require.Equal(t, "UpdatedProvider", reloadedProviders[0].Name)
152	require.Equal(t, "NewProvider", reloadedProviders[1].Name)
153
154	require.Eventually(t, func() bool {
155		providerMu.RLock()
156		defer providerMu.RUnlock()
157		return len(providerList) == 2
158	}, 2*time.Second, 50*time.Millisecond, "In-memory provider list should be reloaded")
159
160	providerMu.RLock()
161	inMemoryProviders := providerList
162	providerMu.RUnlock()
163
164	require.Len(t, inMemoryProviders, 2)
165	require.Equal(t, "UpdatedProvider", inMemoryProviders[0].Name)
166	require.Equal(t, "NewProvider", inMemoryProviders[1].Name)
167}
168
169func TestProvider_reloadProvidersThreadSafety(t *testing.T) {
170	tmpPath := t.TempDir() + "/providers.json"
171
172	initialProviders := []catwalk.Provider{
173		{Name: "Provider1"},
174	}
175	data, err := json.Marshal(initialProviders)
176	require.NoError(t, err)
177	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
178
179	providerMu.Lock()
180	oldList := providerList
181	oldErr := providerErr
182	oldInitialized := initialized
183	providerList = initialProviders
184	providerErr = nil
185	initialized = true
186	providerMu.Unlock()
187
188	defer func() {
189		providerMu.Lock()
190		providerList = oldList
191		providerErr = oldErr
192		initialized = oldInitialized
193		providerMu.Unlock()
194	}()
195
196	var wg sync.WaitGroup
197	for i := 0; i < 10; i++ {
198		wg.Add(1)
199		go func(iteration int) {
200			defer wg.Done()
201
202			updatedProviders := []catwalk.Provider{
203				{Name: "Provider1"},
204				{Name: "Provider2"},
205			}
206			data, err := json.Marshal(updatedProviders)
207			require.NoError(t, err)
208			require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
209
210			reloadProviders(tmpPath)
211
212			providerMu.RLock()
213			currentList := providerList
214			providerMu.RUnlock()
215
216			require.NotNil(t, currentList)
217		}(i)
218	}
219
220	wg.Wait()
221
222	providerMu.RLock()
223	finalList := providerList
224	providerMu.RUnlock()
225
226	require.Len(t, finalList, 2)
227}
228
229func TestProvider_reloadProvidersWithEmptyCache(t *testing.T) {
230	tmpPath := t.TempDir() + "/providers.json"
231
232	initialProviders := []catwalk.Provider{
233		{Name: "InitialProvider"},
234	}
235
236	providerMu.Lock()
237	oldList := providerList
238	oldErr := providerErr
239	oldInitialized := initialized
240	providerList = initialProviders
241	providerErr = nil
242	initialized = true
243	providerMu.Unlock()
244
245	defer func() {
246		providerMu.Lock()
247		providerList = oldList
248		providerErr = oldErr
249		initialized = oldInitialized
250		providerMu.Unlock()
251	}()
252
253	emptyProviders := []catwalk.Provider{}
254	data, err := json.Marshal(emptyProviders)
255	require.NoError(t, err)
256	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
257
258	reloadProviders(tmpPath)
259
260	providerMu.RLock()
261	currentList := providerList
262	providerMu.RUnlock()
263
264	require.Len(t, currentList, 1)
265	require.Equal(t, "InitialProvider", currentList[0].Name)
266}
267
268func TestProvider_reloadProvidersWithInvalidCache(t *testing.T) {
269	tmpPath := t.TempDir() + "/providers.json"
270
271	initialProviders := []catwalk.Provider{
272		{Name: "InitialProvider"},
273	}
274
275	providerMu.Lock()
276	oldList := providerList
277	oldErr := providerErr
278	oldInitialized := initialized
279	providerList = initialProviders
280	providerErr = nil
281	initialized = true
282	providerMu.Unlock()
283
284	defer func() {
285		providerMu.Lock()
286		providerList = oldList
287		providerErr = oldErr
288		initialized = oldInitialized
289		providerMu.Unlock()
290	}()
291
292	require.NoError(t, os.WriteFile(tmpPath, []byte("invalid json"), 0o644))
293
294	reloadProviders(tmpPath)
295
296	providerMu.RLock()
297	currentList := providerList
298	providerMu.RUnlock()
299
300	require.Len(t, currentList, 1)
301	require.Equal(t, "InitialProvider", currentList[0].Name)
302}