1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "os"
7 "sync"
8 "testing"
9 "time"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/stretchr/testify/assert"
13 "github.com/stretchr/testify/require"
14)
15
16type mockProviderClient struct {
17 shouldFail bool
18 data []catwalk.Provider
19}
20
21func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
22 if m.shouldFail {
23 return nil, errors.New("failed to load providers")
24 }
25 if len(m.data) > 0 {
26 return m.data, nil
27 }
28 return []catwalk.Provider{
29 {
30 Name: "Mock",
31 },
32 }, nil
33}
34
35func TestProvider_ProvidersReturnsFromClientIfNoCache(t *testing.T) {
36 defer func() {
37 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
38 }()
39 require.False(t, initialized)
40 catwalkProviderData := []catwalk.Provider{
41 {Name: "Mock1"},
42 {Name: "Mock2"},
43 }
44 client := &mockProviderClient{shouldFail: false, data: catwalkProviderData}
45
46 autoUpdateDisabled := false // this doesn't matter within this test
47 resolvedProviders, err := ProvidersWithClient(autoUpdateDisabled, client, "non-existent-cache-path")
48
49 require.NoError(t, err)
50 assert.Equal(t, catwalkProviderData, resolvedProviders)
51}
52
53func TestProvider_loadProvidersNoIssues(t *testing.T) {
54 defer func() {
55 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
56 }()
57 client := &mockProviderClient{shouldFail: false}
58 tmpPath := t.TempDir() + "/providers.json"
59
60 providers, err := loadProviders(false, client, tmpPath)
61 require.NoError(t, err)
62 require.NotNil(t, providers)
63 require.Len(t, providers, 1)
64
65 // check if file got saved
66 fileInfo, err := os.Stat(tmpPath)
67 require.NoError(t, err)
68 require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
69}
70
71func TestProvider_loadProvidersWithIssues(t *testing.T) {
72 defer func() {
73 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
74 }()
75 require.False(t, initialized)
76 client := &mockProviderClient{shouldFail: true}
77 tmpPath := t.TempDir() + "/providers.json"
78 // store providers to a temporary file
79 oldProviders := []catwalk.Provider{
80 {
81 Name: "OldProvider",
82 },
83 }
84 data, err := json.Marshal(oldProviders)
85 if err != nil {
86 t.Fatalf("Failed to marshal old providers: %v", err)
87 }
88
89 err = os.WriteFile(tmpPath, data, 0o644)
90 if err != nil {
91 t.Fatalf("Failed to write old providers to file: %v", err)
92 }
93
94 providers, err := loadProviders(false, client, tmpPath)
95 require.NoError(t, err)
96 require.NotNil(t, providers)
97 require.Len(t, providers, 1)
98 require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
99}
100
101func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
102 defer func() {
103 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
104 }()
105 client := &mockProviderClient{shouldFail: true}
106 tmpPath := t.TempDir() + "/providers.json"
107
108 providers, err := loadProviders(false, client, tmpPath)
109 require.Error(t, err)
110 require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
111}
112
113type dynamicMockProviderClient struct {
114 mu sync.Mutex
115 callCount int
116 providers [][]catwalk.Provider
117}
118
119func (m *dynamicMockProviderClient) GetProviders() ([]catwalk.Provider, error) {
120 m.mu.Lock()
121 defer m.mu.Unlock()
122
123 if m.callCount >= len(m.providers) {
124 return m.providers[len(m.providers)-1], nil
125 }
126
127 result := m.providers[m.callCount]
128 m.callCount++
129 return result, nil
130}
131
132func TestProvider_backgroundReloadAfterCacheUpdate(t *testing.T) {
133 defer func() {
134 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
135 providerList = nil
136 }()
137
138 tmpPath := t.TempDir() + "/providers.json"
139
140 initialProviders := []catwalk.Provider{
141 {Name: "InitialProvider"},
142 }
143 updatedProviders := []catwalk.Provider{
144 {Name: "UpdatedProvider"},
145 {Name: "NewProvider"},
146 }
147
148 client := &dynamicMockProviderClient{
149 providers: [][]catwalk.Provider{
150 updatedProviders,
151 },
152 }
153
154 data, err := json.Marshal(initialProviders)
155 require.NoError(t, err)
156 require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
157
158 providerMu.Lock()
159 oldInitialized := initialized
160 initialized = false
161 providerMu.Unlock()
162
163 defer func() {
164 providerMu.Lock()
165 initialized = oldInitialized
166 providerMu.Unlock()
167 }()
168
169 providers, err := loadProviders(false, client, tmpPath)
170 require.NoError(t, err)
171 require.NotNil(t, providers)
172 require.Len(t, providers, 1)
173 require.Equal(t, "InitialProvider", providers[0].Name)
174
175 require.Eventually(t, func() bool {
176 reloadedProviders, err := loadProvidersFromCache(tmpPath)
177 if err != nil {
178 return false
179 }
180 return len(reloadedProviders) == 2
181 }, 2*time.Second, 50*time.Millisecond, "Background cache update should complete within 2 seconds")
182
183 reloadedProviders, err := loadProvidersFromCache(tmpPath)
184 require.NoError(t, err)
185 require.Len(t, reloadedProviders, 2)
186 require.Equal(t, "UpdatedProvider", reloadedProviders[0].Name)
187 require.Equal(t, "NewProvider", reloadedProviders[1].Name)
188
189 require.Eventually(t, func() bool {
190 providerMu.RLock()
191 defer providerMu.RUnlock()
192 return len(providerList) == 2
193 }, 2*time.Second, 50*time.Millisecond, "In-memory provider list should be reloaded")
194
195 providerMu.RLock()
196 inMemoryProviders := providerList
197 providerMu.RUnlock()
198
199 require.Len(t, inMemoryProviders, 2)
200 require.Equal(t, "UpdatedProvider", inMemoryProviders[0].Name)
201 require.Equal(t, "NewProvider", inMemoryProviders[1].Name)
202}
203
204func TestProvider_reloadProvidersThreadSafety(t *testing.T) {
205 defer func() {
206 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
207 providerList = nil
208 }()
209
210 tmpPath := t.TempDir() + "/providers.json"
211
212 initialProviders := []catwalk.Provider{
213 {Name: "Provider1"},
214 }
215 data, err := json.Marshal(initialProviders)
216 require.NoError(t, err)
217 require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
218
219 providerMu.Lock()
220 oldList := providerList
221 oldErr := providerErr
222 oldInitialized := initialized
223 providerList = initialProviders
224 providerErr = nil
225 initialized = true
226 providerMu.Unlock()
227
228 defer func() {
229 providerMu.Lock()
230 providerList = oldList
231 providerErr = oldErr
232 initialized = oldInitialized
233 providerMu.Unlock()
234 }()
235
236 var wg sync.WaitGroup
237 for i := 0; i < 10; i++ {
238 wg.Add(1)
239 go func(iteration int) {
240 defer wg.Done()
241
242 updatedProviders := []catwalk.Provider{
243 {Name: "Provider1"},
244 {Name: "Provider2"},
245 }
246 data, err := json.Marshal(updatedProviders)
247 require.NoError(t, err)
248 require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
249
250 reloadProviders(tmpPath)
251
252 providerMu.RLock()
253 currentList := providerList
254 providerMu.RUnlock()
255
256 require.NotNil(t, currentList)
257 }(i)
258 }
259
260 wg.Wait()
261
262 providerMu.RLock()
263 finalList := providerList
264 providerMu.RUnlock()
265
266 require.Len(t, finalList, 2)
267}
268
269func TestProvider_reloadProvidersWithEmptyCache(t *testing.T) {
270 defer func() {
271 initialized = false // NOTE(tauramui): should make these part of a test suite's tidy method
272 }()
273
274 tmpPath := t.TempDir() + "/providers.json"
275
276 initialProviders := []catwalk.Provider{
277 {Name: "InitialProvider"},
278 }
279
280 providerMu.Lock()
281 oldList := providerList
282 oldErr := providerErr
283 oldInitialized := initialized
284 providerList = initialProviders
285 providerErr = nil
286 initialized = true
287 providerMu.Unlock()
288
289 defer func() {
290 providerMu.Lock()
291 providerList = oldList
292 providerErr = oldErr
293 initialized = oldInitialized
294 providerMu.Unlock()
295 }()
296
297 emptyProviders := []catwalk.Provider{}
298 data, err := json.Marshal(emptyProviders)
299 require.NoError(t, err)
300 require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
301
302 reloadProviders(tmpPath)
303
304 providerMu.RLock()
305 currentList := providerList
306 providerMu.RUnlock()
307
308 require.Len(t, currentList, 1)
309 require.Equal(t, "InitialProvider", currentList[0].Name)
310}
311
312func TestProvider_reloadProvidersWithInvalidCache(t *testing.T) {
313 tmpPath := t.TempDir() + "/providers.json"
314
315 initialProviders := []catwalk.Provider{
316 {Name: "InitialProvider"},
317 }
318
319 providerMu.Lock()
320 oldList := providerList
321 oldErr := providerErr
322 oldInitialized := initialized
323 providerList = initialProviders
324 providerErr = nil
325 initialized = true
326 providerMu.Unlock()
327
328 defer func() {
329 providerMu.Lock()
330 providerList = oldList
331 providerErr = oldErr
332 initialized = oldInitialized
333 providerMu.Unlock()
334 }()
335
336 require.NoError(t, os.WriteFile(tmpPath, []byte("invalid json"), 0o644))
337
338 reloadProviders(tmpPath)
339
340 providerMu.RLock()
341 currentList := providerList
342 providerMu.RUnlock()
343
344 require.Len(t, currentList, 1)
345 require.Equal(t, "InitialProvider", currentList[0].Name)
346}