1package config
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "os"
8 "testing"
9
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/stretchr/testify/require"
12)
13
14type mockHyperClient struct {
15 provider catwalk.Provider
16 err error
17 callCount int
18}
19
20func (m *mockHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
21 m.callCount++
22 return m.provider, m.err
23}
24
25func TestHyperSync_Init(t *testing.T) {
26 t.Parallel()
27
28 syncer := &hyperSync{}
29 client := &mockHyperClient{}
30 path := "/tmp/hyper.json"
31
32 syncer.Init(client, path, true)
33
34 require.True(t, syncer.init.Load())
35 require.Equal(t, client, syncer.client)
36 require.Equal(t, path, syncer.cache.path)
37}
38
39func TestHyperSync_GetPanicIfNotInit(t *testing.T) {
40 t.Parallel()
41
42 syncer := &hyperSync{}
43 require.Panics(t, func() {
44 _, _ = syncer.Get(t.Context())
45 })
46}
47
48func TestHyperSync_GetFreshProvider(t *testing.T) {
49 t.Parallel()
50
51 syncer := &hyperSync{}
52 client := &mockHyperClient{
53 provider: catwalk.Provider{
54 Name: "Hyper",
55 ID: "hyper",
56 Models: []catwalk.Model{
57 {ID: "model-1", Name: "Model 1"},
58 },
59 },
60 }
61 path := t.TempDir() + "/hyper.json"
62
63 syncer.Init(client, path, true)
64
65 provider, err := syncer.Get(t.Context())
66 require.NoError(t, err)
67 require.Equal(t, "Hyper", provider.Name)
68 require.Equal(t, 1, client.callCount)
69
70 // Verify cache was written.
71 fileInfo, err := os.Stat(path)
72 require.NoError(t, err)
73 require.False(t, fileInfo.IsDir())
74}
75
76func TestHyperSync_GetNotModifiedUsesCached(t *testing.T) {
77 t.Parallel()
78
79 tmpDir := t.TempDir()
80 path := tmpDir + "/hyper.json"
81
82 // Create cache file.
83 cachedProvider := catwalk.Provider{
84 Name: "Cached Hyper",
85 ID: "hyper",
86 }
87 data, err := json.Marshal(cachedProvider)
88 require.NoError(t, err)
89 require.NoError(t, os.WriteFile(path, data, 0o644))
90
91 syncer := &hyperSync{}
92 client := &mockHyperClient{
93 err: catwalk.ErrNotModified,
94 }
95
96 syncer.Init(client, path, true)
97
98 provider, err := syncer.Get(t.Context())
99 require.NoError(t, err)
100 require.Equal(t, "Cached Hyper", provider.Name)
101 require.Equal(t, 1, client.callCount)
102}
103
104func TestHyperSync_GetClientError(t *testing.T) {
105 t.Parallel()
106
107 tmpDir := t.TempDir()
108 path := tmpDir + "/hyper.json"
109
110 syncer := &hyperSync{}
111 client := &mockHyperClient{
112 err: errors.New("network error"),
113 }
114
115 syncer.Init(client, path, true)
116
117 provider, err := syncer.Get(t.Context())
118 require.NoError(t, err) // Should fall back to embedded.
119 require.Equal(t, "Charm Hyper", provider.Name)
120 require.Equal(t, catwalk.InferenceProvider("hyper"), provider.ID)
121}
122
123func TestHyperSync_GetEmptyCache(t *testing.T) {
124 t.Parallel()
125
126 tmpDir := t.TempDir()
127 path := tmpDir + "/hyper.json"
128
129 syncer := &hyperSync{}
130 client := &mockHyperClient{
131 provider: catwalk.Provider{
132 Name: "Fresh Hyper",
133 ID: "hyper",
134 Models: []catwalk.Model{
135 {ID: "model-1", Name: "Model 1"},
136 },
137 },
138 }
139
140 syncer.Init(client, path, true)
141
142 provider, err := syncer.Get(t.Context())
143 require.NoError(t, err)
144 require.Equal(t, "Fresh Hyper", provider.Name)
145}
146
147func TestHyperSync_GetCalledMultipleTimesUsesOnce(t *testing.T) {
148 t.Parallel()
149
150 syncer := &hyperSync{}
151 client := &mockHyperClient{
152 provider: catwalk.Provider{
153 Name: "Hyper",
154 ID: "hyper",
155 Models: []catwalk.Model{
156 {ID: "model-1", Name: "Model 1"},
157 },
158 },
159 }
160 path := t.TempDir() + "/hyper.json"
161
162 syncer.Init(client, path, true)
163
164 // Call Get multiple times.
165 provider1, err1 := syncer.Get(t.Context())
166 require.NoError(t, err1)
167 require.Equal(t, "Hyper", provider1.Name)
168
169 provider2, err2 := syncer.Get(t.Context())
170 require.NoError(t, err2)
171 require.Equal(t, "Hyper", provider2.Name)
172
173 // Client should only be called once due to sync.Once.
174 require.Equal(t, 1, client.callCount)
175}
176
177func TestHyperSync_GetCacheStoreError(t *testing.T) {
178 t.Parallel()
179
180 // Create a file where we want a directory, causing mkdir to fail.
181 tmpDir := t.TempDir()
182 blockingFile := tmpDir + "/blocking"
183 require.NoError(t, os.WriteFile(blockingFile, []byte("block"), 0o644))
184
185 // Try to create cache in a subdirectory under the blocking file.
186 path := blockingFile + "/subdir/hyper.json"
187
188 syncer := &hyperSync{}
189 client := &mockHyperClient{
190 provider: catwalk.Provider{
191 Name: "Hyper",
192 ID: "hyper",
193 Models: []catwalk.Model{
194 {ID: "model-1", Name: "Model 1"},
195 },
196 },
197 }
198
199 syncer.Init(client, path, true)
200
201 provider, err := syncer.Get(t.Context())
202 require.Error(t, err)
203 require.Contains(t, err.Error(), "failed to create directory for provider cache")
204 require.Equal(t, "Hyper", provider.Name) // Provider is still returned.
205}