1package config
  2
  3import (
  4	"encoding/json"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/fur/provider"
  8	"github.com/stretchr/testify/assert"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12func TestMockProviders(t *testing.T) {
 13	// Enable mock providers for testing
 14	originalUseMock := UseMockProviders
 15	UseMockProviders = true
 16	defer func() {
 17		UseMockProviders = originalUseMock
 18		ResetProviders()
 19	}()
 20
 21	// Reset providers to ensure we get fresh mock data
 22	ResetProviders()
 23
 24	providers := Providers()
 25	require.NotEmpty(t, providers, "Mock providers should not be empty")
 26
 27	// Verify we have the expected mock providers
 28	providerIDs := make(map[provider.InferenceProvider]bool)
 29	for _, p := range providers {
 30		providerIDs[p.ID] = true
 31	}
 32
 33	assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider")
 34	assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider")
 35	assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider")
 36
 37	// Verify Anthropic provider details
 38	var anthropicProvider provider.Provider
 39	for _, p := range providers {
 40		if p.ID == provider.InferenceProviderAnthropic {
 41			anthropicProvider = p
 42			break
 43		}
 44	}
 45
 46	assert.Equal(t, "Anthropic", anthropicProvider.Name)
 47	assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type)
 48	assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID)
 49	assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID)
 50	assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models")
 51
 52	// Verify model details
 53	var opusModel provider.Model
 54	for _, m := range anthropicProvider.Models {
 55		if m.ID == "claude-3-opus" {
 56			opusModel = m
 57			break
 58		}
 59	}
 60
 61	assert.Equal(t, "Claude 3 Opus", opusModel.Name)
 62	assert.Equal(t, int64(200000), opusModel.ContextWindow)
 63	assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens)
 64	assert.True(t, opusModel.SupportsImages)
 65}
 66
 67func TestProvidersWithoutMock(t *testing.T) {
 68	// Ensure mock is disabled
 69	originalUseMock := UseMockProviders
 70	UseMockProviders = false
 71	defer func() {
 72		UseMockProviders = originalUseMock
 73		ResetProviders()
 74	}()
 75
 76	// Reset providers to ensure we get fresh data
 77	ResetProviders()
 78
 79	// This will try to make an actual API call or use cached data
 80	providers := Providers()
 81
 82	// We can't guarantee what we'll get here since it depends on network/cache
 83	// but we can at least verify the function doesn't panic
 84	t.Logf("Got %d providers without mock", len(providers))
 85}
 86
 87func TestResetProviders(t *testing.T) {
 88	// Enable mock providers
 89	UseMockProviders = true
 90	defer func() {
 91		UseMockProviders = false
 92		ResetProviders()
 93	}()
 94
 95	// Get providers once
 96	providers1 := Providers()
 97	require.NotEmpty(t, providers1)
 98
 99	// Reset and get again
100	ResetProviders()
101	providers2 := Providers()
102	require.NotEmpty(t, providers2)
103
104	// Should get the same mock data
105	assert.Equal(t, len(providers1), len(providers2))
106}
107
108func TestReasoningEffortSupport(t *testing.T) {
109	originalUseMock := UseMockProviders
110	UseMockProviders = true
111	defer func() {
112		UseMockProviders = originalUseMock
113		ResetProviders()
114	}()
115
116	ResetProviders()
117	providers := Providers()
118	
119	var openaiProvider provider.Provider
120	for _, p := range providers {
121		if p.ID == provider.InferenceProviderOpenAI {
122			openaiProvider = p
123			break
124		}
125	}
126	require.NotEmpty(t, openaiProvider.ID)
127
128	var reasoningModel, nonReasoningModel provider.Model
129	for _, model := range openaiProvider.Models {
130		if model.CanReason && model.HasReasoningEffort {
131			reasoningModel = model
132		} else if !model.CanReason {
133			nonReasoningModel = model
134		}
135	}
136
137	require.NotEmpty(t, reasoningModel.ID)
138	assert.Equal(t, "medium", reasoningModel.DefaultReasoningEffort)
139	assert.True(t, reasoningModel.HasReasoningEffort)
140
141	require.NotEmpty(t, nonReasoningModel.ID)
142	assert.False(t, nonReasoningModel.HasReasoningEffort)
143	assert.Empty(t, nonReasoningModel.DefaultReasoningEffort)
144}
145
146func TestReasoningEffortConfigTransfer(t *testing.T) {
147	originalUseMock := UseMockProviders
148	UseMockProviders = true
149	defer func() {
150		UseMockProviders = originalUseMock
151		ResetProviders()
152	}()
153
154	ResetProviders()
155	t.Setenv("OPENAI_API_KEY", "test-openai-key")
156
157	cfg, err := Init(t.TempDir(), false)
158	require.NoError(t, err)
159
160	openaiProviderConfig, exists := cfg.Providers[provider.InferenceProviderOpenAI]
161	require.True(t, exists)
162
163	var foundReasoning, foundNonReasoning bool
164	for _, model := range openaiProviderConfig.Models {
165		if model.CanReason && model.HasReasoningEffort && model.ReasoningEffort != "" {
166			assert.Equal(t, "medium", model.ReasoningEffort)
167			assert.True(t, model.HasReasoningEffort)
168			foundReasoning = true
169		} else if !model.CanReason {
170			assert.Empty(t, model.ReasoningEffort)
171			assert.False(t, model.HasReasoningEffort)
172			foundNonReasoning = true
173		}
174	}
175
176	assert.True(t, foundReasoning, "Should find at least one reasoning model")
177	assert.True(t, foundNonReasoning, "Should find at least one non-reasoning model")
178}
179
180func TestNewProviders(t *testing.T) {
181	originalUseMock := UseMockProviders
182	UseMockProviders = true
183	defer func() {
184		UseMockProviders = originalUseMock
185		ResetProviders()
186	}()
187
188	ResetProviders()
189	providers := Providers()
190	require.NotEmpty(t, providers)
191
192	var xaiProvider, openRouterProvider provider.Provider
193	for _, p := range providers {
194		switch p.ID {
195		case provider.InferenceProviderXAI:
196			xaiProvider = p
197		case provider.InferenceProviderOpenRouter:
198			openRouterProvider = p
199		}
200	}
201
202	require.NotEmpty(t, xaiProvider.ID)
203	assert.Equal(t, "xAI", xaiProvider.Name)
204	assert.Equal(t, "grok-beta", xaiProvider.DefaultLargeModelID)
205
206	require.NotEmpty(t, openRouterProvider.ID)
207	assert.Equal(t, "OpenRouter", openRouterProvider.Name)
208	assert.Equal(t, "anthropic/claude-3.5-sonnet", openRouterProvider.DefaultLargeModelID)
209}
210
211func TestO1ModelsInMockProvider(t *testing.T) {
212	originalUseMock := UseMockProviders
213	UseMockProviders = true
214	defer func() {
215		UseMockProviders = originalUseMock
216		ResetProviders()
217	}()
218
219	ResetProviders()
220	providers := Providers()
221	
222	var openaiProvider provider.Provider
223	for _, p := range providers {
224		if p.ID == provider.InferenceProviderOpenAI {
225			openaiProvider = p
226			break
227		}
228	}
229	require.NotEmpty(t, openaiProvider.ID)
230
231	modelTests := []struct {
232		id   string
233		name string
234	}{
235		{"o1-preview", "o1-preview"},
236		{"o1-mini", "o1-mini"},
237	}
238
239	for _, test := range modelTests {
240		var model provider.Model
241		var found bool
242		for _, m := range openaiProvider.Models {
243			if m.ID == test.id {
244				model = m
245				found = true
246				break
247			}
248		}
249		require.True(t, found, "Should find %s model", test.id)
250		assert.Equal(t, test.name, model.Name)
251		assert.True(t, model.CanReason)
252		assert.True(t, model.HasReasoningEffort)
253		assert.Equal(t, "medium", model.DefaultReasoningEffort)
254	}
255}
256
257func TestPreferredModelReasoningEffort(t *testing.T) {
258	// Test that PreferredModel struct can hold reasoning effort
259	preferredModel := PreferredModel{
260		ModelID:         "o1-preview",
261		Provider:        provider.InferenceProviderOpenAI,
262		ReasoningEffort: "high",
263	}
264
265	assert.Equal(t, "o1-preview", preferredModel.ModelID)
266	assert.Equal(t, provider.InferenceProviderOpenAI, preferredModel.Provider)
267	assert.Equal(t, "high", preferredModel.ReasoningEffort)
268
269	// Test JSON marshaling/unmarshaling
270	jsonData, err := json.Marshal(preferredModel)
271	require.NoError(t, err)
272
273	var unmarshaled PreferredModel
274	err = json.Unmarshal(jsonData, &unmarshaled)
275	require.NoError(t, err)
276
277	assert.Equal(t, preferredModel.ModelID, unmarshaled.ModelID)
278	assert.Equal(t, preferredModel.Provider, unmarshaled.Provider)
279	assert.Equal(t, preferredModel.ReasoningEffort, unmarshaled.ReasoningEffort)
280}