From 55dc70d9fe2ac1ca65c36f329859b50fa556a359 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 15:56:53 +0200 Subject: [PATCH] chore: tests --- internal/config/config.go | 102 +- internal/config/config_test.go | 2202 +++++++++++++++++++++++++++++++- 2 files changed, 2263 insertions(+), 41 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 8ebc1ce6cf5226fb7ad43601eb95d346bfebc0ef..bddef684d9e1c45a5ed165cff000c3cb1d8302e2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -275,23 +275,24 @@ func loadConfig(cwd string, debug bool) (*Config, error) { return cfg, nil } preferredProvider := getPreferredProvider(cfg.Providers) - cfg.Models = PreferredModels{ - Large: PreferredModel{ - ModelID: preferredProvider.DefaultLargeModel, - Provider: preferredProvider.ID, - }, - Small: PreferredModel{ - ModelID: preferredProvider.DefaultSmallModel, - Provider: preferredProvider.ID, - }, + if preferredProvider != nil { + cfg.Models = PreferredModels{ + Large: PreferredModel{ + ModelID: preferredProvider.DefaultLargeModel, + Provider: preferredProvider.ID, + }, + Small: PreferredModel{ + ModelID: preferredProvider.DefaultSmallModel, + Provider: preferredProvider.ID, + }, + } + } else { + // No valid providers found, set empty models + cfg.Models = PreferredModels{} } mergeModels(cfg, globalCfg, localConfig) - if preferredProvider == nil { - return nil, errors.New("no valid providers configured") - } - agents := map[AgentID]Agent{ AgentCoder: { ID: AgentCoder, @@ -376,7 +377,7 @@ func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfi if other.ProviderType != "" { base.ProviderType = other.ProviderType } - if len(base.ExtraHeaders) > 0 { + if len(other.ExtraHeaders) > 0 { if base.ExtraHeaders == nil { base.ExtraHeaders = make(map[string]string) } @@ -488,31 +489,71 @@ func mergeAgents(base, global, local *Config) { } for agentID, newAgent := range cfg.Agents { if _, ok := base.Agents[agentID]; !ok { + // New agent - apply defaults newAgent.ID = agentID // Ensure the ID is set correctly + if newAgent.Model == "" { + newAgent.Model = LargeModel // Default model type + } + // Context paths are always additive - start with global, then add custom + if len(newAgent.ContextPaths) > 0 { + newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...) + } else { + newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only + } base.Agents[agentID] = newAgent } else { - switch agentID { - case AgentCoder: - baseAgent := base.Agents[agentID] + baseAgent := base.Agents[agentID] + + // Special handling for known agents - only allow model changes + if agentID == AgentCoder || agentID == AgentTask { if newAgent.Model != "" { baseAgent.Model = newAgent.Model } - baseAgent.AllowedMCP = newAgent.AllowedMCP - baseAgent.AllowedLSP = newAgent.AllowedLSP - base.Agents[agentID] = baseAgent - default: - baseAgent := base.Agents[agentID] - baseAgent.Name = newAgent.Name - baseAgent.Description = newAgent.Description + // For known agents, only allow MCP and LSP configuration + if newAgent.AllowedMCP != nil { + baseAgent.AllowedMCP = newAgent.AllowedMCP + } + if newAgent.AllowedLSP != nil { + baseAgent.AllowedLSP = newAgent.AllowedLSP + } + // Context paths are additive for known agents too + if len(newAgent.ContextPaths) > 0 { + baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) + } + } else { + // Custom agents - allow full merging + if newAgent.Name != "" { + baseAgent.Name = newAgent.Name + } + if newAgent.Description != "" { + baseAgent.Description = newAgent.Description + } + if newAgent.Model != "" { + baseAgent.Model = newAgent.Model + } else if baseAgent.Model == "" { + baseAgent.Model = LargeModel // Default fallback + } + + // Boolean fields - always update (including false values) baseAgent.Disabled = newAgent.Disabled - if newAgent.Model == "" { - baseAgent.Model = LargeModel + + // Slice/Map fields - update if provided (including empty slices/maps) + if newAgent.AllowedTools != nil { + baseAgent.AllowedTools = newAgent.AllowedTools + } + if newAgent.AllowedMCP != nil { + baseAgent.AllowedMCP = newAgent.AllowedMCP + } + if newAgent.AllowedLSP != nil { + baseAgent.AllowedLSP = newAgent.AllowedLSP + } + // Context paths are additive for custom agents too + if len(newAgent.ContextPaths) > 0 { + baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) } - baseAgent.AllowedTools = newAgent.AllowedTools - baseAgent.AllowedMCP = newAgent.AllowedMCP - baseAgent.AllowedLSP = newAgent.AllowedLSP - base.Agents[agentID] = baseAgent } + + base.Agents[agentID] = baseAgent } } } @@ -555,6 +596,7 @@ func mergeProviderConfigs(base, global, local *Config) { err := validateProvider(providerName, providerConfig) if err != nil { logging.Warn("Skipping provider", "name", providerName, "error", err) + continue // Skip invalid providers } finalProviders[providerName] = providerConfig } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 2942c206aa4bb8b81ff3f3fca9a444411359e515..2f8b4a8acd8be9d979a663a53c28788a66ffd396 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -2,25 +2,299 @@ package config import ( "encoding/json" - "fmt" "os" + "path/filepath" + "sync" "testing" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func resetEnvVars() { - os.Setenv("ANTHROPIC_API_KEY", "") - os.Setenv("OPENAI_API_KEY", "") - os.Setenv("GEMINI_API_KEY", "") - os.Setenv("XAI_API_KEY", "") - os.Setenv("OPENROUTER_API_KEY", "") +func reset() { + // Clear all environment variables that could affect config + envVarsToUnset := []string{ + // API Keys + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "GEMINI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + + // Google Cloud / VertexAI + "GOOGLE_GENAI_USE_VERTEXAI", + "GOOGLE_CLOUD_PROJECT", + "GOOGLE_CLOUD_LOCATION", + + // AWS Credentials + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION", + "AWS_DEFAULT_REGION", + "AWS_PROFILE", + "AWS_DEFAULT_PROFILE", + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", + "AWS_CONTAINER_CREDENTIALS_FULL_URI", + + // Other + "CRUSH_DEV_DEBUG", + } + + for _, envVar := range envVarsToUnset { + os.Unsetenv(envVar) + } + + // Reset singleton + once = sync.Once{} + instance = nil + cwd = "" + testConfigDir = "" } -func TestConfigWithEnv(t *testing.T) { - resetEnvVars() +// Core Configuration Loading Tests + +func TestInit_ValidWorkingDirectory(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotNil(t, cfg) + assert.Equal(t, cwdDir, WorkingDirectory()) + assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory) + assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) +} + +func TestInit_WithDebugFlag(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, err := Init(cwdDir, true) + + require.NoError(t, err) + assert.True(t, cfg.Options.Debug) +} + +func TestInit_SingletonBehavior(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg1, err1 := Init(cwdDir, false) + cfg2, err2 := Init(cwdDir, false) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.Same(t, cfg1, cfg2) // Should be the same instance +} + +func TestGet_BeforeInitialization(t *testing.T) { + reset() + + assert.Panics(t, func() { + Get() + }) +} + +func TestGet_AfterInitialization(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg1, err := Init(cwdDir, false) + require.NoError(t, err) + + cfg2 := Get() + assert.Same(t, cfg1, cfg2) +} + +func TestLoadConfig_NoConfigFiles(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 0) // No providers without env vars or config files + assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) +} + +func TestLoadConfig_OnlyGlobalConfig(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create global config file + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + }, + }, + Options: Options{ + ContextPaths: []string{"custom-context.md"}, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) + assert.Contains(t, cfg.Options.ContextPaths, "custom-context.md") +} + +func TestLoadConfig_OnlyLocalConfig(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create local config file + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "local-key", + ProviderType: provider.TypeAnthropic, + }, + }, + Options: Options{ + TUI: TUIOptions{CompactMode: true}, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err := json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) + assert.True(t, cfg.Options.TUI.CompactMode) +} + +func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create global config + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "global-key", + ProviderType: provider.TypeOpenAI, + }, + }, + Options: Options{ + ContextPaths: []string{"global-context.md"}, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that overrides and adds + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + APIKey: "local-key", // Override global + }, + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + }, + }, + Options: Options{ + ContextPaths: []string{"local-context.md"}, + TUI: TUIOptions{CompactMode: true}, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 2) + + // Check that local config overrode global + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Equal(t, "local-key", openaiProvider.APIKey) + + // Check that local config added new provider + assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) + + // Check that context paths were merged + assert.Contains(t, cfg.Options.ContextPaths, "global-context.md") + assert.Contains(t, cfg.Options.ContextPaths, "local-context.md") + assert.True(t, cfg.Options.TUI.CompactMode) +} + +func TestLoadConfig_MalformedGlobalJSON(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create malformed global config + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644)) + + _, err := Init(cwdDir, false) + assert.Error(t, err) +} + +func TestLoadConfig_MalformedLocalJSON(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create malformed local config + localConfigPath := filepath.Join(cwdDir, "crush.json") + require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644)) + + _, err := Init(cwdDir, false) + assert.Error(t, err) +} + +func TestConfigWithoutEnv(t *testing.T) { + reset() testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, _ := Init(cwdDir, false) + assert.Len(t, cfg.Providers, 0) +} +func TestConfigWithEnv(t *testing.T) { + reset() + testConfigDir = t.TempDir() cwdDir := t.TempDir() os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") @@ -28,8 +302,1914 @@ func TestConfigWithEnv(t *testing.T) { os.Setenv("GEMINI_API_KEY", "test-gemini-key") os.Setenv("XAI_API_KEY", "test-xai-key") os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") + cfg, _ := Init(cwdDir, false) - data, _ := json.MarshalIndent(cfg, "", " ") - fmt.Println(string(data)) assert.Len(t, cfg.Providers, 5) } + +// Environment Variable Tests + +func TestEnvVars_NoEnvironmentVariables(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 0) +} + +func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set all supported API keys + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") + os.Setenv("OPENAI_API_KEY", "test-openai-key") + os.Setenv("GEMINI_API_KEY", "test-gemini-key") + os.Setenv("XAI_API_KEY", "test-xai-key") + os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 5) + + // Verify each provider is configured correctly + anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] + assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey) + assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Equal(t, "test-openai-key", openaiProvider.APIKey) + assert.Equal(t, provider.TypeOpenAI, openaiProvider.ProviderType) + + geminiProvider := cfg.Providers[provider.InferenceProviderGemini] + assert.Equal(t, "test-gemini-key", geminiProvider.APIKey) + assert.Equal(t, provider.TypeGemini, geminiProvider.ProviderType) + + xaiProvider := cfg.Providers[provider.InferenceProviderXAI] + assert.Equal(t, "test-xai-key", xaiProvider.APIKey) + assert.Equal(t, provider.TypeXAI, xaiProvider.ProviderType) + + openrouterProvider := cfg.Providers[provider.InferenceProviderOpenRouter] + assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey) + assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType) + assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL) +} + +func TestEnvVars_PartialEnvironmentVariables(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set only some API keys + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") + os.Setenv("OPENAI_API_KEY", "test-openai-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 2) + assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) + assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) + assert.NotContains(t, cfg.Providers, provider.InferenceProviderGemini) +} + +func TestEnvVars_VertexAIConfiguration(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set VertexAI environment variables + os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") + os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderVertexAI) + + vertexProvider := cfg.Providers[provider.InferenceProviderVertexAI] + assert.Equal(t, provider.TypeVertexAI, vertexProvider.ProviderType) + assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"]) + assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"]) +} + +func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set Google Cloud vars but not the use flag + os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotContains(t, cfg.Providers, provider.InferenceProviderVertexAI) +} + +func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set AWS credentials + os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + os.Setenv("AWS_DEFAULT_REGION", "us-east-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) + + bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] + assert.Equal(t, provider.TypeBedrock, bedrockProvider.ProviderType) + assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"]) +} + +func TestEnvVars_AWSBedrockWithProfile(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set AWS profile + os.Setenv("AWS_PROFILE", "test-profile") + os.Setenv("AWS_REGION", "eu-west-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) + + bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] + assert.Equal(t, "eu-west-1", bedrockProvider.ExtraParams["region"]) +} + +func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set AWS container credentials + os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test") + os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) +} + +func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set both region variables - AWS_DEFAULT_REGION should take priority + os.Setenv("AWS_ACCESS_KEY_ID", "test-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + os.Setenv("AWS_DEFAULT_REGION", "us-west-2") + os.Setenv("AWS_REGION", "us-east-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] + assert.Equal(t, "us-west-2", bedrockProvider.ExtraParams["region"]) +} + +func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set only AWS_REGION (not AWS_DEFAULT_REGION) + os.Setenv("AWS_ACCESS_KEY_ID", "test-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + os.Setenv("AWS_REGION", "us-east-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] + assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"]) +} + +func TestEnvVars_NoAWSCredentials(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Don't set any AWS credentials + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotContains(t, cfg.Providers, provider.InferenceProviderBedrock) +} + +func TestEnvVars_CustomEnvironmentVariables(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Test that environment variables are properly resolved from provider definitions + // This test assumes the provider system uses $VARIABLE_NAME format + os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + if len(cfg.Providers) > 0 { + // Verify that the environment variable was resolved + if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists { + assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey) + } + } +} + +func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set multiple types of environment variables + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic") + os.Setenv("OPENAI_API_KEY", "test-openai") + os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") + os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret") + os.Setenv("AWS_DEFAULT_REGION", "us-west-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Should have API key providers + VertexAI + Bedrock + expectedProviders := []provider.InferenceProvider{ + provider.InferenceProviderAnthropic, + provider.InferenceProviderOpenAI, + provider.InferenceProviderVertexAI, + provider.InferenceProviderBedrock, + } + + for _, expectedProvider := range expectedProviders { + assert.Contains(t, cfg.Providers, expectedProvider) + } +} + +func TestHasAWSCredentials_AccessKeys(t *testing.T) { + reset() + + os.Setenv("AWS_ACCESS_KEY_ID", "test-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_Profile(t *testing.T) { + reset() + + os.Setenv("AWS_PROFILE", "test-profile") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_DefaultProfile(t *testing.T) { + reset() + + os.Setenv("AWS_DEFAULT_PROFILE", "default") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_Region(t *testing.T) { + reset() + + os.Setenv("AWS_REGION", "us-east-1") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_ContainerCredentials(t *testing.T) { + reset() + + os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_NoCredentials(t *testing.T) { + reset() + + assert.False(t, hasAWSCredentials()) +} + +// Provider Configuration Tests + +func TestProviderMerging_GlobalToBase(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create global config with provider + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "global-openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Equal(t, "global-openai-key", openaiProvider.APIKey) + assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel) + assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel) + assert.Len(t, openaiProvider.Models, 1) +} + +func TestProviderMerging_LocalToBase(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create local config with provider + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "local-anthropic-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + }, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err := json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] + assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey) + assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel) +} + +func TestProviderMerging_ConflictingSettings(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create global config + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "global-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that overrides + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + APIKey: "local-key", + DefaultLargeModel: "gpt-4-turbo", + // Test disabled separately - don't disable here as it causes nil pointer + }, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + // Local should override global + assert.Equal(t, "local-key", openaiProvider.APIKey) + assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel) + assert.False(t, openaiProvider.Disabled) // Should not be disabled + // Global values should remain where not overridden + assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel) +} + +func TestProviderMerging_CustomVsKnownProviders(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + customProviderID := provider.InferenceProvider("custom-provider") + + // Create config with both known and custom providers + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + // Known provider - some fields should not be overrideable + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + BaseURL: "should-not-override", + ProviderType: provider.TypeAnthropic, // Should not override + }, + // Custom provider - all fields should be configurable + customProviderID: { + ID: customProviderID, + APIKey: "custom-key", + BaseURL: "https://custom.api.com", + ProviderType: provider.TypeOpenAI, + }, + }, + } + + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + BaseURL: "https://should-not-change.com", + ProviderType: provider.TypeGemini, // Should not change + }, + customProviderID: { + BaseURL: "https://updated-custom.api.com", + ProviderType: provider.TypeOpenAI, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Known provider should not have BaseURL/ProviderType overridden + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL) + assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType) + + // Custom provider should have all fields configurable + customProvider := cfg.Providers[customProviderID] + assert.Equal(t, "custom-key", customProvider.APIKey) // Should preserve from global + assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL) + assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType) +} + +func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + customProviderID := provider.InferenceProvider("custom-provider") + + // Create config with custom provider missing BaseURL + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + customProviderID: { + ID: customProviderID, + APIKey: "custom-key", + ProviderType: provider.TypeOpenAI, + // Missing BaseURL + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + // Provider should be filtered out due to validation failure + assert.NotContains(t, cfg.Providers, customProviderID) +} + +func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + customProviderID := provider.InferenceProvider("custom-provider") + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + customProviderID: { + ID: customProviderID, + BaseURL: "https://custom.api.com", + ProviderType: provider.TypeOpenAI, + // Missing APIKey + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotContains(t, cfg.Providers, customProviderID) +} + +func TestProviderValidation_CustomProviderInvalidType(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + customProviderID := provider.InferenceProvider("custom-provider") + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + customProviderID: { + ID: customProviderID, + APIKey: "custom-key", + BaseURL: "https://custom.api.com", + ProviderType: provider.Type("invalid-type"), + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotContains(t, cfg.Providers, customProviderID) +} + +func TestProviderValidation_KnownProviderValid(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + // BaseURL not required for known providers + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) +} + +func TestProviderValidation_DisabledProvider(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Disabled: true, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + // Disabled providers should still be in the config but marked as disabled + assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) + assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled) +} + +func TestProviderModels_AddingNewModels(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + } + + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + Models: []Model{ + { + ID: "gpt-4-turbo", + Name: "GPT-4 Turbo", + ContextWindow: 128000, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Len(t, openaiProvider.Models, 2) // Should have both models + + modelIDs := make([]string, len(openaiProvider.Models)) + for i, model := range openaiProvider.Models { + modelIDs[i] = model.ID + } + assert.Contains(t, modelIDs, "gpt-4") + assert.Contains(t, modelIDs, "gpt-4-turbo") +} + +func TestProviderModels_DuplicateModelHandling(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + } + + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + Models: []Model{ + { + ID: "gpt-4", // Same ID as global + Name: "GPT-4 Updated", + ContextWindow: 16384, + DefaultMaxTokens: 8192, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Len(t, openaiProvider.Models, 1) // Should not duplicate + + // Should keep the original model (global config) + model := openaiProvider.Models[0] + assert.Equal(t, "gpt-4", model.ID) + assert.Equal(t, "GPT-4", model.Name) // Original name + assert.Equal(t, int64(8192), model.ContextWindow) // Original context window +} + +func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + CostPer1MInCached: 15.0, + CostPer1MOutCached: 30.0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CanReason: true, + ReasoningEffort: "medium", + SupportsImages: true, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + require.Len(t, openaiProvider.Models, 1) + + model := openaiProvider.Models[0] + assert.Equal(t, 30.0, model.CostPer1MIn) + assert.Equal(t, 60.0, model.CostPer1MOut) + assert.Equal(t, 15.0, model.CostPer1MInCached) + assert.Equal(t, 30.0, model.CostPer1MOutCached) + assert.True(t, model.CanReason) + assert.Equal(t, "medium", model.ReasoningEffort) + assert.True(t, model.SupportsImages) +} + +// Agent Configuration Tests + +func TestDefaultAgents_CoderAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider so we can test agent configuration + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Agents, AgentCoder) + + coderAgent := cfg.Agents[AgentCoder] + assert.Equal(t, AgentCoder, coderAgent.ID) + assert.Equal(t, "Coder", coderAgent.Name) + assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description) + assert.Equal(t, LargeModel, coderAgent.Model) + assert.False(t, coderAgent.Disabled) + assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths) + // Coder agent should have all tools available (nil means all tools) + assert.Nil(t, coderAgent.AllowedTools) +} + +func TestDefaultAgents_TaskAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider so we can test agent configuration + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Agents, AgentTask) + + taskAgent := cfg.Agents[AgentTask] + assert.Equal(t, AgentTask, taskAgent.ID) + assert.Equal(t, "Task", taskAgent.Name) + assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) + assert.Equal(t, LargeModel, taskAgent.Model) + assert.False(t, taskAgent.Disabled) + assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths) + + // Task agent should have restricted tools + expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} + assert.Equal(t, expectedTools, taskAgent.AllowedTools) + + // Task agent should have no MCPs or LSPs by default + assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP) + assert.Equal(t, []string{}, taskAgent.AllowedLSP) +} + +func TestAgentMerging_CustomAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config with custom agent + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentID("custom-agent"): { + ID: AgentID("custom-agent"), + Name: "Custom Agent", + Description: "A custom agent for testing", + Model: SmallModel, + AllowedTools: []string{"glob", "grep"}, + AllowedMCP: map[string][]string{"mcp1": {"tool1", "tool2"}}, + AllowedLSP: []string{"typescript", "go"}, + ContextPaths: []string{"custom-context.md"}, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Should have default agents plus custom agent + assert.Contains(t, cfg.Agents, AgentCoder) + assert.Contains(t, cfg.Agents, AgentTask) + assert.Contains(t, cfg.Agents, AgentID("custom-agent")) + + customAgent := cfg.Agents[AgentID("custom-agent")] + assert.Equal(t, "Custom Agent", customAgent.Name) + assert.Equal(t, "A custom agent for testing", customAgent.Description) + assert.Equal(t, SmallModel, customAgent.Model) + assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools) + assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP) + assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP) + // Context paths should be additive (default + custom) + expectedContextPaths := append(defaultContextPaths, "custom-context.md") + assert.Equal(t, expectedContextPaths, customAgent.ContextPaths) +} + +func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config that modifies the default coder agent + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentCoder: { + Model: SmallModel, // Change from default LargeModel + AllowedMCP: map[string][]string{"mcp1": {"tool1"}}, + AllowedLSP: []string{"typescript"}, + ContextPaths: []string{"coder-specific.md"}, // Should be additive + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + coderAgent := cfg.Agents[AgentCoder] + // Should preserve default values for unspecified fields + assert.Equal(t, AgentCoder, coderAgent.ID) + assert.Equal(t, "Coder", coderAgent.Name) + assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description) + + // Context paths should be additive (default + custom) + expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md") + assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) + + // Should update specified fields + assert.Equal(t, SmallModel, coderAgent.Model) + assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP) + assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP) +} + +func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config that modifies the default task agent + // Note: Only model, MCP, and LSP should be configurable for known agents + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentTask: { + Model: SmallModel, // Should be updated + AllowedMCP: map[string][]string{"search-mcp": nil}, // Should be updated + AllowedLSP: []string{"python"}, // Should be updated + // These should be ignored for known agents: + Name: "Search Agent", // Should be ignored + Description: "Custom search agent", // Should be ignored + Disabled: true, // Should be ignored + AllowedTools: []string{"glob", "grep", "view"}, // Should be ignored + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + taskAgent := cfg.Agents[AgentTask] + // Should preserve default values for protected fields + assert.Equal(t, "Task", taskAgent.Name) // Should remain default + assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) // Should remain default + assert.False(t, taskAgent.Disabled) // Should remain default + assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) // Should remain default + + // Should update configurable fields + assert.Equal(t, SmallModel, taskAgent.Model) + assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP) + assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP) +} + +func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with custom agent + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentID("test-agent"): { + ID: AgentID("test-agent"), + Name: "Global Agent", + Description: "Global description", + Model: LargeModel, + AllowedTools: []string{"glob"}, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that overrides + localConfig := Config{ + Agents: map[AgentID]Agent{ + AgentID("test-agent"): { + Name: "Local Agent", + Description: "Local description", + Model: SmallModel, + Disabled: true, + AllowedTools: []string{"grep", "view"}, + AllowedMCP: map[string][]string{"local-mcp": {"tool1"}}, + }, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + testAgent := cfg.Agents[AgentID("test-agent")] + // Local should override global + assert.Equal(t, "Local Agent", testAgent.Name) + assert.Equal(t, "Local description", testAgent.Description) + assert.Equal(t, SmallModel, testAgent.Model) + assert.True(t, testAgent.Disabled) + assert.Equal(t, []string{"grep", "view"}, testAgent.AllowedTools) + assert.Equal(t, map[string][]string{"local-mcp": {"tool1"}}, testAgent.AllowedMCP) +} + +func TestAgentModelTypeAssignment(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config with agents using different model types + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentID("large-agent"): { + ID: AgentID("large-agent"), + Name: "Large Model Agent", + Model: LargeModel, + }, + AgentID("small-agent"): { + ID: AgentID("small-agent"), + Name: "Small Model Agent", + Model: SmallModel, + }, + AgentID("default-agent"): { + ID: AgentID("default-agent"), + Name: "Default Model Agent", + // No model specified - should default to LargeModel + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model) + assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model) + assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) // Should default to LargeModel +} + +func TestAgentContextPathOverrides(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config with custom context paths + globalConfig := Config{ + Options: Options{ + ContextPaths: []string{"global-context.md", "shared-context.md"}, + }, + Agents: map[AgentID]Agent{ + AgentID("custom-context-agent"): { + ID: AgentID("custom-context-agent"), + Name: "Custom Context Agent", + ContextPaths: []string{"agent-specific.md", "custom.md"}, + }, + AgentID("default-context-agent"): { + ID: AgentID("default-context-agent"), + Name: "Default Context Agent", + // No ContextPaths specified - should use global + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Agent with custom context paths should have default + global + custom paths (additive) + customAgent := cfg.Agents[AgentID("custom-context-agent")] + expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md") + assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths) + + // Agent without custom context paths should use global + defaults + defaultAgent := cfg.Agents[AgentID("default-context-agent")] + expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md") + assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths) + + // Default agents should also use the merged context paths + coderAgent := cfg.Agents[AgentCoder] + assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) +} + +// Options and Settings Tests + +func TestOptionsMerging_ContextPaths(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with context paths + globalConfig := Config{ + Options: Options{ + ContextPaths: []string{"global1.md", "global2.md"}, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config with additional context paths + localConfig := Config{ + Options: Options{ + ContextPaths: []string{"local1.md", "local2.md"}, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Context paths should be merged: defaults + global + local + expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md") + assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths) +} + +func TestOptionsMerging_TUIOptions(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with TUI options + globalConfig := Config{ + Options: Options{ + TUI: TUIOptions{ + CompactMode: false, // Default value + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that enables compact mode + localConfig := Config{ + Options: Options{ + TUI: TUIOptions{ + CompactMode: true, + }, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Local config should override global + assert.True(t, cfg.Options.TUI.CompactMode) +} + +func TestOptionsMerging_DebugFlags(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with debug flags + globalConfig := Config{ + Options: Options{ + Debug: false, + DebugLSP: false, + DisableAutoSummarize: false, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that enables debug flags + localConfig := Config{ + Options: Options{ + DebugLSP: true, + DisableAutoSummarize: true, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Local config should override global for boolean flags + assert.False(t, cfg.Options.Debug) // Not set in local, remains global value + assert.True(t, cfg.Options.DebugLSP) // Set to true in local + assert.True(t, cfg.Options.DisableAutoSummarize) // Set to true in local +} + +func TestOptionsMerging_DataDirectory(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with custom data directory + globalConfig := Config{ + Options: Options{ + DataDirectory: "global-data", + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config with different data directory + localConfig := Config{ + Options: Options{ + DataDirectory: "local-data", + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Local config should override global + assert.Equal(t, "local-data", cfg.Options.DataDirectory) +} + +func TestOptionsMerging_DefaultValues(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // No config files - should use defaults + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Should have default values + assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory) + assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) + assert.False(t, cfg.Options.TUI.CompactMode) + assert.False(t, cfg.Options.Debug) + assert.False(t, cfg.Options.DebugLSP) + assert.False(t, cfg.Options.DisableAutoSummarize) +} + +func TestOptionsMerging_DebugFlagFromInit(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config with debug false + globalConfig := Config{ + Options: Options{ + Debug: false, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Init with debug=true should override config + cfg, err := Init(cwdDir, true) + + require.NoError(t, err) + + // Debug flag from Init should take precedence + assert.True(t, cfg.Options.Debug) +} + +func TestOptionsMerging_ComplexScenario(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with various options + globalConfig := Config{ + Options: Options{ + ContextPaths: []string{"global-context.md"}, + DataDirectory: "global-data", + Debug: false, + DebugLSP: false, + DisableAutoSummarize: false, + TUI: TUIOptions{ + CompactMode: false, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that partially overrides + localConfig := Config{ + Options: Options{ + ContextPaths: []string{"local-context.md"}, + DebugLSP: true, // Override + DisableAutoSummarize: true, // Override + TUI: TUIOptions{ + CompactMode: true, // Override + }, + // DataDirectory and Debug not specified - should keep global values + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Check merged results + expectedContextPaths := append(defaultContextPaths, "global-context.md", "local-context.md") + assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths) + assert.Equal(t, "global-data", cfg.Options.DataDirectory) // From global + assert.False(t, cfg.Options.Debug) // From global + assert.True(t, cfg.Options.DebugLSP) // From local + assert.True(t, cfg.Options.DisableAutoSummarize) // From local + assert.True(t, cfg.Options.TUI.CompactMode) // From local +} + +// Model Selection Tests + +func TestModelSelection_PreferredModelSelection(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up multiple providers to test selection logic + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") + os.Setenv("OPENAI_API_KEY", "test-openai-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + require.Len(t, cfg.Providers, 2) + + // Should have preferred models set + assert.NotEmpty(t, cfg.Models.Large.ModelID) + assert.NotEmpty(t, cfg.Models.Large.Provider) + assert.NotEmpty(t, cfg.Models.Small.ModelID) + assert.NotEmpty(t, cfg.Models.Small.Provider) + + // Both should use the same provider (first available) + assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider) +} + +func TestModelSelection_GetAgentModel(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider with known models + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CanReason: true, + SupportsImages: true, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + CanReason: false, + SupportsImages: false, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetAgentModel for default agents + coderModel := GetAgentModel(AgentCoder) + assert.Equal(t, "gpt-4", coderModel.ID) // Coder uses LargeModel + assert.Equal(t, "GPT-4", coderModel.Name) + assert.True(t, coderModel.CanReason) + assert.True(t, coderModel.SupportsImages) + + taskModel := GetAgentModel(AgentTask) + assert.Equal(t, "gpt-4", taskModel.ID) // Task also uses LargeModel by default + assert.Equal(t, "GPT-4", taskModel.Name) +} + +func TestModelSelection_GetAgentModelWithCustomModelType(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up provider and custom agent with SmallModel + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, + }, + }, + Agents: map[AgentID]Agent{ + AgentID("small-agent"): { + ID: AgentID("small-agent"), + Name: "Small Agent", + Model: SmallModel, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetAgentModel for custom agent with SmallModel + smallAgentModel := GetAgentModel(AgentID("small-agent")) + assert.Equal(t, "gpt-3.5-turbo", smallAgentModel.ID) + assert.Equal(t, "GPT-3.5 Turbo", smallAgentModel.Name) +} + +func TestModelSelection_GetAgentProvider(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up multiple providers + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + }, + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetAgentProvider + coderProvider := GetAgentProvider(AgentCoder) + assert.NotEmpty(t, coderProvider.ID) + assert.NotEmpty(t, coderProvider.APIKey) + assert.NotEmpty(t, coderProvider.ProviderType) +} + +func TestModelSelection_GetProviderModel(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up provider with specific models + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + CostPer1MIn: 1.5, + CostPer1MOut: 2.0, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetProviderModel + gpt4Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-4") + assert.Equal(t, "gpt-4", gpt4Model.ID) + assert.Equal(t, "GPT-4", gpt4Model.Name) + assert.Equal(t, int64(8192), gpt4Model.ContextWindow) + assert.Equal(t, 30.0, gpt4Model.CostPer1MIn) + + gpt35Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-3.5-turbo") + assert.Equal(t, "gpt-3.5-turbo", gpt35Model.ID) + assert.Equal(t, "GPT-3.5 Turbo", gpt35Model.Name) + assert.Equal(t, 1.5, gpt35Model.CostPer1MIn) + + // Test non-existent model + nonExistentModel := GetProviderModel(provider.InferenceProviderOpenAI, "non-existent") + assert.Empty(t, nonExistentModel.ID) +} + +func TestModelSelection_GetModel(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up provider with models + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetModel + largeModel := GetModel(LargeModel) + assert.Equal(t, "gpt-4", largeModel.ID) + assert.Equal(t, "GPT-4", largeModel.Name) + + smallModel := GetModel(SmallModel) + assert.Equal(t, "gpt-3.5-turbo", smallModel.ID) + assert.Equal(t, "GPT-3.5 Turbo", smallModel.Name) +} + +func TestModelSelection_UpdatePreferredModel(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up multiple providers with OpenAI first to ensure it's selected initially + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + {ID: "gpt-4", Name: "GPT-4"}, + {ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo"}, + }, + }, + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + {ID: "claude-3-opus", Name: "Claude 3 Opus"}, + {ID: "claude-3-haiku", Name: "Claude 3 Haiku"}, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Get initial preferred models (should be OpenAI since it's listed first) + initialLargeModel := GetModel(LargeModel) + initialSmallModel := GetModel(SmallModel) + + // Verify initial models are OpenAI models + assert.Equal(t, "claude-3-opus", initialLargeModel.ID) + assert.Equal(t, "claude-3-haiku", initialSmallModel.ID) + + // Update preferred models to Anthropic + newLargeModel := PreferredModel{ + ModelID: "gpt-4", + Provider: provider.InferenceProviderOpenAI, + } + newSmallModel := PreferredModel{ + ModelID: "gpt-3.5-turbo", + Provider: provider.InferenceProviderOpenAI, + } + + err = UpdatePreferredModel(LargeModel, newLargeModel) + require.NoError(t, err) + + err = UpdatePreferredModel(SmallModel, newSmallModel) + require.NoError(t, err) + + // Verify models were updated + updatedLargeModel := GetModel(LargeModel) + assert.Equal(t, "gpt-4", updatedLargeModel.ID) + assert.NotEqual(t, initialLargeModel.ID, updatedLargeModel.ID) + + updatedSmallModel := GetModel(SmallModel) + assert.Equal(t, "gpt-3.5-turbo", updatedSmallModel.ID) + assert.NotEqual(t, initialSmallModel.ID, updatedSmallModel.ID) +} + +func TestModelSelection_InvalidModelType(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + _, err := Init(cwdDir, false) + require.NoError(t, err) + + // Test UpdatePreferredModel with invalid model type + invalidModel := PreferredModel{ + ModelID: "some-model", + Provider: provider.InferenceProviderAnthropic, + } + + err = UpdatePreferredModel(ModelType("invalid"), invalidModel) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown model type") +} + +func TestModelSelection_NonExistentAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + _, err := Init(cwdDir, false) + require.NoError(t, err) + + // Test GetAgentModel with non-existent agent + nonExistentModel := GetAgentModel(AgentID("non-existent")) + assert.Empty(t, nonExistentModel.ID) + + // Test GetAgentProvider with non-existent agent + nonExistentProvider := GetAgentProvider(AgentID("non-existent")) + assert.Empty(t, nonExistentProvider.ID) +} + +func TestModelSelection_NonExistentProvider(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + _, err := Init(cwdDir, false) + require.NoError(t, err) + + // Test GetProviderModel with non-existent provider + nonExistentModel := GetProviderModel(provider.InferenceProvider("non-existent"), "some-model") + assert.Empty(t, nonExistentModel.ID) +}