Detailed changes
@@ -325,6 +325,11 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
mergeMCPs(cfg, globalCfg, localConfig)
mergeLSPs(cfg, globalCfg, localConfig)
+ // Validate the final configuration
+ if err := cfg.Validate(); err != nil {
+ return cfg, fmt.Errorf("configuration validation failed: %w", err)
+ }
+
return cfg, nil
}
@@ -503,7 +508,7 @@ func mergeAgents(base, global, local *Config) {
base.Agents[agentID] = newAgent
} else {
baseAgent := base.Agents[agentID]
-
+
// Special handling for known agents - only allow model changes
if agentID == AgentCoder || agentID == AgentTask {
if newAgent.Model != "" {
@@ -533,10 +538,10 @@ func mergeAgents(base, global, local *Config) {
} else if baseAgent.Model == "" {
baseAgent.Model = LargeModel // Default fallback
}
-
+
// Boolean fields - always update (including false values)
baseAgent.Disabled = newAgent.Disabled
-
+
// Slice/Map fields - update if provided (including empty slices/maps)
if newAgent.AllowedTools != nil {
baseAgent.AllowedTools = newAgent.AllowedTools
@@ -552,7 +557,7 @@ func mergeAgents(base, global, local *Config) {
baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
}
}
-
+
base.Agents[agentID] = baseAgent
}
}
@@ -666,6 +671,9 @@ func defaultConfigBasedOnEnv() *Config {
ContextPaths: defaultContextPaths,
},
Providers: make(map[provider.InferenceProvider]ProviderConfig),
+ Agents: make(map[AgentID]Agent),
+ LSP: make(map[string]LSPConfig),
+ MCP: make(map[string]MCP),
}
providers := Providers()
@@ -681,9 +689,7 @@ func defaultConfigBasedOnEnv() *Config {
baseURL := p.APIEndpoint
if strings.HasPrefix(baseURL, "$") {
envVar := strings.TrimPrefix(baseURL, "$")
- if url := os.Getenv(envVar); url != "" {
- baseURL = url
- }
+ baseURL = os.Getenv(envVar)
}
providerConfig.BaseURL = baseURL
for _, model := range p.Models {
@@ -871,3 +877,507 @@ func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
}
return nil
}
+
+// ValidationError represents a configuration validation error
+type ValidationError struct {
+ Field string
+ Message string
+}
+
+func (e ValidationError) Error() string {
+ return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
+}
+
+// ValidationErrors represents multiple validation errors
+type ValidationErrors []ValidationError
+
+func (e ValidationErrors) Error() string {
+ if len(e) == 0 {
+ return "no validation errors"
+ }
+ if len(e) == 1 {
+ return e[0].Error()
+ }
+
+ var messages []string
+ for _, err := range e {
+ messages = append(messages, err.Error())
+ }
+ return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
+}
+
+// HasErrors returns true if there are any validation errors
+func (e ValidationErrors) HasErrors() bool {
+ return len(e) > 0
+}
+
+// Add appends a new validation error
+func (e *ValidationErrors) Add(field, message string) {
+ *e = append(*e, ValidationError{Field: field, Message: message})
+}
+
+// Validate performs comprehensive validation of the configuration
+func (c *Config) Validate() error {
+ var errors ValidationErrors
+
+ // Validate providers
+ c.validateProviders(&errors)
+
+ // Validate models
+ c.validateModels(&errors)
+
+ // Validate agents
+ c.validateAgents(&errors)
+
+ // Validate options
+ c.validateOptions(&errors)
+
+ // Validate MCP configurations
+ c.validateMCPs(&errors)
+
+ // Validate LSP configurations
+ c.validateLSPs(&errors)
+
+ // Validate cross-references
+ c.validateCrossReferences(&errors)
+
+ // Validate completeness
+ c.validateCompleteness(&errors)
+
+ if errors.HasErrors() {
+ return errors
+ }
+
+ return nil
+}
+
+// validateProviders validates all provider configurations
+func (c *Config) validateProviders(errors *ValidationErrors) {
+ if c.Providers == nil {
+ c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
+ }
+
+ knownProviders := provider.KnownProviders()
+ validTypes := []provider.Type{
+ provider.TypeOpenAI,
+ provider.TypeAnthropic,
+ provider.TypeGemini,
+ provider.TypeAzure,
+ provider.TypeBedrock,
+ provider.TypeVertexAI,
+ provider.TypeXAI,
+ }
+
+ for providerID, providerConfig := range c.Providers {
+ fieldPrefix := fmt.Sprintf("providers.%s", providerID)
+
+ // Validate API key for non-disabled providers
+ if !providerConfig.Disabled && providerConfig.APIKey == "" {
+ // Special case for AWS Bedrock and VertexAI which may use other auth methods
+ if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
+ errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
+ }
+ }
+
+ // Validate provider type
+ validType := false
+ for _, vt := range validTypes {
+ if providerConfig.ProviderType == vt {
+ validType = true
+ break
+ }
+ }
+ if !validType {
+ errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
+ }
+
+ // Validate custom providers
+ isKnownProvider := false
+ for _, kp := range knownProviders {
+ if providerID == kp {
+ isKnownProvider = true
+ break
+ }
+ }
+
+ if !isKnownProvider {
+ // Custom provider validation
+ if providerConfig.BaseURL == "" {
+ errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
+ }
+ if providerConfig.ProviderType != provider.TypeOpenAI {
+ errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
+ }
+ }
+
+ // Validate models
+ modelIDs := make(map[string]bool)
+ for i, model := range providerConfig.Models {
+ modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
+
+ // Check for duplicate model IDs
+ if modelIDs[model.ID] {
+ errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
+ }
+ modelIDs[model.ID] = true
+
+ // Validate required model fields
+ if model.ID == "" {
+ errors.Add(modelFieldPrefix+".id", "model ID is required")
+ }
+ if model.Name == "" {
+ errors.Add(modelFieldPrefix+".name", "model name is required")
+ }
+ if model.ContextWindow <= 0 {
+ errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
+ }
+ if model.DefaultMaxTokens <= 0 {
+ errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
+ }
+ if model.DefaultMaxTokens > model.ContextWindow {
+ errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
+ }
+
+ // Validate cost fields
+ if model.CostPer1MIn < 0 {
+ errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
+ }
+ if model.CostPer1MOut < 0 {
+ errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
+ }
+ if model.CostPer1MInCached < 0 {
+ errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
+ }
+ if model.CostPer1MOutCached < 0 {
+ errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
+ }
+ }
+
+ // Validate default model references
+ if providerConfig.DefaultLargeModel != "" {
+ if !modelIDs[providerConfig.DefaultLargeModel] {
+ errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
+ }
+ }
+ if providerConfig.DefaultSmallModel != "" {
+ if !modelIDs[providerConfig.DefaultSmallModel] {
+ errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
+ }
+ }
+
+ // Validate provider-specific requirements
+ c.validateProviderSpecific(providerID, providerConfig, errors)
+ }
+}
+
+// validateProviderSpecific validates provider-specific requirements
+func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
+ fieldPrefix := fmt.Sprintf("providers.%s", providerID)
+
+ switch providerID {
+ case provider.InferenceProviderVertexAI:
+ if !providerConfig.Disabled {
+ if providerConfig.ExtraParams == nil {
+ errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
+ } else {
+ if providerConfig.ExtraParams["project"] == "" {
+ errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
+ }
+ if providerConfig.ExtraParams["location"] == "" {
+ errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
+ }
+ }
+ }
+ case provider.InferenceProviderBedrock:
+ if !providerConfig.Disabled {
+ if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
+ errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
+ }
+ // Check for AWS credentials in environment
+ if !hasAWSCredentials() {
+ errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
+ }
+ }
+ }
+}
+
+// validateModels validates preferred model configurations
+func (c *Config) validateModels(errors *ValidationErrors) {
+ // Validate large model
+ if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
+ if c.Models.Large.ModelID == "" {
+ errors.Add("models.large.model_id", "large model ID is required when provider is set")
+ }
+ if c.Models.Large.Provider == "" {
+ errors.Add("models.large.provider", "large model provider is required when model ID is set")
+ }
+
+ // Check if provider exists and is not disabled
+ if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
+ if providerConfig.Disabled {
+ errors.Add("models.large.provider", "large model provider is disabled")
+ }
+
+ // Check if model exists in provider
+ modelExists := false
+ for _, model := range providerConfig.Models {
+ if model.ID == c.Models.Large.ModelID {
+ modelExists = true
+ break
+ }
+ }
+ if !modelExists {
+ errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
+ }
+ } else {
+ errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
+ }
+ }
+
+ // Validate small model
+ if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
+ if c.Models.Small.ModelID == "" {
+ errors.Add("models.small.model_id", "small model ID is required when provider is set")
+ }
+ if c.Models.Small.Provider == "" {
+ errors.Add("models.small.provider", "small model provider is required when model ID is set")
+ }
+
+ // Check if provider exists and is not disabled
+ if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
+ if providerConfig.Disabled {
+ errors.Add("models.small.provider", "small model provider is disabled")
+ }
+
+ // Check if model exists in provider
+ modelExists := false
+ for _, model := range providerConfig.Models {
+ if model.ID == c.Models.Small.ModelID {
+ modelExists = true
+ break
+ }
+ }
+ if !modelExists {
+ errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
+ }
+ } else {
+ errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
+ }
+ }
+}
+
+// validateAgents validates agent configurations
+func (c *Config) validateAgents(errors *ValidationErrors) {
+ if c.Agents == nil {
+ c.Agents = make(map[AgentID]Agent)
+ }
+
+ validTools := []string{
+ "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
+ }
+
+ for agentID, agent := range c.Agents {
+ fieldPrefix := fmt.Sprintf("agents.%s", agentID)
+
+ // Validate agent ID consistency
+ if agent.ID != agentID {
+ errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
+ }
+
+ // Validate required fields
+ if agent.ID == "" {
+ errors.Add(fieldPrefix+".id", "agent ID is required")
+ }
+ if agent.Name == "" {
+ errors.Add(fieldPrefix+".name", "agent name is required")
+ }
+
+ // Validate model type
+ if agent.Model != LargeModel && agent.Model != SmallModel {
+ errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
+ }
+
+ // Validate allowed tools
+ if agent.AllowedTools != nil {
+ for i, tool := range agent.AllowedTools {
+ validTool := false
+ for _, vt := range validTools {
+ if tool == vt {
+ validTool = true
+ break
+ }
+ }
+ if !validTool {
+ errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
+ }
+ }
+ }
+
+ // Validate MCP references
+ if agent.AllowedMCP != nil {
+ for mcpName := range agent.AllowedMCP {
+ if _, exists := c.MCP[mcpName]; !exists {
+ errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
+ }
+ }
+ }
+
+ // Validate LSP references
+ if agent.AllowedLSP != nil {
+ for _, lspName := range agent.AllowedLSP {
+ if _, exists := c.LSP[lspName]; !exists {
+ errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
+ }
+ }
+ }
+
+ // Validate context paths (basic path validation)
+ for i, contextPath := range agent.ContextPaths {
+ if contextPath == "" {
+ errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
+ }
+ // Check for invalid characters in path
+ if strings.Contains(contextPath, "\x00") {
+ errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
+ }
+ }
+
+ // Validate known agents maintain their core properties
+ if agentID == AgentCoder {
+ if agent.Name != "Coder" {
+ errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
+ }
+ if agent.Description != "An agent that helps with executing coding tasks." {
+ errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
+ }
+ } else if agentID == AgentTask {
+ if agent.Name != "Task" {
+ errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
+ }
+ if agent.Description != "An agent that helps with searching for context and finding implementation details." {
+ errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
+ }
+ expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
+ if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
+ errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
+ }
+ }
+ }
+}
+
+// validateOptions validates configuration options
+func (c *Config) validateOptions(errors *ValidationErrors) {
+ // Validate data directory
+ if c.Options.DataDirectory == "" {
+ errors.Add("options.data_directory", "data directory is required")
+ }
+
+ // Validate context paths
+ for i, contextPath := range c.Options.ContextPaths {
+ if contextPath == "" {
+ errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
+ }
+ if strings.Contains(contextPath, "\x00") {
+ errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
+ }
+ }
+}
+
+// validateMCPs validates MCP configurations
+func (c *Config) validateMCPs(errors *ValidationErrors) {
+ if c.MCP == nil {
+ c.MCP = make(map[string]MCP)
+ }
+
+ for mcpName, mcpConfig := range c.MCP {
+ fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
+
+ // Validate MCP type
+ if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
+ errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
+ }
+
+ // Validate based on type
+ if mcpConfig.Type == MCPStdio {
+ if mcpConfig.Command == "" {
+ errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
+ }
+ } else if mcpConfig.Type == MCPSse {
+ if mcpConfig.URL == "" {
+ errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
+ }
+ }
+ }
+}
+
+// validateLSPs validates LSP configurations
+func (c *Config) validateLSPs(errors *ValidationErrors) {
+ if c.LSP == nil {
+ c.LSP = make(map[string]LSPConfig)
+ }
+
+ for lspName, lspConfig := range c.LSP {
+ fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
+
+ if lspConfig.Command == "" {
+ errors.Add(fieldPrefix+".command", "command is required for LSP")
+ }
+ }
+}
+
+// validateCrossReferences validates cross-references between different config sections
+func (c *Config) validateCrossReferences(errors *ValidationErrors) {
+ // Validate that agents can use their assigned model types
+ for agentID, agent := range c.Agents {
+ fieldPrefix := fmt.Sprintf("agents.%s", agentID)
+
+ var preferredModel PreferredModel
+ switch agent.Model {
+ case LargeModel:
+ preferredModel = c.Models.Large
+ case SmallModel:
+ preferredModel = c.Models.Small
+ }
+
+ if preferredModel.Provider != "" {
+ if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
+ if providerConfig.Disabled {
+ errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
+ }
+ }
+ }
+ }
+}
+
+// validateCompleteness validates that the configuration is complete and usable
+func (c *Config) validateCompleteness(errors *ValidationErrors) {
+ // Check for at least one valid, non-disabled provider
+ hasValidProvider := false
+ for _, providerConfig := range c.Providers {
+ if !providerConfig.Disabled {
+ hasValidProvider = true
+ break
+ }
+ }
+ if !hasValidProvider {
+ errors.Add("providers", "at least one non-disabled provider is required")
+ }
+
+ // Check that default agents exist
+ if _, exists := c.Agents[AgentCoder]; !exists {
+ errors.Add("agents", "coder agent is required")
+ }
+ if _, exists := c.Agents[AgentTask]; !exists {
+ errors.Add("agents", "task agent is required")
+ }
+
+ // Check that preferred models are set if providers exist
+ if hasValidProvider {
+ if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
+ errors.Add("models.large", "large preferred model must be configured when providers are available")
+ }
+ if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
+ errors.Add("models.small", "small preferred model must be configured when providers are available")
+ }
+ }
+}
@@ -50,6 +50,10 @@ func reset() {
instance = nil
cwd = ""
testConfigDir = ""
+
+ // Enable mock providers for all tests to avoid API calls
+ UseMockProviders = true
+ ResetProviders()
}
// Core Configuration Loading Tests
@@ -133,9 +137,29 @@ func TestLoadConfig_OnlyGlobalConfig(t *testing.T) {
globalConfig := Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ CostPer1MIn: 30.0,
+ CostPer1MOut: 60.0,
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ CostPer1MIn: 1.0,
+ CostPer1MOut: 2.0,
+ ContextWindow: 4096,
+ DefaultMaxTokens: 4096,
+ },
+ },
},
},
Options: Options{
@@ -167,9 +191,29 @@ func TestLoadConfig_OnlyLocalConfig(t *testing.T) {
localConfig := Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderAnthropic: {
- ID: provider.InferenceProviderAnthropic,
- APIKey: "local-key",
- ProviderType: provider.TypeAnthropic,
+ ID: provider.InferenceProviderAnthropic,
+ APIKey: "local-key",
+ ProviderType: provider.TypeAnthropic,
+ DefaultLargeModel: "claude-3-opus",
+ DefaultSmallModel: "claude-3-haiku",
+ Models: []Model{
+ {
+ ID: "claude-3-opus",
+ Name: "Claude 3 Opus",
+ CostPer1MIn: 15.0,
+ CostPer1MOut: 75.0,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "claude-3-haiku",
+ Name: "Claude 3 Haiku",
+ CostPer1MIn: 0.25,
+ CostPer1MOut: 1.25,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ },
},
},
Options: Options{
@@ -199,9 +243,29 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) {
globalConfig := Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "global-key",
- ProviderType: provider.TypeOpenAI,
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "global-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ CostPer1MIn: 30.0,
+ CostPer1MOut: 60.0,
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ CostPer1MIn: 1.0,
+ CostPer1MOut: 2.0,
+ ContextWindow: 4096,
+ DefaultMaxTokens: 4096,
+ },
+ },
},
},
Options: Options{
@@ -222,9 +286,29 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) {
APIKey: "local-key", // Override global
},
provider.InferenceProviderAnthropic: {
- ID: provider.InferenceProviderAnthropic,
- APIKey: "anthropic-key",
- ProviderType: provider.TypeAnthropic,
+ ID: provider.InferenceProviderAnthropic,
+ APIKey: "anthropic-key",
+ ProviderType: provider.TypeAnthropic,
+ DefaultLargeModel: "claude-3-opus",
+ DefaultSmallModel: "claude-3-haiku",
+ Models: []Model{
+ {
+ ID: "claude-3-opus",
+ Name: "Claude 3 Opus",
+ CostPer1MIn: 15.0,
+ CostPer1MOut: 75.0,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "claude-3-haiku",
+ Name: "Claude 3 Haiku",
+ CostPer1MIn: 0.25,
+ CostPer1MOut: 1.25,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ },
},
},
Options: Options{
@@ -15,6 +15,8 @@ var fur = client.New()
var (
providerOnc sync.Once // Ensures the initialization happens only once
providerList []provider.Provider
+ // UseMockProviders can be set to true in tests to avoid API calls
+ UseMockProviders bool
)
func providersPath() string {
@@ -50,6 +52,12 @@ func loadProviders() ([]provider.Provider, error) {
func Providers() []provider.Provider {
providerOnc.Do(func() {
+ // Use mock providers when testing
+ if UseMockProviders {
+ providerList = MockProviders()
+ return
+ }
+
// Try to get providers from upstream API
if providers, err := fur.GetProviders(); err == nil {
providerList = providers
@@ -67,3 +75,9 @@ func Providers() []provider.Provider {
})
return providerList
}
+
+// ResetProviders resets the provider cache. Useful for testing.
+func ResetProviders() {
+ providerOnc = sync.Once{}
+ providerList = nil
+}
@@ -0,0 +1,177 @@
+package config
+
+import (
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
+
+// MockProviders returns a mock list of providers for testing.
+// This avoids making API calls during tests and provides consistent test data.
+func MockProviders() []provider.Provider {
+ return []provider.Provider{
+ {
+ Name: "Anthropic",
+ ID: provider.InferenceProviderAnthropic,
+ APIKey: "$ANTHROPIC_API_KEY",
+ APIEndpoint: "$ANTHROPIC_API_ENDPOINT",
+ Type: provider.TypeAnthropic,
+ DefaultLargeModelID: "claude-3-opus",
+ DefaultSmallModelID: "claude-3-haiku",
+ Models: []provider.Model{
+ {
+ ID: "claude-3-opus",
+ Name: "Claude 3 Opus",
+ CostPer1MIn: 15.0,
+ CostPer1MOut: 75.0,
+ CostPer1MInCached: 18.75,
+ CostPer1MOutCached: 1.5,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ {
+ ID: "claude-3-haiku",
+ Name: "Claude 3 Haiku",
+ CostPer1MIn: 0.25,
+ CostPer1MOut: 1.25,
+ CostPer1MInCached: 0.3,
+ CostPer1MOutCached: 0.03,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ {
+ ID: "claude-3-5-sonnet-20241022",
+ Name: "Claude 3.5 Sonnet",
+ CostPer1MIn: 3.0,
+ CostPer1MOut: 15.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.3,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 8192,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ {
+ ID: "claude-3-5-haiku-20241022",
+ Name: "Claude 3.5 Haiku",
+ CostPer1MIn: 0.8,
+ CostPer1MOut: 4.0,
+ CostPer1MInCached: 1.0,
+ CostPer1MOutCached: 0.08,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 8192,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "OpenAI",
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "$OPENAI_API_ENDPOINT",
+ Type: provider.TypeOpenAI,
+ DefaultLargeModelID: "gpt-4",
+ DefaultSmallModelID: "gpt-3.5-turbo",
+ Models: []provider.Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ CostPer1MIn: 30.0,
+ CostPer1MOut: 60.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ CanReason: false,
+ SupportsImages: false,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ CostPer1MIn: 1.0,
+ CostPer1MOut: 2.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 4096,
+ DefaultMaxTokens: 4096,
+ CanReason: false,
+ SupportsImages: false,
+ },
+ {
+ ID: "gpt-4-turbo",
+ Name: "GPT-4 Turbo",
+ CostPer1MIn: 10.0,
+ CostPer1MOut: 30.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 4096,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ {
+ ID: "gpt-4o",
+ Name: "GPT-4o",
+ CostPer1MIn: 2.5,
+ CostPer1MOut: 10.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 1.25,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 16384,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ {
+ ID: "gpt-4o-mini",
+ Name: "GPT-4o-mini",
+ CostPer1MIn: 0.15,
+ CostPer1MOut: 0.6,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.075,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 16384,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "Google Gemini",
+ ID: provider.InferenceProviderGemini,
+ APIKey: "$GEMINI_API_KEY",
+ APIEndpoint: "$GEMINI_API_ENDPOINT",
+ Type: provider.TypeGemini,
+ DefaultLargeModelID: "gemini-2.5-pro",
+ DefaultSmallModelID: "gemini-2.5-flash",
+ Models: []provider.Model{
+ {
+ ID: "gemini-2.5-pro",
+ Name: "Gemini 2.5 Pro",
+ CostPer1MIn: 1.25,
+ CostPer1MOut: 10.0,
+ CostPer1MInCached: 1.625,
+ CostPer1MOutCached: 0.31,
+ ContextWindow: 1048576,
+ DefaultMaxTokens: 65536,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ {
+ ID: "gemini-2.5-flash",
+ Name: "Gemini 2.5 Flash",
+ CostPer1MIn: 0.3,
+ CostPer1MOut: 2.5,
+ CostPer1MInCached: 0.3833,
+ CostPer1MOutCached: 0.075,
+ ContextWindow: 1048576,
+ DefaultMaxTokens: 65535,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ },
+ },
+ }
+}
@@ -0,0 +1,105 @@
+package config
+
+import (
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMockProviders(t *testing.T) {
+ // Enable mock providers for testing
+ originalUseMock := UseMockProviders
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = originalUseMock
+ ResetProviders()
+ }()
+
+ // Reset providers to ensure we get fresh mock data
+ ResetProviders()
+
+ providers := Providers()
+ require.NotEmpty(t, providers, "Mock providers should not be empty")
+
+ // Verify we have the expected mock providers
+ providerIDs := make(map[provider.InferenceProvider]bool)
+ for _, p := range providers {
+ providerIDs[p.ID] = true
+ }
+
+ assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider")
+ assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider")
+ assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider")
+
+ // Verify Anthropic provider details
+ var anthropicProvider provider.Provider
+ for _, p := range providers {
+ if p.ID == provider.InferenceProviderAnthropic {
+ anthropicProvider = p
+ break
+ }
+ }
+
+ assert.Equal(t, "Anthropic", anthropicProvider.Name)
+ assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type)
+ assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID)
+ assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID)
+ assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models")
+
+ // Verify model details
+ var opusModel provider.Model
+ for _, m := range anthropicProvider.Models {
+ if m.ID == "claude-3-opus" {
+ opusModel = m
+ break
+ }
+ }
+
+ assert.Equal(t, "Claude 3 Opus", opusModel.Name)
+ assert.Equal(t, int64(200000), opusModel.ContextWindow)
+ assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens)
+ assert.True(t, opusModel.SupportsImages)
+}
+
+func TestProvidersWithoutMock(t *testing.T) {
+ // Ensure mock is disabled
+ originalUseMock := UseMockProviders
+ UseMockProviders = false
+ defer func() {
+ UseMockProviders = originalUseMock
+ ResetProviders()
+ }()
+
+ // Reset providers to ensure we get fresh data
+ ResetProviders()
+
+ // This will try to make an actual API call or use cached data
+ providers := Providers()
+
+ // We can't guarantee what we'll get here since it depends on network/cache
+ // but we can at least verify the function doesn't panic
+ t.Logf("Got %d providers without mock", len(providers))
+}
+
+func TestResetProviders(t *testing.T) {
+ // Enable mock providers
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = false
+ ResetProviders()
+ }()
+
+ // Get providers once
+ providers1 := Providers()
+ require.NotEmpty(t, providers1)
+
+ // Reset and get again
+ ResetProviders()
+ providers2 := Providers()
+ require.NotEmpty(t, providers2)
+
+ // Should get the same mock data
+ assert.Equal(t, len(providers1), len(providers2))
+}