1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "os"
7 "sync"
8 "testing"
9 "time"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/stretchr/testify/assert"
13 "github.com/stretchr/testify/require"
14)
15
16type mockProviderClient struct {
17 shouldFail bool
18 data []catwalk.Provider
19}
20
21func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
22 if m.shouldFail {
23 return nil, errors.New("failed to load providers")
24 }
25 if len(m.data) > 0 {
26 return m.data, nil
27 }
28 return []catwalk.Provider{
29 {
30 Name: "Mock",
31 },
32 }, nil
33}
34
35func TestProvider_ProvidersReturnsFromClientIfNoCache(t *testing.T) {
36 defer func() {
37 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
38 providerList = nil
39 }()
40 require.False(t, initialized)
41 catwalkProviderData := []catwalk.Provider{
42 {Name: "Mock1"},
43 {Name: "Mock2"},
44 }
45 client := &mockProviderClient{shouldFail: false, data: catwalkProviderData}
46
47 autoUpdateDisabled := false // this doesn't matter within this test
48 resolvedProviders, err := ProvidersWithClient(autoUpdateDisabled, client, "non-existent-cache-path")
49
50 require.NoError(t, err)
51 assert.Equal(t, catwalkProviderData, resolvedProviders)
52}
53
54func TestProvider_loadProvidersNoIssues(t *testing.T) {
55 defer func() {
56 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
57 providerList = nil
58 }()
59 client := &mockProviderClient{shouldFail: false}
60 tmpPath := t.TempDir() + "/providers.json"
61
62 providers, err := loadProviders(false, client, tmpPath)
63 require.NoError(t, err)
64 require.NotNil(t, providers)
65 require.Len(t, providers, 1)
66
67 // check if file got saved
68 fileInfo, err := os.Stat(tmpPath)
69 require.NoError(t, err)
70 require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
71}
72
73func TestProvider_loadProvidersWithIssues(t *testing.T) {
74 defer func() {
75 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
76 providerList = nil
77 }()
78 require.False(t, initialized)
79 client := &mockProviderClient{shouldFail: true}
80 tmpPath := t.TempDir() + "/providers.json"
81 // store providers to a temporary file
82 oldProviders := []catwalk.Provider{
83 {
84 Name: "OldProvider",
85 },
86 }
87 data, err := json.Marshal(oldProviders)
88 if err != nil {
89 t.Fatalf("Failed to marshal old providers: %v", err)
90 }
91
92 err = os.WriteFile(tmpPath, data, 0o644)
93 if err != nil {
94 t.Fatalf("Failed to write old providers to file: %v", err)
95 }
96
97 providers, err := loadProviders(false, client, tmpPath)
98 require.NoError(t, err)
99 require.NotNil(t, providers)
100 require.Len(t, providers, 1)
101 require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
102}
103
104func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
105 defer func() {
106 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
107 providerList = nil
108 }()
109 client := &mockProviderClient{shouldFail: true}
110 tmpPath := t.TempDir() + "/providers.json"
111
112 providers, err := loadProviders(false, client, tmpPath)
113 require.Error(t, err)
114 require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
115}
116
117type dynamicMockProviderClient struct {
118 mu sync.Mutex
119 callCount int
120 providers [][]catwalk.Provider
121}
122
123func (m *dynamicMockProviderClient) GetProviders() ([]catwalk.Provider, error) {
124 m.mu.Lock()
125 defer m.mu.Unlock()
126
127 if m.callCount >= len(m.providers) {
128 return m.providers[len(m.providers)-1], nil
129 }
130
131 result := m.providers[m.callCount]
132 m.callCount++
133 return result, nil
134}
135
136func TestProvider_backgroundReloadAfterCacheUpdate(t *testing.T) {
137 defer func() {
138 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
139 providerList = nil
140 }()
141
142 tmpPath := t.TempDir() + "/providers.json"
143
144 initialProviders := []catwalk.Provider{
145 {Name: "InitialProvider"},
146 }
147 updatedProviders := []catwalk.Provider{
148 {Name: "UpdatedProvider"},
149 {Name: "NewProvider"},
150 }
151
152 client := &dynamicMockProviderClient{
153 providers: [][]catwalk.Provider{
154 updatedProviders,
155 },
156 }
157
158 data, err := json.Marshal(initialProviders)
159 require.NoError(t, err)
160 require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
161
162 providerMu.Lock()
163 oldInitialized := initialized
164 initialized = false
165 providerMu.Unlock()
166
167 defer func() {
168 providerMu.Lock()
169 initialized = oldInitialized
170 providerMu.Unlock()
171 }()
172
173 providers, err := loadProviders(false, client, tmpPath)
174 require.NoError(t, err)
175 require.NotNil(t, providers)
176 require.Len(t, providers, 1)
177 require.Equal(t, "InitialProvider", providers[0].Name)
178
179 require.Eventually(t, func() bool {
180 reloadedProviders, err := loadProvidersFromCache(tmpPath)
181 if err != nil {
182 return false
183 }
184 return len(reloadedProviders) == 2
185 }, 2*time.Second, 50*time.Millisecond, "Background cache update should complete within 2 seconds")
186
187 reloadedProviders, err := loadProvidersFromCache(tmpPath)
188 require.NoError(t, err)
189 require.Len(t, reloadedProviders, 2)
190 require.Equal(t, "UpdatedProvider", reloadedProviders[0].Name)
191 require.Equal(t, "NewProvider", reloadedProviders[1].Name)
192
193 require.Eventually(t, func() bool {
194 providerMu.RLock()
195 defer providerMu.RUnlock()
196 return len(providerList) == 2
197 }, 2*time.Second, 50*time.Millisecond, "In-memory provider list should be reloaded")
198
199 providerMu.RLock()
200 inMemoryProviders := providerList
201 providerMu.RUnlock()
202
203 require.Len(t, inMemoryProviders, 2)
204 require.Equal(t, "UpdatedProvider", inMemoryProviders[0].Name)
205 require.Equal(t, "NewProvider", inMemoryProviders[1].Name)
206}
207
208func TestProvider_reloadProvidersThreadSafety(t *testing.T) {
209 defer func() {
210 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
211 providerList = nil
212 }()
213
214 tmpPath := t.TempDir() + "/providers.json"
215
216 initialProviders := []catwalk.Provider{
217 {Name: "Provider1"},
218 }
219 data, err := json.Marshal(initialProviders)
220 require.NoError(t, err)
221 require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
222
223 providerMu.Lock()
224 oldList := providerList
225 oldErr := providerErr
226 oldInitialized := initialized
227 providerList = initialProviders
228 providerErr = nil
229 initialized = true
230 providerMu.Unlock()
231
232 defer func() {
233 providerMu.Lock()
234 providerList = oldList
235 providerErr = oldErr
236 initialized = oldInitialized
237 providerMu.Unlock()
238 }()
239
240 var wg sync.WaitGroup
241 for i := 0; i < 10; i++ {
242 wg.Add(1)
243 go func(iteration int) {
244 defer wg.Done()
245
246 updatedProviders := []catwalk.Provider{
247 {Name: "Provider1"},
248 {Name: "Provider2"},
249 }
250 data, err := json.Marshal(updatedProviders)
251 require.NoError(t, err)
252 require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
253
254 reloadProviders(tmpPath)
255
256 providerMu.RLock()
257 currentList := providerList
258 providerMu.RUnlock()
259
260 require.NotNil(t, currentList)
261 }(i)
262 }
263
264 wg.Wait()
265
266 providerMu.RLock()
267 finalList := providerList
268 providerMu.RUnlock()
269
270 require.Len(t, finalList, 2)
271}
272
273func TestProvider_reloadProvidersWithEmptyCache(t *testing.T) {
274 defer func() {
275 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
276 providerList = nil
277 }()
278
279 tmpPath := t.TempDir() + "/providers.json"
280
281 initialProviders := []catwalk.Provider{
282 {Name: "InitialProvider"},
283 }
284
285 providerMu.Lock()
286 oldList := providerList
287 oldErr := providerErr
288 oldInitialized := initialized
289 providerList = initialProviders
290 providerErr = nil
291 initialized = true
292 providerMu.Unlock()
293
294 defer func() {
295 providerMu.Lock()
296 providerList = oldList
297 providerErr = oldErr
298 initialized = oldInitialized
299 providerMu.Unlock()
300 }()
301
302 emptyProviders := []catwalk.Provider{}
303 data, err := json.Marshal(emptyProviders)
304 require.NoError(t, err)
305 require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
306
307 reloadProviders(tmpPath)
308
309 providerMu.RLock()
310 currentList := providerList
311 providerMu.RUnlock()
312
313 require.Len(t, currentList, 1)
314 require.Equal(t, "InitialProvider", currentList[0].Name)
315}
316
317func TestProvider_reloadProvidersWithInvalidCache(t *testing.T) {
318 defer func() {
319 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
320 providerList = nil
321 }()
322
323 tmpPath := t.TempDir() + "/providers.json"
324
325 initialProviders := []catwalk.Provider{
326 {Name: "InitialProvider"},
327 }
328
329 providerMu.Lock()
330 oldList := providerList
331 oldErr := providerErr
332 oldInitialized := initialized
333 providerList = initialProviders
334 providerErr = nil
335 initialized = true
336 providerMu.Unlock()
337
338 defer func() {
339 providerMu.Lock()
340 providerList = oldList
341 providerErr = oldErr
342 initialized = oldInitialized
343 providerMu.Unlock()
344 }()
345
346 require.NoError(t, os.WriteFile(tmpPath, []byte("invalid json"), 0o644))
347
348 reloadProviders(tmpPath)
349
350 providerMu.RLock()
351 currentList := providerList
352 providerMu.RUnlock()
353
354 require.Len(t, currentList, 1)
355 require.Equal(t, "InitialProvider", currentList[0].Name)
356}