From e16fc3deeb917fdc6b7b19e3e84557fe67c00f73 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:46:25 +0200 Subject: [PATCH] feat: add config validation and provider system with mock support --- internal/config/config.go | 524 ++++++++++++++++++++++++++++++- internal/config/config_test.go | 108 ++++++- internal/config/provider.go | 14 + internal/config/provider_mock.go | 177 +++++++++++ internal/config/provider_test.go | 105 +++++++ 5 files changed, 909 insertions(+), 19 deletions(-) create mode 100644 internal/config/provider_mock.go create mode 100644 internal/config/provider_test.go diff --git a/internal/config/config.go b/internal/config/config.go index bddef684d9e1c45a5ed165cff000c3cb1d8302e2..e33aab02a492e8a1a4c55554fe5a3656d101ec1e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -325,6 +325,11 @@ func loadConfig(cwd string, debug bool) (*Config, error) { mergeMCPs(cfg, globalCfg, localConfig) mergeLSPs(cfg, globalCfg, localConfig) + // Validate the final configuration + if err := cfg.Validate(); err != nil { + return cfg, fmt.Errorf("configuration validation failed: %w", err) + } + return cfg, nil } @@ -503,7 +508,7 @@ func mergeAgents(base, global, local *Config) { base.Agents[agentID] = newAgent } else { baseAgent := base.Agents[agentID] - + // Special handling for known agents - only allow model changes if agentID == AgentCoder || agentID == AgentTask { if newAgent.Model != "" { @@ -533,10 +538,10 @@ func mergeAgents(base, global, local *Config) { } else if baseAgent.Model == "" { baseAgent.Model = LargeModel // Default fallback } - + // Boolean fields - always update (including false values) baseAgent.Disabled = newAgent.Disabled - + // Slice/Map fields - update if provided (including empty slices/maps) if newAgent.AllowedTools != nil { baseAgent.AllowedTools = newAgent.AllowedTools @@ -552,7 +557,7 @@ func mergeAgents(base, global, local *Config) { baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) } } - + base.Agents[agentID] = baseAgent } } @@ -666,6 +671,9 @@ func defaultConfigBasedOnEnv() *Config { ContextPaths: defaultContextPaths, }, Providers: make(map[provider.InferenceProvider]ProviderConfig), + Agents: make(map[AgentID]Agent), + LSP: make(map[string]LSPConfig), + MCP: make(map[string]MCP), } providers := Providers() @@ -681,9 +689,7 @@ func defaultConfigBasedOnEnv() *Config { baseURL := p.APIEndpoint if strings.HasPrefix(baseURL, "$") { envVar := strings.TrimPrefix(baseURL, "$") - if url := os.Getenv(envVar); url != "" { - baseURL = url - } + baseURL = os.Getenv(envVar) } providerConfig.BaseURL = baseURL for _, model := range p.Models { @@ -871,3 +877,507 @@ func UpdatePreferredModel(modelType ModelType, model PreferredModel) error { } return nil } + +// ValidationError represents a configuration validation error +type ValidationError struct { + Field string + Message string +} + +func (e ValidationError) Error() string { + return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message) +} + +// ValidationErrors represents multiple validation errors +type ValidationErrors []ValidationError + +func (e ValidationErrors) Error() string { + if len(e) == 0 { + return "no validation errors" + } + if len(e) == 1 { + return e[0].Error() + } + + var messages []string + for _, err := range e { + messages = append(messages, err.Error()) + } + return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; ")) +} + +// HasErrors returns true if there are any validation errors +func (e ValidationErrors) HasErrors() bool { + return len(e) > 0 +} + +// Add appends a new validation error +func (e *ValidationErrors) Add(field, message string) { + *e = append(*e, ValidationError{Field: field, Message: message}) +} + +// Validate performs comprehensive validation of the configuration +func (c *Config) Validate() error { + var errors ValidationErrors + + // Validate providers + c.validateProviders(&errors) + + // Validate models + c.validateModels(&errors) + + // Validate agents + c.validateAgents(&errors) + + // Validate options + c.validateOptions(&errors) + + // Validate MCP configurations + c.validateMCPs(&errors) + + // Validate LSP configurations + c.validateLSPs(&errors) + + // Validate cross-references + c.validateCrossReferences(&errors) + + // Validate completeness + c.validateCompleteness(&errors) + + if errors.HasErrors() { + return errors + } + + return nil +} + +// validateProviders validates all provider configurations +func (c *Config) validateProviders(errors *ValidationErrors) { + if c.Providers == nil { + c.Providers = make(map[provider.InferenceProvider]ProviderConfig) + } + + knownProviders := provider.KnownProviders() + validTypes := []provider.Type{ + provider.TypeOpenAI, + provider.TypeAnthropic, + provider.TypeGemini, + provider.TypeAzure, + provider.TypeBedrock, + provider.TypeVertexAI, + provider.TypeXAI, + } + + for providerID, providerConfig := range c.Providers { + fieldPrefix := fmt.Sprintf("providers.%s", providerID) + + // Validate API key for non-disabled providers + if !providerConfig.Disabled && providerConfig.APIKey == "" { + // Special case for AWS Bedrock and VertexAI which may use other auth methods + if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI { + errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers") + } + } + + // Validate provider type + validType := false + for _, vt := range validTypes { + if providerConfig.ProviderType == vt { + validType = true + break + } + } + if !validType { + errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType)) + } + + // Validate custom providers + isKnownProvider := false + for _, kp := range knownProviders { + if providerID == kp { + isKnownProvider = true + break + } + } + + if !isKnownProvider { + // Custom provider validation + if providerConfig.BaseURL == "" { + errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers") + } + if providerConfig.ProviderType != provider.TypeOpenAI { + errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type") + } + } + + // Validate models + modelIDs := make(map[string]bool) + for i, model := range providerConfig.Models { + modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i) + + // Check for duplicate model IDs + if modelIDs[model.ID] { + errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID)) + } + modelIDs[model.ID] = true + + // Validate required model fields + if model.ID == "" { + errors.Add(modelFieldPrefix+".id", "model ID is required") + } + if model.Name == "" { + errors.Add(modelFieldPrefix+".name", "model name is required") + } + if model.ContextWindow <= 0 { + errors.Add(modelFieldPrefix+".context_window", "context window must be positive") + } + if model.DefaultMaxTokens <= 0 { + errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive") + } + if model.DefaultMaxTokens > model.ContextWindow { + errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window") + } + + // Validate cost fields + if model.CostPer1MIn < 0 { + errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative") + } + if model.CostPer1MOut < 0 { + errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative") + } + if model.CostPer1MInCached < 0 { + errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative") + } + if model.CostPer1MOutCached < 0 { + errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative") + } + } + + // Validate default model references + if providerConfig.DefaultLargeModel != "" { + if !modelIDs[providerConfig.DefaultLargeModel] { + errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel)) + } + } + if providerConfig.DefaultSmallModel != "" { + if !modelIDs[providerConfig.DefaultSmallModel] { + errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel)) + } + } + + // Validate provider-specific requirements + c.validateProviderSpecific(providerID, providerConfig, errors) + } +} + +// validateProviderSpecific validates provider-specific requirements +func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) { + fieldPrefix := fmt.Sprintf("providers.%s", providerID) + + switch providerID { + case provider.InferenceProviderVertexAI: + if !providerConfig.Disabled { + if providerConfig.ExtraParams == nil { + errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration") + } else { + if providerConfig.ExtraParams["project"] == "" { + errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter") + } + if providerConfig.ExtraParams["location"] == "" { + errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter") + } + } + } + case provider.InferenceProviderBedrock: + if !providerConfig.Disabled { + if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" { + errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter") + } + // Check for AWS credentials in environment + if !hasAWSCredentials() { + errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment") + } + } + } +} + +// validateModels validates preferred model configurations +func (c *Config) validateModels(errors *ValidationErrors) { + // Validate large model + if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" { + if c.Models.Large.ModelID == "" { + errors.Add("models.large.model_id", "large model ID is required when provider is set") + } + if c.Models.Large.Provider == "" { + errors.Add("models.large.provider", "large model provider is required when model ID is set") + } + + // Check if provider exists and is not disabled + if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists { + if providerConfig.Disabled { + errors.Add("models.large.provider", "large model provider is disabled") + } + + // Check if model exists in provider + modelExists := false + for _, model := range providerConfig.Models { + if model.ID == c.Models.Large.ModelID { + modelExists = true + break + } + } + if !modelExists { + errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider)) + } + } else { + errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider)) + } + } + + // Validate small model + if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" { + if c.Models.Small.ModelID == "" { + errors.Add("models.small.model_id", "small model ID is required when provider is set") + } + if c.Models.Small.Provider == "" { + errors.Add("models.small.provider", "small model provider is required when model ID is set") + } + + // Check if provider exists and is not disabled + if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists { + if providerConfig.Disabled { + errors.Add("models.small.provider", "small model provider is disabled") + } + + // Check if model exists in provider + modelExists := false + for _, model := range providerConfig.Models { + if model.ID == c.Models.Small.ModelID { + modelExists = true + break + } + } + if !modelExists { + errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider)) + } + } else { + errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider)) + } + } +} + +// validateAgents validates agent configurations +func (c *Config) validateAgents(errors *ValidationErrors) { + if c.Agents == nil { + c.Agents = make(map[AgentID]Agent) + } + + validTools := []string{ + "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", + } + + for agentID, agent := range c.Agents { + fieldPrefix := fmt.Sprintf("agents.%s", agentID) + + // Validate agent ID consistency + if agent.ID != agentID { + errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID)) + } + + // Validate required fields + if agent.ID == "" { + errors.Add(fieldPrefix+".id", "agent ID is required") + } + if agent.Name == "" { + errors.Add(fieldPrefix+".name", "agent name is required") + } + + // Validate model type + if agent.Model != LargeModel && agent.Model != SmallModel { + errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model)) + } + + // Validate allowed tools + if agent.AllowedTools != nil { + for i, tool := range agent.AllowedTools { + validTool := false + for _, vt := range validTools { + if tool == vt { + validTool = true + break + } + } + if !validTool { + errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool)) + } + } + } + + // Validate MCP references + if agent.AllowedMCP != nil { + for mcpName := range agent.AllowedMCP { + if _, exists := c.MCP[mcpName]; !exists { + errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName)) + } + } + } + + // Validate LSP references + if agent.AllowedLSP != nil { + for _, lspName := range agent.AllowedLSP { + if _, exists := c.LSP[lspName]; !exists { + errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName)) + } + } + } + + // Validate context paths (basic path validation) + for i, contextPath := range agent.ContextPaths { + if contextPath == "" { + errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty") + } + // Check for invalid characters in path + if strings.Contains(contextPath, "\x00") { + errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters") + } + } + + // Validate known agents maintain their core properties + if agentID == AgentCoder { + if agent.Name != "Coder" { + errors.Add(fieldPrefix+".name", "coder agent name cannot be changed") + } + if agent.Description != "An agent that helps with executing coding tasks." { + errors.Add(fieldPrefix+".description", "coder agent description cannot be changed") + } + } else if agentID == AgentTask { + if agent.Name != "Task" { + errors.Add(fieldPrefix+".name", "task agent name cannot be changed") + } + if agent.Description != "An agent that helps with searching for context and finding implementation details." { + errors.Add(fieldPrefix+".description", "task agent description cannot be changed") + } + expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} + if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) { + errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed") + } + } + } +} + +// validateOptions validates configuration options +func (c *Config) validateOptions(errors *ValidationErrors) { + // Validate data directory + if c.Options.DataDirectory == "" { + errors.Add("options.data_directory", "data directory is required") + } + + // Validate context paths + for i, contextPath := range c.Options.ContextPaths { + if contextPath == "" { + errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty") + } + if strings.Contains(contextPath, "\x00") { + errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters") + } + } +} + +// validateMCPs validates MCP configurations +func (c *Config) validateMCPs(errors *ValidationErrors) { + if c.MCP == nil { + c.MCP = make(map[string]MCP) + } + + for mcpName, mcpConfig := range c.MCP { + fieldPrefix := fmt.Sprintf("mcp.%s", mcpName) + + // Validate MCP type + if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse { + errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type)) + } + + // Validate based on type + if mcpConfig.Type == MCPStdio { + if mcpConfig.Command == "" { + errors.Add(fieldPrefix+".command", "command is required for stdio MCP") + } + } else if mcpConfig.Type == MCPSse { + if mcpConfig.URL == "" { + errors.Add(fieldPrefix+".url", "URL is required for SSE MCP") + } + } + } +} + +// validateLSPs validates LSP configurations +func (c *Config) validateLSPs(errors *ValidationErrors) { + if c.LSP == nil { + c.LSP = make(map[string]LSPConfig) + } + + for lspName, lspConfig := range c.LSP { + fieldPrefix := fmt.Sprintf("lsp.%s", lspName) + + if lspConfig.Command == "" { + errors.Add(fieldPrefix+".command", "command is required for LSP") + } + } +} + +// validateCrossReferences validates cross-references between different config sections +func (c *Config) validateCrossReferences(errors *ValidationErrors) { + // Validate that agents can use their assigned model types + for agentID, agent := range c.Agents { + fieldPrefix := fmt.Sprintf("agents.%s", agentID) + + var preferredModel PreferredModel + switch agent.Model { + case LargeModel: + preferredModel = c.Models.Large + case SmallModel: + preferredModel = c.Models.Small + } + + if preferredModel.Provider != "" { + if providerConfig, exists := c.Providers[preferredModel.Provider]; exists { + if providerConfig.Disabled { + errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider)) + } + } + } + } +} + +// validateCompleteness validates that the configuration is complete and usable +func (c *Config) validateCompleteness(errors *ValidationErrors) { + // Check for at least one valid, non-disabled provider + hasValidProvider := false + for _, providerConfig := range c.Providers { + if !providerConfig.Disabled { + hasValidProvider = true + break + } + } + if !hasValidProvider { + errors.Add("providers", "at least one non-disabled provider is required") + } + + // Check that default agents exist + if _, exists := c.Agents[AgentCoder]; !exists { + errors.Add("agents", "coder agent is required") + } + if _, exists := c.Agents[AgentTask]; !exists { + errors.Add("agents", "task agent is required") + } + + // Check that preferred models are set if providers exist + if hasValidProvider { + if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" { + errors.Add("models.large", "large preferred model must be configured when providers are available") + } + if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" { + errors.Add("models.small", "small preferred model must be configured when providers are available") + } + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 2f8b4a8acd8be9d979a663a53c28788a66ffd396..f69b3c84701b07c6df948c7abca6e37a65e3c69e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -50,6 +50,10 @@ func reset() { instance = nil cwd = "" testConfigDir = "" + + // Enable mock providers for all tests to avoid API calls + UseMockProviders = true + ResetProviders() } // Core Configuration Loading Tests @@ -133,9 +137,29 @@ func TestLoadConfig_OnlyGlobalConfig(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + CostPer1MIn: 1.0, + CostPer1MOut: 2.0, + ContextWindow: 4096, + DefaultMaxTokens: 4096, + }, + }, }, }, Options: Options{ @@ -167,9 +191,29 @@ func TestLoadConfig_OnlyLocalConfig(t *testing.T) { localConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "local-key", - ProviderType: provider.TypeAnthropic, + ID: provider.InferenceProviderAnthropic, + APIKey: "local-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + CostPer1MIn: 15.0, + CostPer1MOut: 75.0, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + CostPer1MIn: 0.25, + CostPer1MOut: 1.25, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + }, }, }, Options: Options{ @@ -199,9 +243,29 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "global-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "global-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + CostPer1MIn: 1.0, + CostPer1MOut: 2.0, + ContextWindow: 4096, + DefaultMaxTokens: 4096, + }, + }, }, }, Options: Options{ @@ -222,9 +286,29 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { APIKey: "local-key", // Override global }, provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "anthropic-key", - ProviderType: provider.TypeAnthropic, + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + CostPer1MIn: 15.0, + CostPer1MOut: 75.0, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + CostPer1MIn: 0.25, + CostPer1MOut: 1.25, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + }, }, }, Options: Options{ diff --git a/internal/config/provider.go b/internal/config/provider.go index 4c2b61ff6d5d86f62a8a1833a6ea91b500bbc7b0..09e3b0e3fc84b9e2688ccc4d2559604aca83ddfc 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -15,6 +15,8 @@ var fur = client.New() var ( providerOnc sync.Once // Ensures the initialization happens only once providerList []provider.Provider + // UseMockProviders can be set to true in tests to avoid API calls + UseMockProviders bool ) func providersPath() string { @@ -50,6 +52,12 @@ func loadProviders() ([]provider.Provider, error) { func Providers() []provider.Provider { providerOnc.Do(func() { + // Use mock providers when testing + if UseMockProviders { + providerList = MockProviders() + return + } + // Try to get providers from upstream API if providers, err := fur.GetProviders(); err == nil { providerList = providers @@ -67,3 +75,9 @@ func Providers() []provider.Provider { }) return providerList } + +// ResetProviders resets the provider cache. Useful for testing. +func ResetProviders() { + providerOnc = sync.Once{} + providerList = nil +} diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go new file mode 100644 index 0000000000000000000000000000000000000000..86b87768b95246654e176ca5f40af5aef249c23f --- /dev/null +++ b/internal/config/provider_mock.go @@ -0,0 +1,177 @@ +package config + +import ( + "github.com/charmbracelet/crush/internal/fur/provider" +) + +// MockProviders returns a mock list of providers for testing. +// This avoids making API calls during tests and provides consistent test data. +func MockProviders() []provider.Provider { + return []provider.Provider{ + { + Name: "Anthropic", + ID: provider.InferenceProviderAnthropic, + APIKey: "$ANTHROPIC_API_KEY", + APIEndpoint: "$ANTHROPIC_API_ENDPOINT", + Type: provider.TypeAnthropic, + DefaultLargeModelID: "claude-3-opus", + DefaultSmallModelID: "claude-3-haiku", + Models: []provider.Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + CostPer1MIn: 15.0, + CostPer1MOut: 75.0, + CostPer1MInCached: 18.75, + CostPer1MOutCached: 1.5, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + CostPer1MIn: 0.25, + CostPer1MOut: 1.25, + CostPer1MInCached: 0.3, + CostPer1MOutCached: 0.03, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, + }, + { + ID: "claude-3-5-sonnet-20241022", + Name: "Claude 3.5 Sonnet", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.3, + ContextWindow: 200000, + DefaultMaxTokens: 8192, + CanReason: false, + SupportsImages: true, + }, + { + ID: "claude-3-5-haiku-20241022", + Name: "Claude 3.5 Haiku", + CostPer1MIn: 0.8, + CostPer1MOut: 4.0, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + ContextWindow: 200000, + DefaultMaxTokens: 8192, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "OpenAI", + ID: provider.InferenceProviderOpenAI, + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "$OPENAI_API_ENDPOINT", + Type: provider.TypeOpenAI, + DefaultLargeModelID: "gpt-4", + DefaultSmallModelID: "gpt-3.5-turbo", + Models: []provider.Model{ + { + ID: "gpt-4", + Name: "GPT-4", + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: false, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + CostPer1MIn: 1.0, + CostPer1MOut: 2.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 4096, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: false, + }, + { + ID: "gpt-4-turbo", + Name: "GPT-4 Turbo", + CostPer1MIn: 10.0, + CostPer1MOut: 30.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, + }, + { + ID: "gpt-4o", + Name: "GPT-4o", + CostPer1MIn: 2.5, + CostPer1MOut: 10.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 1.25, + ContextWindow: 128000, + DefaultMaxTokens: 16384, + CanReason: false, + SupportsImages: true, + }, + { + ID: "gpt-4o-mini", + Name: "GPT-4o-mini", + CostPer1MIn: 0.15, + CostPer1MOut: 0.6, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.075, + ContextWindow: 128000, + DefaultMaxTokens: 16384, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "Google Gemini", + ID: provider.InferenceProviderGemini, + APIKey: "$GEMINI_API_KEY", + APIEndpoint: "$GEMINI_API_ENDPOINT", + Type: provider.TypeGemini, + DefaultLargeModelID: "gemini-2.5-pro", + DefaultSmallModelID: "gemini-2.5-flash", + Models: []provider.Model{ + { + ID: "gemini-2.5-pro", + Name: "Gemini 2.5 Pro", + CostPer1MIn: 1.25, + CostPer1MOut: 10.0, + CostPer1MInCached: 1.625, + CostPer1MOutCached: 0.31, + ContextWindow: 1048576, + DefaultMaxTokens: 65536, + CanReason: true, + SupportsImages: true, + }, + { + ID: "gemini-2.5-flash", + Name: "Gemini 2.5 Flash", + CostPer1MIn: 0.3, + CostPer1MOut: 2.5, + CostPer1MInCached: 0.3833, + CostPer1MOutCached: 0.075, + ContextWindow: 1048576, + DefaultMaxTokens: 65535, + CanReason: true, + SupportsImages: true, + }, + }, + }, + } +} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go new file mode 100644 index 0000000000000000000000000000000000000000..92547ff2925699d8519c33656395d3979a095b35 --- /dev/null +++ b/internal/config/provider_test.go @@ -0,0 +1,105 @@ +package config + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMockProviders(t *testing.T) { + // Enable mock providers for testing + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + // Reset providers to ensure we get fresh mock data + ResetProviders() + + providers := Providers() + require.NotEmpty(t, providers, "Mock providers should not be empty") + + // Verify we have the expected mock providers + providerIDs := make(map[provider.InferenceProvider]bool) + for _, p := range providers { + providerIDs[p.ID] = true + } + + assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider") + assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider") + assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider") + + // Verify Anthropic provider details + var anthropicProvider provider.Provider + for _, p := range providers { + if p.ID == provider.InferenceProviderAnthropic { + anthropicProvider = p + break + } + } + + assert.Equal(t, "Anthropic", anthropicProvider.Name) + assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type) + assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID) + assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID) + assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models") + + // Verify model details + var opusModel provider.Model + for _, m := range anthropicProvider.Models { + if m.ID == "claude-3-opus" { + opusModel = m + break + } + } + + assert.Equal(t, "Claude 3 Opus", opusModel.Name) + assert.Equal(t, int64(200000), opusModel.ContextWindow) + assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens) + assert.True(t, opusModel.SupportsImages) +} + +func TestProvidersWithoutMock(t *testing.T) { + // Ensure mock is disabled + originalUseMock := UseMockProviders + UseMockProviders = false + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + // Reset providers to ensure we get fresh data + ResetProviders() + + // This will try to make an actual API call or use cached data + providers := Providers() + + // We can't guarantee what we'll get here since it depends on network/cache + // but we can at least verify the function doesn't panic + t.Logf("Got %d providers without mock", len(providers)) +} + +func TestResetProviders(t *testing.T) { + // Enable mock providers + UseMockProviders = true + defer func() { + UseMockProviders = false + ResetProviders() + }() + + // Get providers once + providers1 := Providers() + require.NotEmpty(t, providers1) + + // Reset and get again + ResetProviders() + providers2 := Providers() + require.NotEmpty(t, providers2) + + // Should get the same mock data + assert.Equal(t, len(providers1), len(providers2)) +}