1package config
2
3import (
4 "encoding/json"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/fur/provider"
8 "github.com/stretchr/testify/assert"
9 "github.com/stretchr/testify/require"
10)
11
12func TestMockProviders(t *testing.T) {
13 // Enable mock providers for testing
14 originalUseMock := UseMockProviders
15 UseMockProviders = true
16 defer func() {
17 UseMockProviders = originalUseMock
18 ResetProviders()
19 }()
20
21 // Reset providers to ensure we get fresh mock data
22 ResetProviders()
23
24 providers := Providers()
25 require.NotEmpty(t, providers, "Mock providers should not be empty")
26
27 // Verify we have the expected mock providers
28 providerIDs := make(map[provider.InferenceProvider]bool)
29 for _, p := range providers {
30 providerIDs[p.ID] = true
31 }
32
33 assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider")
34 assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider")
35 assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider")
36
37 // Verify Anthropic provider details
38 var anthropicProvider provider.Provider
39 for _, p := range providers {
40 if p.ID == provider.InferenceProviderAnthropic {
41 anthropicProvider = p
42 break
43 }
44 }
45
46 assert.Equal(t, "Anthropic", anthropicProvider.Name)
47 assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type)
48 assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID)
49 assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID)
50 assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models")
51
52 // Verify model details
53 var opusModel provider.Model
54 for _, m := range anthropicProvider.Models {
55 if m.ID == "claude-3-opus" {
56 opusModel = m
57 break
58 }
59 }
60
61 assert.Equal(t, "Claude 3 Opus", opusModel.Name)
62 assert.Equal(t, int64(200000), opusModel.ContextWindow)
63 assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens)
64 assert.True(t, opusModel.SupportsImages)
65}
66
67func TestProvidersWithoutMock(t *testing.T) {
68 // Ensure mock is disabled
69 originalUseMock := UseMockProviders
70 UseMockProviders = false
71 defer func() {
72 UseMockProviders = originalUseMock
73 ResetProviders()
74 }()
75
76 // Reset providers to ensure we get fresh data
77 ResetProviders()
78
79 // This will try to make an actual API call or use cached data
80 providers := Providers()
81
82 // We can't guarantee what we'll get here since it depends on network/cache
83 // but we can at least verify the function doesn't panic
84 t.Logf("Got %d providers without mock", len(providers))
85}
86
87func TestResetProviders(t *testing.T) {
88 // Enable mock providers
89 UseMockProviders = true
90 defer func() {
91 UseMockProviders = false
92 ResetProviders()
93 }()
94
95 // Get providers once
96 providers1 := Providers()
97 require.NotEmpty(t, providers1)
98
99 // Reset and get again
100 ResetProviders()
101 providers2 := Providers()
102 require.NotEmpty(t, providers2)
103
104 // Should get the same mock data
105 assert.Equal(t, len(providers1), len(providers2))
106}
107
108func TestReasoningEffortSupport(t *testing.T) {
109 originalUseMock := UseMockProviders
110 UseMockProviders = true
111 defer func() {
112 UseMockProviders = originalUseMock
113 ResetProviders()
114 }()
115
116 ResetProviders()
117 providers := Providers()
118
119 var openaiProvider provider.Provider
120 for _, p := range providers {
121 if p.ID == provider.InferenceProviderOpenAI {
122 openaiProvider = p
123 break
124 }
125 }
126 require.NotEmpty(t, openaiProvider.ID)
127
128 var reasoningModel, nonReasoningModel provider.Model
129 for _, model := range openaiProvider.Models {
130 if model.CanReason && model.HasReasoningEffort {
131 reasoningModel = model
132 } else if !model.CanReason {
133 nonReasoningModel = model
134 }
135 }
136
137 require.NotEmpty(t, reasoningModel.ID)
138 assert.Equal(t, "medium", reasoningModel.DefaultReasoningEffort)
139 assert.True(t, reasoningModel.HasReasoningEffort)
140
141 require.NotEmpty(t, nonReasoningModel.ID)
142 assert.False(t, nonReasoningModel.HasReasoningEffort)
143 assert.Empty(t, nonReasoningModel.DefaultReasoningEffort)
144}
145
146func TestReasoningEffortConfigTransfer(t *testing.T) {
147 originalUseMock := UseMockProviders
148 UseMockProviders = true
149 defer func() {
150 UseMockProviders = originalUseMock
151 ResetProviders()
152 }()
153
154 ResetProviders()
155 t.Setenv("OPENAI_API_KEY", "test-openai-key")
156
157 cfg, err := Init(t.TempDir(), false)
158 require.NoError(t, err)
159
160 openaiProviderConfig, exists := cfg.Providers[provider.InferenceProviderOpenAI]
161 require.True(t, exists)
162
163 var foundReasoning, foundNonReasoning bool
164 for _, model := range openaiProviderConfig.Models {
165 if model.CanReason && model.HasReasoningEffort && model.ReasoningEffort != "" {
166 assert.Equal(t, "medium", model.ReasoningEffort)
167 assert.True(t, model.HasReasoningEffort)
168 foundReasoning = true
169 } else if !model.CanReason {
170 assert.Empty(t, model.ReasoningEffort)
171 assert.False(t, model.HasReasoningEffort)
172 foundNonReasoning = true
173 }
174 }
175
176 assert.True(t, foundReasoning, "Should find at least one reasoning model")
177 assert.True(t, foundNonReasoning, "Should find at least one non-reasoning model")
178}
179
180func TestNewProviders(t *testing.T) {
181 originalUseMock := UseMockProviders
182 UseMockProviders = true
183 defer func() {
184 UseMockProviders = originalUseMock
185 ResetProviders()
186 }()
187
188 ResetProviders()
189 providers := Providers()
190 require.NotEmpty(t, providers)
191
192 var xaiProvider, openRouterProvider provider.Provider
193 for _, p := range providers {
194 switch p.ID {
195 case provider.InferenceProviderXAI:
196 xaiProvider = p
197 case provider.InferenceProviderOpenRouter:
198 openRouterProvider = p
199 }
200 }
201
202 require.NotEmpty(t, xaiProvider.ID)
203 assert.Equal(t, "xAI", xaiProvider.Name)
204 assert.Equal(t, "grok-beta", xaiProvider.DefaultLargeModelID)
205
206 require.NotEmpty(t, openRouterProvider.ID)
207 assert.Equal(t, "OpenRouter", openRouterProvider.Name)
208 assert.Equal(t, "anthropic/claude-3.5-sonnet", openRouterProvider.DefaultLargeModelID)
209}
210
211func TestO1ModelsInMockProvider(t *testing.T) {
212 originalUseMock := UseMockProviders
213 UseMockProviders = true
214 defer func() {
215 UseMockProviders = originalUseMock
216 ResetProviders()
217 }()
218
219 ResetProviders()
220 providers := Providers()
221
222 var openaiProvider provider.Provider
223 for _, p := range providers {
224 if p.ID == provider.InferenceProviderOpenAI {
225 openaiProvider = p
226 break
227 }
228 }
229 require.NotEmpty(t, openaiProvider.ID)
230
231 modelTests := []struct {
232 id string
233 name string
234 }{
235 {"o1-preview", "o1-preview"},
236 {"o1-mini", "o1-mini"},
237 }
238
239 for _, test := range modelTests {
240 var model provider.Model
241 var found bool
242 for _, m := range openaiProvider.Models {
243 if m.ID == test.id {
244 model = m
245 found = true
246 break
247 }
248 }
249 require.True(t, found, "Should find %s model", test.id)
250 assert.Equal(t, test.name, model.Name)
251 assert.True(t, model.CanReason)
252 assert.True(t, model.HasReasoningEffort)
253 assert.Equal(t, "medium", model.DefaultReasoningEffort)
254 }
255}
256
257func TestPreferredModelReasoningEffort(t *testing.T) {
258 // Test that PreferredModel struct can hold reasoning effort
259 preferredModel := PreferredModel{
260 ModelID: "o1-preview",
261 Provider: provider.InferenceProviderOpenAI,
262 ReasoningEffort: "high",
263 }
264
265 assert.Equal(t, "o1-preview", preferredModel.ModelID)
266 assert.Equal(t, provider.InferenceProviderOpenAI, preferredModel.Provider)
267 assert.Equal(t, "high", preferredModel.ReasoningEffort)
268
269 // Test JSON marshaling/unmarshaling
270 jsonData, err := json.Marshal(preferredModel)
271 require.NoError(t, err)
272
273 var unmarshaled PreferredModel
274 err = json.Unmarshal(jsonData, &unmarshaled)
275 require.NoError(t, err)
276
277 assert.Equal(t, preferredModel.ModelID, unmarshaled.ModelID)
278 assert.Equal(t, preferredModel.Provider, unmarshaled.Provider)
279 assert.Equal(t, preferredModel.ReasoningEffort, unmarshaled.ReasoningEffort)
280}