1package config
2
3import (
4 "testing"
5
6 "github.com/charmbracelet/crush/internal/fur/provider"
7 "github.com/stretchr/testify/assert"
8 "github.com/stretchr/testify/require"
9)
10
11func TestMockProviders(t *testing.T) {
12 // Enable mock providers for testing
13 originalUseMock := UseMockProviders
14 UseMockProviders = true
15 defer func() {
16 UseMockProviders = originalUseMock
17 ResetProviders()
18 }()
19
20 // Reset providers to ensure we get fresh mock data
21 ResetProviders()
22
23 providers := Providers()
24 require.NotEmpty(t, providers, "Mock providers should not be empty")
25
26 // Verify we have the expected mock providers
27 providerIDs := make(map[provider.InferenceProvider]bool)
28 for _, p := range providers {
29 providerIDs[p.ID] = true
30 }
31
32 assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider")
33 assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider")
34 assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider")
35
36 // Verify Anthropic provider details
37 var anthropicProvider provider.Provider
38 for _, p := range providers {
39 if p.ID == provider.InferenceProviderAnthropic {
40 anthropicProvider = p
41 break
42 }
43 }
44
45 assert.Equal(t, "Anthropic", anthropicProvider.Name)
46 assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type)
47 assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID)
48 assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID)
49 assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models")
50
51 // Verify model details
52 var opusModel provider.Model
53 for _, m := range anthropicProvider.Models {
54 if m.ID == "claude-3-opus" {
55 opusModel = m
56 break
57 }
58 }
59
60 assert.Equal(t, "Claude 3 Opus", opusModel.Name)
61 assert.Equal(t, int64(200000), opusModel.ContextWindow)
62 assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens)
63 assert.True(t, opusModel.SupportsImages)
64}
65
66func TestProvidersWithoutMock(t *testing.T) {
67 // Ensure mock is disabled
68 originalUseMock := UseMockProviders
69 UseMockProviders = false
70 defer func() {
71 UseMockProviders = originalUseMock
72 ResetProviders()
73 }()
74
75 // Reset providers to ensure we get fresh data
76 ResetProviders()
77
78 // This will try to make an actual API call or use cached data
79 providers := Providers()
80
81 // We can't guarantee what we'll get here since it depends on network/cache
82 // but we can at least verify the function doesn't panic
83 t.Logf("Got %d providers without mock", len(providers))
84}
85
86func TestResetProviders(t *testing.T) {
87 // Enable mock providers
88 UseMockProviders = true
89 defer func() {
90 UseMockProviders = false
91 ResetProviders()
92 }()
93
94 // Get providers once
95 providers1 := Providers()
96 require.NotEmpty(t, providers1)
97
98 // Reset and get again
99 ResetProviders()
100 providers2 := Providers()
101 require.NotEmpty(t, providers2)
102
103 // Should get the same mock data
104 assert.Equal(t, len(providers1), len(providers2))
105}