provider_test.go

  1package config
  2
  3import (
  4	"testing"
  5
  6	"github.com/charmbracelet/crush/internal/fur/provider"
  7	"github.com/stretchr/testify/assert"
  8	"github.com/stretchr/testify/require"
  9)
 10
 11func TestMockProviders(t *testing.T) {
 12	// Enable mock providers for testing
 13	originalUseMock := UseMockProviders
 14	UseMockProviders = true
 15	defer func() {
 16		UseMockProviders = originalUseMock
 17		ResetProviders()
 18	}()
 19
 20	// Reset providers to ensure we get fresh mock data
 21	ResetProviders()
 22
 23	providers := Providers()
 24	require.NotEmpty(t, providers, "Mock providers should not be empty")
 25
 26	// Verify we have the expected mock providers
 27	providerIDs := make(map[provider.InferenceProvider]bool)
 28	for _, p := range providers {
 29		providerIDs[p.ID] = true
 30	}
 31
 32	assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider")
 33	assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider")
 34	assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider")
 35
 36	// Verify Anthropic provider details
 37	var anthropicProvider provider.Provider
 38	for _, p := range providers {
 39		if p.ID == provider.InferenceProviderAnthropic {
 40			anthropicProvider = p
 41			break
 42		}
 43	}
 44
 45	assert.Equal(t, "Anthropic", anthropicProvider.Name)
 46	assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type)
 47	assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID)
 48	assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID)
 49	assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models")
 50
 51	// Verify model details
 52	var opusModel provider.Model
 53	for _, m := range anthropicProvider.Models {
 54		if m.ID == "claude-3-opus" {
 55			opusModel = m
 56			break
 57		}
 58	}
 59
 60	assert.Equal(t, "Claude 3 Opus", opusModel.Name)
 61	assert.Equal(t, int64(200000), opusModel.ContextWindow)
 62	assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens)
 63	assert.True(t, opusModel.SupportsImages)
 64}
 65
 66func TestProvidersWithoutMock(t *testing.T) {
 67	// Ensure mock is disabled
 68	originalUseMock := UseMockProviders
 69	UseMockProviders = false
 70	defer func() {
 71		UseMockProviders = originalUseMock
 72		ResetProviders()
 73	}()
 74
 75	// Reset providers to ensure we get fresh data
 76	ResetProviders()
 77
 78	// This will try to make an actual API call or use cached data
 79	providers := Providers()
 80
 81	// We can't guarantee what we'll get here since it depends on network/cache
 82	// but we can at least verify the function doesn't panic
 83	t.Logf("Got %d providers without mock", len(providers))
 84}
 85
 86func TestResetProviders(t *testing.T) {
 87	// Enable mock providers
 88	UseMockProviders = true
 89	defer func() {
 90		UseMockProviders = false
 91		ResetProviders()
 92	}()
 93
 94	// Get providers once
 95	providers1 := Providers()
 96	require.NotEmpty(t, providers1)
 97
 98	// Reset and get again
 99	ResetProviders()
100	providers2 := Providers()
101	require.NotEmpty(t, providers2)
102
103	// Should get the same mock data
104	assert.Equal(t, len(providers1), len(providers2))
105}