hyper_test.go

  1package config
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"os"
  8	"testing"
  9
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/stretchr/testify/require"
 12)
 13
 14type mockHyperClient struct {
 15	provider  catwalk.Provider
 16	err       error
 17	callCount int
 18}
 19
 20func (m *mockHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
 21	m.callCount++
 22	return m.provider, m.err
 23}
 24
 25func TestHyperSync_Init(t *testing.T) {
 26	t.Parallel()
 27
 28	syncer := &hyperSync{}
 29	client := &mockHyperClient{}
 30	path := "/tmp/hyper.json"
 31
 32	syncer.Init(client, path, true)
 33
 34	require.True(t, syncer.init.Load())
 35	require.Equal(t, client, syncer.client)
 36	require.Equal(t, path, syncer.cache.path)
 37}
 38
 39func TestHyperSync_GetPanicIfNotInit(t *testing.T) {
 40	t.Parallel()
 41
 42	syncer := &hyperSync{}
 43	require.Panics(t, func() {
 44		_, _ = syncer.Get(t.Context())
 45	})
 46}
 47
 48func TestHyperSync_GetFreshProvider(t *testing.T) {
 49	t.Parallel()
 50
 51	syncer := &hyperSync{}
 52	client := &mockHyperClient{
 53		provider: catwalk.Provider{
 54			Name: "Hyper",
 55			ID:   "hyper",
 56			Models: []catwalk.Model{
 57				{ID: "model-1", Name: "Model 1"},
 58			},
 59		},
 60	}
 61	path := t.TempDir() + "/hyper.json"
 62
 63	syncer.Init(client, path, true)
 64
 65	provider, err := syncer.Get(t.Context())
 66	require.NoError(t, err)
 67	require.Equal(t, "Hyper", provider.Name)
 68	require.Equal(t, 1, client.callCount)
 69
 70	// Verify cache was written.
 71	fileInfo, err := os.Stat(path)
 72	require.NoError(t, err)
 73	require.False(t, fileInfo.IsDir())
 74}
 75
 76func TestHyperSync_GetNotModifiedUsesCached(t *testing.T) {
 77	t.Parallel()
 78
 79	tmpDir := t.TempDir()
 80	path := tmpDir + "/hyper.json"
 81
 82	// Create cache file.
 83	cachedProvider := catwalk.Provider{
 84		Name: "Cached Hyper",
 85		ID:   "hyper",
 86	}
 87	data, err := json.Marshal(cachedProvider)
 88	require.NoError(t, err)
 89	require.NoError(t, os.WriteFile(path, data, 0o644))
 90
 91	syncer := &hyperSync{}
 92	client := &mockHyperClient{
 93		err: catwalk.ErrNotModified,
 94	}
 95
 96	syncer.Init(client, path, true)
 97
 98	provider, err := syncer.Get(t.Context())
 99	require.NoError(t, err)
100	require.Equal(t, "Cached Hyper", provider.Name)
101	require.Equal(t, 1, client.callCount)
102}
103
104func TestHyperSync_GetClientError(t *testing.T) {
105	t.Parallel()
106
107	tmpDir := t.TempDir()
108	path := tmpDir + "/hyper.json"
109
110	syncer := &hyperSync{}
111	client := &mockHyperClient{
112		err: errors.New("network error"),
113	}
114
115	syncer.Init(client, path, true)
116
117	provider, err := syncer.Get(t.Context())
118	require.NoError(t, err) // Should fall back to embedded.
119	require.Equal(t, "Charm Hyper", provider.Name)
120	require.Equal(t, catwalk.InferenceProvider("hyper"), provider.ID)
121}
122
123func TestHyperSync_GetEmptyCache(t *testing.T) {
124	t.Parallel()
125
126	tmpDir := t.TempDir()
127	path := tmpDir + "/hyper.json"
128
129	syncer := &hyperSync{}
130	client := &mockHyperClient{
131		provider: catwalk.Provider{
132			Name: "Fresh Hyper",
133			ID:   "hyper",
134			Models: []catwalk.Model{
135				{ID: "model-1", Name: "Model 1"},
136			},
137		},
138	}
139
140	syncer.Init(client, path, true)
141
142	provider, err := syncer.Get(t.Context())
143	require.NoError(t, err)
144	require.Equal(t, "Fresh Hyper", provider.Name)
145}
146
147func TestHyperSync_GetCalledMultipleTimesUsesOnce(t *testing.T) {
148	t.Parallel()
149
150	syncer := &hyperSync{}
151	client := &mockHyperClient{
152		provider: catwalk.Provider{
153			Name: "Hyper",
154			ID:   "hyper",
155			Models: []catwalk.Model{
156				{ID: "model-1", Name: "Model 1"},
157			},
158		},
159	}
160	path := t.TempDir() + "/hyper.json"
161
162	syncer.Init(client, path, true)
163
164	// Call Get multiple times.
165	provider1, err1 := syncer.Get(t.Context())
166	require.NoError(t, err1)
167	require.Equal(t, "Hyper", provider1.Name)
168
169	provider2, err2 := syncer.Get(t.Context())
170	require.NoError(t, err2)
171	require.Equal(t, "Hyper", provider2.Name)
172
173	// Client should only be called once due to sync.Once.
174	require.Equal(t, 1, client.callCount)
175}
176
177func TestHyperSync_GetCacheStoreError(t *testing.T) {
178	t.Parallel()
179
180	// Create a file where we want a directory, causing mkdir to fail.
181	tmpDir := t.TempDir()
182	blockingFile := tmpDir + "/blocking"
183	require.NoError(t, os.WriteFile(blockingFile, []byte("block"), 0o644))
184
185	// Try to create cache in a subdirectory under the blocking file.
186	path := blockingFile + "/subdir/hyper.json"
187
188	syncer := &hyperSync{}
189	client := &mockHyperClient{
190		provider: catwalk.Provider{
191			Name: "Hyper",
192			ID:   "hyper",
193			Models: []catwalk.Model{
194				{ID: "model-1", Name: "Model 1"},
195			},
196		},
197	}
198
199	syncer.Init(client, path, true)
200
201	provider, err := syncer.Get(t.Context())
202	require.Error(t, err)
203	require.Contains(t, err.Error(), "failed to create directory for provider cache")
204	require.Equal(t, "Hyper", provider.Name) // Provider is still returned.
205}