diff --git a/cmd/logs.go b/cmd/logs.go index 4e69e14b9790dc5985bc8f5e89c33066ddce628a..e6f59384afa795541d7c6016b64551bc0488b35b 100644 --- a/cmd/logs.go +++ b/cmd/logs.go @@ -7,7 +7,7 @@ import ( "slices" "time" - "github.com/charmbracelet/crush/pkg/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/log/v2" "github.com/nxadm/tail" "github.com/spf13/cobra" diff --git a/cmd/schema/main.go b/cmd/schema/main.go deleted file mode 100644 index 43f361662cd5d357e4d3b736ba7b4f3af2222724..0000000000000000000000000000000000000000 --- a/cmd/schema/main.go +++ /dev/null @@ -1,155 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "os" - - "github.com/charmbracelet/crush/internal/config" - "github.com/invopop/jsonschema" -) - -func main() { - // Create a new reflector - r := &jsonschema.Reflector{ - // Use anonymous schemas to avoid ID conflicts - Anonymous: true, - // Expand the root struct instead of referencing it - ExpandedStruct: true, - AllowAdditionalProperties: true, - } - - // Generate schema for the main Config struct - schema := r.Reflect(&config.Config{}) - - // Enhance the schema with additional information - enhanceSchema(schema) - - // Set the schema metadata - schema.Version = "https://json-schema.org/draft/2020-12/schema" - schema.Title = "Crush Configuration" - schema.Description = "Configuration schema for the Crush application" - - // Pretty print the schema - encoder := json.NewEncoder(os.Stdout) - encoder.SetIndent("", " ") - if err := encoder.Encode(schema); err != nil { - fmt.Fprintf(os.Stderr, "Error encoding schema: %v\n", err) - os.Exit(1) - } -} - -// enhanceSchema adds additional enhancements to the generated schema -func enhanceSchema(schema *jsonschema.Schema) { - // Add provider enums - addProviderEnums(schema) - - // Add model enums - addModelEnums(schema) - - // Add tool enums - addToolEnums(schema) - - // Add default context paths - addDefaultContextPaths(schema) -} - -// addProviderEnums adds provider enums to the schema -func addProviderEnums(schema *jsonschema.Schema) { - providers := config.Providers() - var providerIDs []any - for _, p := range providers { - providerIDs = append(providerIDs, string(p.ID)) - } - - // Add to PreferredModel provider field - if schema.Definitions != nil { - if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { - if providerProp, exists := preferredModelDef.Properties.Get("provider"); exists { - providerProp.Enum = providerIDs - } - } - - // Add to ProviderConfig ID field - if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { - if idProp, exists := providerConfigDef.Properties.Get("id"); exists { - idProp.Enum = providerIDs - } - } - } -} - -// addModelEnums adds model enums to the schema -func addModelEnums(schema *jsonschema.Schema) { - providers := config.Providers() - var modelIDs []any - for _, p := range providers { - for _, m := range p.Models { - modelIDs = append(modelIDs, m.ID) - } - } - - // Add to PreferredModel model_id field - if schema.Definitions != nil { - if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { - if modelIDProp, exists := preferredModelDef.Properties.Get("model_id"); exists { - modelIDProp.Enum = modelIDs - } - } - } -} - -// addToolEnums adds tool enums to the schema -func addToolEnums(schema *jsonschema.Schema) { - tools := []any{ - "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", - } - - if schema.Definitions != nil { - if agentDef, exists := schema.Definitions["Agent"]; exists { - if allowedToolsProp, exists := agentDef.Properties.Get("allowed_tools"); exists { - if allowedToolsProp.Items != nil { - allowedToolsProp.Items.Enum = tools - } - } - } - } -} - -// addDefaultContextPaths adds default context paths to the schema -func addDefaultContextPaths(schema *jsonschema.Schema) { - defaultContextPaths := []any{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", - } - - if schema.Definitions != nil { - if optionsDef, exists := schema.Definitions["Options"]; exists { - if contextPathsProp, exists := optionsDef.Properties.Get("context_paths"); exists { - contextPathsProp.Default = defaultContextPaths - } - } - } - - // Also add to root properties if they exist - if schema.Properties != nil { - if optionsProp, exists := schema.Properties.Get("options"); exists { - if optionsProp.Properties != nil { - if contextPathsProp, exists := optionsProp.Properties.Get("context_paths"); exists { - contextPathsProp.Default = defaultContextPaths - } - } - } - } -} diff --git a/crush-schema.json b/crush-schema.json deleted file mode 100644 index ea356c0e585b8a243ee1110d68264c0f2301752f..0000000000000000000000000000000000000000 --- a/crush-schema.json +++ /dev/null @@ -1,700 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "Agent": { - "properties": { - "id": { - "type": "string", - "enum": [ - "coder", - "task", - "coder", - "task" - ], - "title": "Agent ID", - "description": "Unique identifier for the agent" - }, - "name": { - "type": "string", - "title": "Name", - "description": "Display name of the agent" - }, - "description": { - "type": "string", - "title": "Description", - "description": "Description of what the agent does" - }, - "disabled": { - "type": "boolean", - "title": "Disabled", - "description": "Whether this agent is disabled", - "default": false - }, - "model": { - "type": "string", - "enum": [ - "large", - "small", - "large", - "small" - ], - "title": "Model Type", - "description": "Type of model to use (large or small)" - }, - "allowed_tools": { - "items": { - "type": "string", - "enum": [ - "bash", - "edit", - "fetch", - "glob", - "grep", - "ls", - "sourcegraph", - "view", - "write", - "agent" - ] - }, - "type": "array", - "title": "Allowed Tools", - "description": "List of tools this agent is allowed to use (if nil all tools are allowed)" - }, - "allowed_mcp": { - "additionalProperties": { - "items": { - "type": "string" - }, - "type": "array" - }, - "type": "object", - "title": "Allowed MCP", - "description": "Map of MCP servers this agent can use and their allowed tools" - }, - "allowed_lsp": { - "items": { - "type": "string" - }, - "type": "array", - "title": "Allowed LSP", - "description": "List of LSP servers this agent can use (if nil all LSPs are allowed)" - }, - "context_paths": { - "items": { - "type": "string" - }, - "type": "array", - "title": "Context Paths", - "description": "Custom context paths for this agent (additive to global context paths)" - } - }, - "type": "object", - "required": [ - "model" - ] - }, - "LSPConfig": { - "properties": { - "enabled": { - "type": "boolean", - "title": "Enabled", - "description": "Whether this LSP server is enabled", - "default": true - }, - "command": { - "type": "string", - "title": "Command", - "description": "Command to execute for the LSP server" - }, - "args": { - "items": { - "type": "string" - }, - "type": "array", - "title": "Arguments", - "description": "Command line arguments for the LSP server" - }, - "options": { - "title": "Options", - "description": "LSP server specific options" - } - }, - "type": "object", - "required": [ - "command" - ] - }, - "MCP": { - "properties": { - "command": { - "type": "string", - "title": "Command", - "description": "Command to execute for stdio MCP servers" - }, - "env": { - "items": { - "type": "string" - }, - "type": "array", - "title": "Environment", - "description": "Environment variables for the MCP server" - }, - "args": { - "items": { - "type": "string" - }, - "type": "array", - "title": "Arguments", - "description": "Command line arguments for the MCP server" - }, - "type": { - "type": "string", - "enum": [ - "stdio", - "sse", - "stdio", - "sse", - "http" - ], - "title": "Type", - "description": "Type of MCP connection", - "default": "stdio" - }, - "url": { - "type": "string", - "title": "URL", - "description": "URL for SSE MCP servers" - }, - "headers": { - "additionalProperties": { - "type": "string" - }, - "type": "object", - "title": "Headers", - "description": "HTTP headers for SSE MCP servers" - } - }, - "type": "object", - "required": [ - "type" - ] - }, - "Model": { - "properties": { - "id": { - "type": "string", - "title": "Model ID", - "description": "Unique identifier for the model" - }, - "name": { - "type": "string", - "title": "Model Name", - "description": "Display name of the model" - }, - "cost_per_1m_in": { - "type": "number", - "minimum": 0, - "title": "Input Cost", - "description": "Cost per 1 million input tokens" - }, - "cost_per_1m_out": { - "type": "number", - "minimum": 0, - "title": "Output Cost", - "description": "Cost per 1 million output tokens" - }, - "cost_per_1m_in_cached": { - "type": "number", - "minimum": 0, - "title": "Cached Input Cost", - "description": "Cost per 1 million cached input tokens" - }, - "cost_per_1m_out_cached": { - "type": "number", - "minimum": 0, - "title": "Cached Output Cost", - "description": "Cost per 1 million cached output tokens" - }, - "context_window": { - "type": "integer", - "minimum": 1, - "title": "Context Window", - "description": "Maximum context window size in tokens" - }, - "default_max_tokens": { - "type": "integer", - "minimum": 1, - "title": "Default Max Tokens", - "description": "Default maximum tokens for responses" - }, - "can_reason": { - "type": "boolean", - "title": "Can Reason", - "description": "Whether the model supports reasoning capabilities" - }, - "reasoning_effort": { - "type": "string", - "title": "Reasoning Effort", - "description": "Default reasoning effort level for reasoning models" - }, - "has_reasoning_effort": { - "type": "boolean", - "title": "Has Reasoning Effort", - "description": "Whether the model supports reasoning effort configuration" - }, - "supports_attachments": { - "type": "boolean", - "title": "Supports Images", - "description": "Whether the model supports image attachments" - } - }, - "type": "object", - "required": [ - "id", - "name", - "context_window", - "default_max_tokens" - ] - }, - "Options": { - "properties": { - "context_paths": { - "items": { - "type": "string" - }, - "type": "array", - "title": "Context Paths", - "description": "List of paths to search for context files", - "default": [ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md" - ] - }, - "tui": { - "$ref": "#/$defs/TUIOptions", - "title": "TUI Options", - "description": "Terminal UI configuration options" - }, - "debug": { - "type": "boolean", - "title": "Debug", - "description": "Enable debug logging", - "default": false - }, - "debug_lsp": { - "type": "boolean", - "title": "Debug LSP", - "description": "Enable LSP debug logging", - "default": false - }, - "disable_auto_summarize": { - "type": "boolean", - "title": "Disable Auto Summarize", - "description": "Disable automatic conversation summarization", - "default": false - }, - "data_directory": { - "type": "string", - "title": "Data Directory", - "description": "Directory for storing application data", - "default": ".crush" - } - }, - "type": "object" - }, - "PreferredModel": { - "properties": { - "model_id": { - "type": "string", - "enum": [ - "claude-opus-4-20250514", - "claude-sonnet-4-20250514", - "claude-3-7-sonnet-20250219", - "claude-3-5-haiku-20241022", - "claude-3-5-sonnet-20240620", - "claude-3-5-sonnet-20241022", - "codex-mini-latest", - "o4-mini", - "o3", - "o3-pro", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "gpt-4.5-preview", - "o3-mini", - "gpt-4o", - "gpt-4o-mini", - "gemini-2.5-pro", - "gemini-2.5-flash", - "codex-mini-latest", - "o4-mini", - "o3", - "o3-pro", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "gpt-4.5-preview", - "o3-mini", - "gpt-4o", - "gpt-4o-mini", - "anthropic.claude-opus-4-20250514-v1:0", - "anthropic.claude-sonnet-4-20250514-v1:0", - "anthropic.claude-3-7-sonnet-20250219-v1:0", - "anthropic.claude-3-5-haiku-20241022-v1:0", - "gemini-2.5-pro", - "gemini-2.5-flash", - "grok-3-mini", - "grok-3", - "mistralai/mistral-small-3.2-24b-instruct:free", - "mistralai/mistral-small-3.2-24b-instruct", - "minimax/minimax-m1:extended", - "minimax/minimax-m1", - "google/gemini-2.5-flash-lite-preview-06-17", - "google/gemini-2.5-flash", - "google/gemini-2.5-pro", - "openai/o3-pro", - "x-ai/grok-3-mini", - "x-ai/grok-3", - "mistralai/magistral-small-2506", - "mistralai/magistral-medium-2506", - "mistralai/magistral-medium-2506:thinking", - "google/gemini-2.5-pro-preview", - "deepseek/deepseek-r1-0528", - "anthropic/claude-opus-4", - "anthropic/claude-sonnet-4", - "mistralai/devstral-small:free", - "mistralai/devstral-small", - "google/gemini-2.5-flash-preview-05-20", - "google/gemini-2.5-flash-preview-05-20:thinking", - "openai/codex-mini", - "mistralai/mistral-medium-3", - "google/gemini-2.5-pro-preview-05-06", - "arcee-ai/caller-large", - "arcee-ai/virtuoso-large", - "arcee-ai/virtuoso-medium-v2", - "qwen/qwen3-30b-a3b", - "qwen/qwen3-14b", - "qwen/qwen3-32b", - "qwen/qwen3-235b-a22b", - "google/gemini-2.5-flash-preview", - "google/gemini-2.5-flash-preview:thinking", - "openai/o4-mini-high", - "openai/o3", - "openai/o4-mini", - "openai/gpt-4.1", - "openai/gpt-4.1-mini", - "openai/gpt-4.1-nano", - "x-ai/grok-3-mini-beta", - "x-ai/grok-3-beta", - "meta-llama/llama-4-maverick", - "meta-llama/llama-4-scout", - "all-hands/openhands-lm-32b-v0.1", - "google/gemini-2.5-pro-exp-03-25", - "deepseek/deepseek-chat-v3-0324:free", - "deepseek/deepseek-chat-v3-0324", - "mistralai/mistral-small-3.1-24b-instruct:free", - "mistralai/mistral-small-3.1-24b-instruct", - "ai21/jamba-1.6-large", - "ai21/jamba-1.6-mini", - "openai/gpt-4.5-preview", - "google/gemini-2.0-flash-lite-001", - "anthropic/claude-3.7-sonnet", - "anthropic/claude-3.7-sonnet:beta", - "anthropic/claude-3.7-sonnet:thinking", - "mistralai/mistral-saba", - "openai/o3-mini-high", - "google/gemini-2.0-flash-001", - "qwen/qwen-turbo", - "qwen/qwen-plus", - "qwen/qwen-max", - "openai/o3-mini", - "mistralai/mistral-small-24b-instruct-2501", - "deepseek/deepseek-r1-distill-llama-70b", - "deepseek/deepseek-r1", - "mistralai/codestral-2501", - "deepseek/deepseek-chat", - "openai/o1", - "x-ai/grok-2-1212", - "meta-llama/llama-3.3-70b-instruct", - "amazon/nova-lite-v1", - "amazon/nova-micro-v1", - "amazon/nova-pro-v1", - "openai/gpt-4o-2024-11-20", - "mistralai/mistral-large-2411", - "mistralai/mistral-large-2407", - "mistralai/pixtral-large-2411", - "thedrummer/unslopnemo-12b", - "anthropic/claude-3.5-haiku:beta", - "anthropic/claude-3.5-haiku", - "anthropic/claude-3.5-haiku-20241022:beta", - "anthropic/claude-3.5-haiku-20241022", - "anthropic/claude-3.5-sonnet:beta", - "anthropic/claude-3.5-sonnet", - "x-ai/grok-beta", - "mistralai/ministral-8b", - "mistralai/ministral-3b", - "nvidia/llama-3.1-nemotron-70b-instruct", - "google/gemini-flash-1.5-8b", - "meta-llama/llama-3.2-11b-vision-instruct", - "meta-llama/llama-3.2-3b-instruct", - "qwen/qwen-2.5-72b-instruct", - "mistralai/pixtral-12b", - "cohere/command-r-plus-08-2024", - "cohere/command-r-08-2024", - "microsoft/phi-3.5-mini-128k-instruct", - "nousresearch/hermes-3-llama-3.1-70b", - "openai/gpt-4o-2024-08-06", - "meta-llama/llama-3.1-405b-instruct", - "meta-llama/llama-3.1-70b-instruct", - "meta-llama/llama-3.1-8b-instruct", - "mistralai/mistral-nemo", - "openai/gpt-4o-mini", - "openai/gpt-4o-mini-2024-07-18", - "anthropic/claude-3.5-sonnet-20240620:beta", - "anthropic/claude-3.5-sonnet-20240620", - "mistralai/mistral-7b-instruct-v0.3", - "mistralai/mistral-7b-instruct:free", - "mistralai/mistral-7b-instruct", - "microsoft/phi-3-mini-128k-instruct", - "microsoft/phi-3-medium-128k-instruct", - "google/gemini-flash-1.5", - "openai/gpt-4o-2024-05-13", - "openai/gpt-4o", - "openai/gpt-4o:extended", - "meta-llama/llama-3-8b-instruct", - "meta-llama/llama-3-70b-instruct", - "mistralai/mixtral-8x22b-instruct", - "openai/gpt-4-turbo", - "google/gemini-pro-1.5", - "cohere/command-r-plus", - "cohere/command-r-plus-04-2024", - "cohere/command-r", - "anthropic/claude-3-haiku:beta", - "anthropic/claude-3-haiku", - "anthropic/claude-3-opus:beta", - "anthropic/claude-3-opus", - "anthropic/claude-3-sonnet:beta", - "anthropic/claude-3-sonnet", - "cohere/command-r-03-2024", - "mistralai/mistral-large", - "openai/gpt-3.5-turbo-0613", - "openai/gpt-4-turbo-preview", - "mistralai/mistral-small", - "mistralai/mistral-tiny", - "mistralai/mixtral-8x7b-instruct", - "openai/gpt-4-1106-preview", - "mistralai/mistral-7b-instruct-v0.1", - "openai/gpt-3.5-turbo-16k", - "openai/gpt-4", - "openai/gpt-4-0314" - ], - "title": "Model ID", - "description": "ID of the preferred model" - }, - "provider": { - "type": "string", - "enum": [ - "anthropic", - "openai", - "gemini", - "azure", - "bedrock", - "vertex", - "xai", - "openrouter" - ], - "title": "Provider", - "description": "Provider for the preferred model" - }, - "reasoning_effort": { - "type": "string", - "title": "Reasoning Effort", - "description": "Override reasoning effort for this model" - }, - "max_tokens": { - "type": "integer", - "minimum": 1, - "title": "Max Tokens", - "description": "Override max tokens for this model" - }, - "think": { - "type": "boolean", - "title": "Think", - "description": "Enable thinking for reasoning models", - "default": false - } - }, - "type": "object", - "required": [ - "model_id", - "provider" - ] - }, - "PreferredModels": { - "properties": { - "large": { - "$ref": "#/$defs/PreferredModel", - "title": "Large Model", - "description": "Preferred model configuration for large model type" - }, - "small": { - "$ref": "#/$defs/PreferredModel", - "title": "Small Model", - "description": "Preferred model configuration for small model type" - } - }, - "type": "object" - }, - "ProviderConfig": { - "properties": { - "id": { - "type": "string", - "enum": [ - "anthropic", - "openai", - "gemini", - "azure", - "bedrock", - "vertex", - "xai", - "openrouter" - ], - "title": "Provider ID", - "description": "Unique identifier for the provider" - }, - "base_url": { - "type": "string", - "title": "Base URL", - "description": "Base URL for the provider API (required for custom providers)" - }, - "provider_type": { - "type": "string", - "title": "Provider Type", - "description": "Type of the provider (openai" - }, - "api_key": { - "type": "string", - "title": "API Key", - "description": "API key for authenticating with the provider" - }, - "disabled": { - "type": "boolean", - "title": "Disabled", - "description": "Whether this provider is disabled", - "default": false - }, - "extra_headers": { - "additionalProperties": { - "type": "string" - }, - "type": "object", - "title": "Extra Headers", - "description": "Additional HTTP headers to send with requests" - }, - "extra_params": { - "additionalProperties": { - "type": "string" - }, - "type": "object", - "title": "Extra Parameters", - "description": "Additional provider-specific parameters" - }, - "default_large_model": { - "type": "string", - "title": "Default Large Model", - "description": "Default model ID for large model type" - }, - "default_small_model": { - "type": "string", - "title": "Default Small Model", - "description": "Default model ID for small model type" - }, - "models": { - "items": { - "$ref": "#/$defs/Model" - }, - "type": "array", - "title": "Models", - "description": "List of available models for this provider" - } - }, - "type": "object", - "required": [ - "provider_type" - ] - }, - "TUIOptions": { - "properties": { - "compact_mode": { - "type": "boolean", - "title": "Compact Mode", - "description": "Enable compact mode for the TUI", - "default": false - } - }, - "type": "object", - "required": [ - "compact_mode" - ] - } - }, - "properties": { - "models": { - "$ref": "#/$defs/PreferredModels", - "title": "Models", - "description": "Preferred model configurations for large and small model types" - }, - "providers": { - "additionalProperties": { - "$ref": "#/$defs/ProviderConfig" - }, - "type": "object", - "title": "Providers", - "description": "LLM provider configurations" - }, - "agents": { - "additionalProperties": { - "$ref": "#/$defs/Agent" - }, - "type": "object", - "title": "Agents", - "description": "Agent configurations for different tasks" - }, - "mcp": { - "additionalProperties": { - "$ref": "#/$defs/MCP" - }, - "type": "object", - "title": "MCP", - "description": "Model Control Protocol server configurations" - }, - "lsp": { - "additionalProperties": { - "$ref": "#/$defs/LSPConfig" - }, - "type": "object", - "title": "LSP", - "description": "Language Server Protocol configurations" - }, - "options": { - "$ref": "#/$defs/Options", - "title": "Options", - "description": "General application options and settings" - } - }, - "type": "object", - "title": "Crush Configuration", - "description": "Configuration schema for the Crush application" -} diff --git a/crush.json b/crush.json index 6d8a7e97dd55fcc6e27dda4f15c01a2e172cc4cc..b2653070b349f2c520dc5bd2c06c9001d33d9c3e 100644 --- a/crush.json +++ b/crush.json @@ -1,5 +1,4 @@ { - "$schema": "./crush-schema.json", "lsp": { "go": { "command": "gopls" diff --git a/go.mod b/go.mod index 25bc8d66440e76b3842d7e3a11770b57cddb4fa9..1c65635546a2407af280359450233652beb7a094 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/charmbracelet/fang v0.1.0 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 + github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 github.com/charmbracelet/x/exp/charmtone v0.0.0-20250627134340-c144409e381c github.com/charmbracelet/x/exp/golden v0.0.0-20250207160936-21c02780d27a @@ -25,31 +26,28 @@ require ( github.com/go-logfmt/logfmt v0.6.0 github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.13.0 + github.com/joho/godotenv v1.5.1 github.com/mark3labs/mcp-go v0.32.0 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.25.0 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 + github.com/nxadm/tail v1.4.11 github.com/openai/openai-go v1.8.2 github.com/pressly/goose/v3 v3.24.2 + github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 github.com/sahilm/fuzzy v0.1.1 github.com/spf13/cobra v1.9.1 github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef github.com/stretchr/testify v1.10.0 + golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 mvdan.cc/sh/v3 v3.11.0 ) require ( - github.com/charmbracelet/lipgloss v1.1.0 // indirect - github.com/charmbracelet/log v0.4.2 // indirect - github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 // indirect - github.com/joho/godotenv v1.5.1 // indirect - github.com/nxadm/tail v1.4.11 // indirect - github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c // indirect github.com/spf13/cast v1.7.1 // indirect - golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect - gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) @@ -84,7 +82,7 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20250611152503-f53cdd7e01ef github.com/charmbracelet/x/input v0.3.5-0.20250509021451-13796e822d86 // indirect - github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/charmbracelet/x/term v0.2.1 github.com/charmbracelet/x/windows v0.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/disintegration/gift v1.1.2 // indirect diff --git a/go.sum b/go.sum index 266487d545c1dd62f6d83a56c6efc6d816a7ac19..084956c36c89634c9c9a94f7358813c4faa735c9 100644 --- a/go.sum +++ b/go.sum @@ -82,14 +82,8 @@ github.com/charmbracelet/fang v0.1.0 h1:SlZS2crf3/zQh7Mr4+W+7QR1k+L08rrPX5rm5z3d github.com/charmbracelet/fang v0.1.0/go.mod h1:Zl/zeUQ8EtQuGyiV0ZKZlZPDowKRTzu8s/367EpN/fc= github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe h1:i6ce4CcAlPpTj2ER69m1DBeLZ3RRcHnKExuwhKa3GfY= github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe/go.mod h1:p3Q+aN4eQKeM5jhrmXPMgPrlKbmc59rWSnMsSA3udhk= -github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= -github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= -github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c h1:177KMz8zHRlEZJsWzafbKYh6OdjgvTspoH+UjaxgIXY= -github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c/go.mod h1:EJWvaCrhOhNGVZMvcjc0yVryl4qqpMs8tz0r9WyEkdQ= github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 h1:X0tsNa2UHCKNw+illiavosasVzqioRo32SRV35iwr2I= github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71/go.mod h1:EJWvaCrhOhNGVZMvcjc0yVryl4qqpMs8tz0r9WyEkdQ= -github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig= -github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw= github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 h1:WkwO6Ks3mSIGnGuSdKl9qDSyfbYK50z2wc2gGMggegE= github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706/go.mod h1:mjJGp00cxcfvD5xdCa+bso251Jt4owrQvuimJtVmEmM= github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 h1:L07QkDqRF274IZ2UJ/mCTL8DR95efU9BNWLYCDXEjvQ= diff --git a/internal/app/app.go b/internal/app/app.go index 6dd1b9916d593c6f0e053aaef6714723f8fd5c60..85e619b5f23061accf07b644df90a409fd3cbe2d 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -57,7 +57,8 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { cfg := config.Get() - coderAgentCfg := cfg.Agents[config.AgentCoder] + // TODO: remove the concept of agent config most likely + coderAgentCfg := cfg.Agents["coder"] if coderAgentCfg.ID == "" { return nil, fmt.Errorf("coder agent configuration is missing") } diff --git a/internal/app/lsp.go b/internal/app/lsp.go index a056676e1672454adba6d63dd7b7042cc47f6855..7b95458a61c5603df1396a7905eb70370b7adfeb 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -38,7 +38,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman defer cancel() // Initialize with the initialization context - _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory()) + _, err = lspClient.InitializeLSPClient(initCtx, config.Get().WorkingDir()) if err != nil { logging.Error("Initialize failed", "name", name, "error", err) // Clean up the client to prevent resource leaks @@ -91,7 +91,7 @@ func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceW app.restartLSPClient(ctx, name) }) - workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) + workspaceWatcher.WatchWorkspace(ctx, config.Get().WorkingDir()) logging.Info("Workspace watcher stopped", "client", name) } diff --git a/internal/config/config.go b/internal/config/config.go index 544d3ece6f7b653787d06ebc1ac2ff2d7a48cf3f..2b81094ed394e669b4d293200f911a875c2deacb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,28 +1,17 @@ package config import ( - "encoding/json" - "errors" "fmt" - "log/slog" - "maps" - "os" - "path/filepath" "slices" "strings" - "sync" "github.com/charmbracelet/crush/internal/fur/provider" - "github.com/charmbracelet/crush/internal/logging" - "github.com/invopop/jsonschema" ) const ( + appName = "crush" defaultDataDirectory = ".crush" defaultLogLevel = "info" - appName = "crush" - - MaxTokensFallbackDefault = 4096 ) var defaultContextPaths = []string{ @@ -41,82 +30,51 @@ var defaultContextPaths = []string{ "CRUSH.local.md", } -type AgentID string +type SelectedModelType string const ( - AgentCoder AgentID = "coder" - AgentTask AgentID = "task" + SelectedModelTypeLarge SelectedModelType = "large" + SelectedModelTypeSmall SelectedModelType = "small" ) -type ModelType string +type SelectedModel struct { + // The model id as used by the provider API. + // Required. + Model string `json:"model"` + // The model provider, same as the key/id used in the providers config. + // Required. + Provider string `json:"provider"` -const ( - LargeModel ModelType = "large" - SmallModel ModelType = "small" -) + // Only used by models that use the openai provider and need this set. + ReasoningEffort string `json:"reasoning_effort,omitempty"` -type Model struct { - ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model, the API model"` - Name string `json:"name" jsonschema:"title=Model Name,description=Display name of the model"` - CostPer1MIn float64 `json:"cost_per_1m_in,omitempty" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"` - CostPer1MOut float64 `json:"cost_per_1m_out,omitempty" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached,omitempty" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached,omitempty" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"` - ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"` - DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"` - CanReason bool `json:"can_reason,omitempty" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"` - ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"` - HasReasoningEffort bool `json:"has_reasoning_effort,omitempty" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"` - SupportsImages bool `json:"supports_attachments,omitempty" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"` -} + // Overrides the default model configuration. + MaxTokens int64 `json:"max_tokens,omitempty"` -type VertexAIOptions struct { - APIKey string `json:"api_key,omitempty"` - Project string `json:"project,omitempty"` - Location string `json:"location,omitempty"` + // Used by anthropic models that can reason to indicate if the model should think. + Think bool `json:"think,omitempty"` } type ProviderConfig struct { - ID provider.InferenceProvider `json:"id,omitempty" jsonschema:"title=Provider ID,description=Unique identifier for the provider"` - BaseURL string `json:"base_url,omitempty" jsonschema:"title=Base URL,description=Base URL for the provider API (required for custom providers)"` - ProviderType provider.Type `json:"provider_type" jsonschema:"title=Provider Type,description=Type of the provider (openai, anthropic, etc.)"` - APIKey string `json:"api_key,omitempty" jsonschema:"title=API Key,description=API key for authenticating with the provider"` - Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this provider is disabled,default=false"` - ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"title=Extra Headers,description=Additional HTTP headers to send with requests"` - // used for e.x for vertex to set the project - ExtraParams map[string]string `json:"extra_params,omitempty" jsonschema:"title=Extra Parameters,description=Additional provider-specific parameters"` - - DefaultLargeModel string `json:"default_large_model,omitempty" jsonschema:"title=Default Large Model,description=Default model ID for large model type"` - DefaultSmallModel string `json:"default_small_model,omitempty" jsonschema:"title=Default Small Model,description=Default model ID for small model type"` - - Models []Model `json:"models,omitempty" jsonschema:"title=Models,description=List of available models for this provider"` -} - -type Agent struct { - ID AgentID `json:"id,omitempty" jsonschema:"title=Agent ID,description=Unique identifier for the agent,enum=coder,enum=task"` - Name string `json:"name,omitempty" jsonschema:"title=Name,description=Display name of the agent"` - Description string `json:"description,omitempty" jsonschema:"title=Description,description=Description of what the agent does"` - // This is the id of the system prompt used by the agent - Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this agent is disabled,default=false"` - - Model ModelType `json:"model" jsonschema:"title=Model Type,description=Type of model to use (large or small),enum=large,enum=small"` + // The provider's id. + ID string `json:"id,omitempty"` + // The provider's API endpoint. + BaseURL string `json:"base_url,omitempty"` + // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai. + Type provider.Type `json:"type,omitempty"` + // The provider's API key. + APIKey string `json:"api_key,omitempty"` + // Marks the provider as disabled. + Disable bool `json:"disable,omitempty"` - // The available tools for the agent - // if this is nil, all tools are available - AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"title=Allowed Tools,description=List of tools this agent is allowed to use (if nil all tools are allowed)"` + // Extra headers to send with each request to the provider. + ExtraHeaders map[string]string - // this tells us which MCPs are available for this agent - // if this is empty all mcps are available - // the string array is the list of tools from the AllowedMCP the agent has available - // if the string array is nil, all tools from the AllowedMCP are available - AllowedMCP map[string][]string `json:"allowed_mcp,omitempty" jsonschema:"title=Allowed MCP,description=Map of MCP servers this agent can use and their allowed tools"` + // Used to pass extra parameters to the provider. + ExtraParams map[string]string `json:"-"` - // The list of LSPs that this agent can use - // if this is nil, all LSPs are available - AllowedLSP []string `json:"allowed_lsp,omitempty" jsonschema:"title=Allowed LSP,description=List of LSP servers this agent can use (if nil all LSPs are allowed)"` - - // Overrides the context paths for this agent - ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=Custom context paths for this agent (additive to global context paths)"` + // The provider models + Models []provider.Model `json:"models,omitempty"` } type MCPType string @@ -127,1358 +85,205 @@ const ( MCPHttp MCPType = "http" ) -type MCP struct { - Command string `json:"command,omitempty" jsonschema:"title=Command,description=Command to execute for stdio MCP servers"` - Env []string `json:"env,omitempty" jsonschema:"title=Environment,description=Environment variables for the MCP server"` - Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the MCP server"` - Type MCPType `json:"type" jsonschema:"title=Type,description=Type of MCP connection,enum=stdio,enum=sse,enum=http,default=stdio"` - URL string `json:"url,omitempty" jsonschema:"title=URL,description=URL for SSE MCP servers"` +type MCPConfig struct { + Command string `json:"command,omitempty" ` + Env []string `json:"env,omitempty"` + Args []string `json:"args,omitempty"` + Type MCPType `json:"type"` + URL string `json:"url,omitempty"` + // TODO: maybe make it possible to get the value from the env - Headers map[string]string `json:"headers,omitempty" jsonschema:"title=Headers,description=HTTP headers for SSE MCP servers"` + Headers map[string]string `json:"headers,omitempty"` } type LSPConfig struct { - Disabled bool `json:"enabled,omitempty" jsonschema:"title=Enabled,description=Whether this LSP server is enabled,default=true"` - Command string `json:"command" jsonschema:"title=Command,description=Command to execute for the LSP server"` - Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the LSP server"` - Options any `json:"options,omitempty" jsonschema:"title=Options,description=LSP server specific options"` + Disabled bool `json:"enabled,omitempty"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + Options any `json:"options,omitempty"` } type TUIOptions struct { - CompactMode bool `json:"compact_mode" jsonschema:"title=Compact Mode,description=Enable compact mode for the TUI,default=false"` + CompactMode bool `json:"compact_mode,omitempty"` // Here we can add themes later or any TUI related options } type Options struct { - ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=List of paths to search for context files"` - TUI TUIOptions `json:"tui,omitempty" jsonschema:"title=TUI Options,description=Terminal UI configuration options"` - Debug bool `json:"debug,omitempty" jsonschema:"title=Debug,description=Enable debug logging,default=false"` - DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"title=Debug LSP,description=Enable LSP debug logging,default=false"` - DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"title=Disable Auto Summarize,description=Disable automatic conversation summarization,default=false"` + ContextPaths []string `json:"context_paths,omitempty"` + TUI *TUIOptions `json:"tui,omitempty"` + Debug bool `json:"debug,omitempty"` + DebugLSP bool `json:"debug_lsp,omitempty"` + DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"` // Relative to the cwd - DataDirectory string `json:"data_directory,omitempty" jsonschema:"title=Data Directory,description=Directory for storing application data,default=.crush"` + DataDirectory string `json:"data_directory,omitempty"` } -type PreferredModel struct { - ModelID string `json:"model_id" jsonschema:"title=Model ID,description=ID of the preferred model"` - Provider provider.InferenceProvider `json:"provider" jsonschema:"title=Provider,description=Provider for the preferred model"` - // ReasoningEffort overrides the default reasoning effort for this model - ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Override reasoning effort for this model"` - // MaxTokens overrides the default max tokens for this model - MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"title=Max Tokens,description=Override max tokens for this model,minimum=1"` +type MCPs map[string]MCPConfig - // Think indicates if the model should think, only applicable for anthropic reasoning models - Think bool `json:"think,omitempty" jsonschema:"title=Think,description=Enable thinking for reasoning models,default=false"` -} - -type PreferredModels struct { - Large PreferredModel `json:"large,omitempty" jsonschema:"title=Large Model,description=Preferred model configuration for large model type"` - Small PreferredModel `json:"small,omitempty" jsonschema:"title=Small Model,description=Preferred model configuration for small model type"` -} - -type Config struct { - Models PreferredModels `json:"models,omitempty" jsonschema:"title=Models,description=Preferred model configurations for large and small model types"` - // List of configured providers - Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty" jsonschema:"title=Providers,description=LLM provider configurations"` - - // List of configured agents - Agents map[AgentID]Agent `json:"agents,omitempty" jsonschema:"title=Agents,description=Agent configurations for different tasks"` - - // List of configured MCPs - MCP map[string]MCP `json:"mcp,omitempty" jsonschema:"title=MCP,description=Model Control Protocol server configurations"` - - // List of configured LSPs - LSP map[string]LSPConfig `json:"lsp,omitempty" jsonschema:"title=LSP,description=Language Server Protocol configurations"` - - // Miscellaneous options - Options Options `json:"options,omitempty" jsonschema:"title=Options,description=General application options and settings"` -} - -var ( - instance *Config // The single instance of the Singleton - cwd string - once sync.Once // Ensures the initialization happens only once - -) - -func readConfigFile(path string) (*Config, error) { - var cfg *Config - if _, err := os.Stat(path); err != nil && !os.IsNotExist(err) { - // some other error occurred while checking the file - return nil, err - } else if err == nil { - // config file exists, read it - file, err := os.ReadFile(path) - if err != nil { - return nil, err - } - cfg = &Config{} - if err := json.Unmarshal(file, cfg); err != nil { - return nil, err - } - } else { - // config file does not exist, create a new one - cfg = &Config{} - } - return cfg, nil +type MCP struct { + Name string `json:"name"` + MCP MCPConfig `json:"mcp"` } -func loadConfig(cwd string, debug bool) (*Config, error) { - // First read the global config file - cfgPath := ConfigPath() - - cfg := defaultConfigBasedOnEnv() - cfg.Options.Debug = debug - defaultLevel := slog.LevelInfo - if cfg.Options.Debug { - defaultLevel = slog.LevelDebug +func (m MCPs) Sorted() []MCP { + sorted := make([]MCP, 0, len(m)) + for k, v := range m { + sorted = append(sorted, MCP{ + Name: k, + MCP: v, + }) } - if os.Getenv("CRUSH_DEV_DEBUG") == "true" { - loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log") - - // if file does not exist create it - if _, err := os.Stat(loggingFile); os.IsNotExist(err) { - if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil { - return cfg, fmt.Errorf("failed to create directory: %w", err) - } - if _, err := os.Create(loggingFile); err != nil { - return cfg, fmt.Errorf("failed to create log file: %w", err) - } - } - - messagesPath := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "messages") - - if _, err := os.Stat(messagesPath); os.IsNotExist(err) { - if err := os.MkdirAll(messagesPath, 0o756); err != nil { - return cfg, fmt.Errorf("failed to create directory: %w", err) - } - } - logging.MessageDir = messagesPath - - sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) - if err != nil { - return cfg, fmt.Errorf("failed to open log file: %w", err) - } - // Configure logger - logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{ - Level: defaultLevel, - })) - slog.SetDefault(logger) - } else { - // Configure logger - logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ - Level: defaultLevel, - })) - slog.SetDefault(logger) - } - - priorityOrderedConfigFiles := []string{ - cfgPath, // Global config file - filepath.Join(cwd, "crush.json"), // Local config file - filepath.Join(cwd, ".crush.json"), // Local config file - } - - configs := make([]*Config, 0) - for _, path := range priorityOrderedConfigFiles { - localConfig, err := readConfigFile(path) - if err != nil { - return nil, fmt.Errorf("failed to read config file %s: %w", path, err) - } - if localConfig != nil { - // If the config file was read successfully, add it to the list - configs = append(configs, localConfig) - } - } - - // merge options - mergeOptions(cfg, configs...) - - mergeProviderConfigs(cfg, configs...) - // no providers found the app is not initialized yet - if len(cfg.Providers) == 0 { - return cfg, nil - } - preferredProvider := getPreferredProvider(cfg.Providers) - if preferredProvider != nil { - cfg.Models = PreferredModels{ - Large: PreferredModel{ - ModelID: preferredProvider.DefaultLargeModel, - Provider: preferredProvider.ID, - }, - Small: PreferredModel{ - ModelID: preferredProvider.DefaultSmallModel, - Provider: preferredProvider.ID, - }, - } - } else { - // No valid providers found, set empty models - cfg.Models = PreferredModels{} - } - - mergeModels(cfg, configs...) - - agents := map[AgentID]Agent{ - AgentCoder: { - ID: AgentCoder, - Name: "Coder", - Description: "An agent that helps with executing coding tasks.", - Model: LargeModel, - ContextPaths: cfg.Options.ContextPaths, - // All tools allowed - }, - AgentTask: { - ID: AgentTask, - Name: "Task", - Description: "An agent that helps with searching for context and finding implementation details.", - Model: LargeModel, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{ - "glob", - "grep", - "ls", - "sourcegraph", - "view", - }, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - } - cfg.Agents = agents - mergeAgents(cfg, configs...) - mergeMCPs(cfg, configs...) - mergeLSPs(cfg, configs...) - - // Validate the final configuration - if err := cfg.Validate(); err != nil { - return cfg, fmt.Errorf("configuration validation failed: %w", err) - } - - return cfg, nil -} - -func Init(workingDir string, debug bool) (*Config, error) { - var err error - once.Do(func() { - cwd = workingDir - instance, err = loadConfig(cwd, debug) - if err != nil { - logging.Error("Failed to load config", "error", err) - } + slices.SortFunc(sorted, func(a, b MCP) int { + return strings.Compare(a.Name, b.Name) }) - - return instance, err + return sorted } -func Get() *Config { - if instance == nil { - // TODO: Handle this better - panic("Config not initialized. Call InitConfig first.") - } - return instance -} - -func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig { - providers := Providers() - for _, p := range providers { - if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled { - return &providerConfig - } - } - // if none found return the first configured provider - for _, providerConfig := range configuredProviders { - if !providerConfig.Disabled { - return &providerConfig - } - } - return nil -} - -func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { - if other.APIKey != "" { - base.APIKey = other.APIKey - } - // Only change these options if the provider is not a known provider - if !slices.Contains(provider.KnownProviders(), p) { - if other.BaseURL != "" { - base.BaseURL = other.BaseURL - } - if other.ProviderType != "" { - base.ProviderType = other.ProviderType - } - if len(other.ExtraHeaders) > 0 { - if base.ExtraHeaders == nil { - base.ExtraHeaders = make(map[string]string) - } - maps.Copy(base.ExtraHeaders, other.ExtraHeaders) - } - if len(other.ExtraParams) > 0 { - if base.ExtraParams == nil { - base.ExtraParams = make(map[string]string) - } - maps.Copy(base.ExtraParams, other.ExtraParams) - } - } - - if other.Disabled { - base.Disabled = other.Disabled - } - - if other.DefaultLargeModel != "" { - base.DefaultLargeModel = other.DefaultLargeModel - } - // Add new models if they don't exist - if other.Models != nil { - for _, model := range other.Models { - // check if the model already exists - exists := false - for _, existingModel := range base.Models { - if existingModel.ID == model.ID { - exists = true - break - } - } - if !exists { - base.Models = append(base.Models, model) - } - } - } - - return base -} +type LSPs map[string]LSPConfig -func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error { - if !slices.Contains(provider.KnownProviders(), p) { - if providerConfig.ProviderType != provider.TypeOpenAI { - return errors.New("invalid provider type: " + string(providerConfig.ProviderType)) - } - if providerConfig.BaseURL == "" { - return errors.New("base URL must be set for custom providers") - } - if providerConfig.APIKey == "" { - return errors.New("API key must be set for custom providers") - } - } - return nil +type LSP struct { + Name string `json:"name"` + LSP LSPConfig `json:"lsp"` } -func mergeModels(base *Config, others ...*Config) { - for _, cfg := range others { - if cfg == nil { - continue - } - if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" { - base.Models.Large = cfg.Models.Large - } - - if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" { - base.Models.Small = cfg.Models.Small - } +func (l LSPs) Sorted() []LSP { + sorted := make([]LSP, 0, len(l)) + for k, v := range l { + sorted = append(sorted, LSP{ + Name: k, + LSP: v, + }) } + slices.SortFunc(sorted, func(a, b LSP) int { + return strings.Compare(a.Name, b.Name) + }) + return sorted } -func mergeOptions(base *Config, others ...*Config) { - for _, cfg := range others { - if cfg == nil { - continue - } - baseOptions := base.Options - other := cfg.Options - if len(other.ContextPaths) > 0 { - baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...) - } +type Agent struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + // This is the id of the system prompt used by the agent + Disabled bool `json:"disabled,omitempty"` - if other.TUI.CompactMode { - baseOptions.TUI.CompactMode = other.TUI.CompactMode - } + Model SelectedModelType `json:"model"` - if other.Debug { - baseOptions.Debug = other.Debug - } + // The available tools for the agent + // if this is nil, all tools are available + AllowedTools []string `json:"allowed_tools,omitempty"` - if other.DebugLSP { - baseOptions.DebugLSP = other.DebugLSP - } + // this tells us which MCPs are available for this agent + // if this is empty all mcps are available + // the string array is the list of tools from the AllowedMCP the agent has available + // if the string array is nil, all tools from the AllowedMCP are available + AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"` - if other.DisableAutoSummarize { - baseOptions.DisableAutoSummarize = other.DisableAutoSummarize - } + // The list of LSPs that this agent can use + // if this is nil, all LSPs are available + AllowedLSP []string `json:"allowed_lsp,omitempty"` - if other.DataDirectory != "" { - baseOptions.DataDirectory = other.DataDirectory - } - base.Options = baseOptions - } + // Overrides the context paths for this agent + ContextPaths []string `json:"context_paths,omitempty"` } -func mergeAgents(base *Config, others ...*Config) { - for _, cfg := range others { - if cfg == nil { - continue - } - for agentID, newAgent := range cfg.Agents { - if _, ok := base.Agents[agentID]; !ok { - newAgent.ID = agentID - if newAgent.Model == "" { - newAgent.Model = LargeModel - } - if len(newAgent.ContextPaths) > 0 { - newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...) - } else { - newAgent.ContextPaths = base.Options.ContextPaths - } - base.Agents[agentID] = newAgent - } else { - baseAgent := base.Agents[agentID] +// Config holds the configuration for crush. +type Config struct { + // We currently only support large/small as values here. + Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` - if agentID == AgentCoder || agentID == AgentTask { - if newAgent.Model != "" { - baseAgent.Model = newAgent.Model - } - if newAgent.AllowedMCP != nil { - baseAgent.AllowedMCP = newAgent.AllowedMCP - } - if newAgent.AllowedLSP != nil { - baseAgent.AllowedLSP = newAgent.AllowedLSP - } - // Context paths are additive for known agents too - if len(newAgent.ContextPaths) > 0 { - baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) - } - } else { - if newAgent.Name != "" { - baseAgent.Name = newAgent.Name - } - if newAgent.Description != "" { - baseAgent.Description = newAgent.Description - } - if newAgent.Model != "" { - baseAgent.Model = newAgent.Model - } else if baseAgent.Model == "" { - baseAgent.Model = LargeModel - } + // The providers that are configured + Providers map[string]ProviderConfig `json:"providers,omitempty"` - baseAgent.Disabled = newAgent.Disabled + MCP MCPs `json:"mcp,omitempty"` - if newAgent.AllowedTools != nil { - baseAgent.AllowedTools = newAgent.AllowedTools - } - if newAgent.AllowedMCP != nil { - baseAgent.AllowedMCP = newAgent.AllowedMCP - } - if newAgent.AllowedLSP != nil { - baseAgent.AllowedLSP = newAgent.AllowedLSP - } - if len(newAgent.ContextPaths) > 0 { - baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) - } - } + LSP LSPs `json:"lsp,omitempty"` - base.Agents[agentID] = baseAgent - } - } - } -} + Options *Options `json:"options,omitempty"` -func mergeMCPs(base *Config, others ...*Config) { - for _, cfg := range others { - if cfg == nil { - continue - } - maps.Copy(base.MCP, cfg.MCP) - } + // Internal + workingDir string `json:"-"` + // TODO: most likely remove this concept when I come back to it + Agents map[string]Agent `json:"-"` + // TODO: find a better way to do this this should probably not be part of the config + resolver VariableResolver } -func mergeLSPs(base *Config, others ...*Config) { - for _, cfg := range others { - if cfg == nil { - continue - } - maps.Copy(base.LSP, cfg.LSP) - } +func (c *Config) WorkingDir() string { + return c.workingDir } -func mergeProviderConfigs(base *Config, others ...*Config) { - for _, cfg := range others { - if cfg == nil { - continue - } - for providerName, p := range cfg.Providers { - p.ID = providerName - if _, ok := base.Providers[providerName]; !ok { - if slices.Contains(provider.KnownProviders(), providerName) { - providers := Providers() - for _, providerDef := range providers { - if providerDef.ID == providerName { - logging.Info("Using default provider config for", "provider", providerName) - baseProvider := getDefaultProviderConfig(providerDef, providerDef.APIKey) - base.Providers[providerName] = mergeProviderConfig(providerName, baseProvider, p) - break - } - } - } else { - base.Providers[providerName] = p - } - } else { - base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p) - } - } - } - - finalProviders := make(map[provider.InferenceProvider]ProviderConfig) - for providerName, providerConfig := range base.Providers { - err := validateProvider(providerName, providerConfig) - if err != nil { - logging.Warn("Skipping provider", "name", providerName, "error", err) - continue // Skip invalid providers +func (c *Config) EnabledProviders() []ProviderConfig { + enabled := make([]ProviderConfig, 0, len(c.Providers)) + for _, p := range c.Providers { + if !p.Disable { + enabled = append(enabled, p) } - finalProviders[providerName] = providerConfig } - base.Providers = finalProviders + return enabled } -func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig { - switch providerID { - case provider.InferenceProviderAnthropic: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeAnthropic, - } - case provider.InferenceProviderOpenAI: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeOpenAI, - } - case provider.InferenceProviderGemini: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeGemini, - } - case provider.InferenceProviderBedrock: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeBedrock, - } - case provider.InferenceProviderAzure: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeAzure, - } - case provider.InferenceProviderOpenRouter: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeOpenAI, - BaseURL: "https://openrouter.ai/api/v1", - ExtraHeaders: map[string]string{ - "HTTP-Referer": "crush.charm.land", - "X-Title": "Crush", - }, - } - case provider.InferenceProviderXAI: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeXAI, - BaseURL: "https://api.x.ai/v1", - } - case provider.InferenceProviderVertexAI: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeVertexAI, - } - default: - return ProviderConfig{ - ID: providerID, - ProviderType: provider.TypeOpenAI, - } - } +// IsConfigured return true if at least one provider is configured +func (c *Config) IsConfigured() bool { + return len(c.EnabledProviders()) > 0 } -func getDefaultProviderConfig(p provider.Provider, apiKey string) ProviderConfig { - providerConfig := providerDefaultConfig(p.ID) - providerConfig.APIKey = apiKey - providerConfig.DefaultLargeModel = p.DefaultLargeModelID - providerConfig.DefaultSmallModel = p.DefaultSmallModelID - baseURL := p.APIEndpoint - if strings.HasPrefix(baseURL, "$") { - envVar := strings.TrimPrefix(baseURL, "$") - baseURL = os.Getenv(envVar) - } - providerConfig.BaseURL = baseURL - for _, model := range p.Models { - configModel := Model{ - ID: model.ID, - Name: model.Name, - CostPer1MIn: model.CostPer1MIn, - CostPer1MOut: model.CostPer1MOut, - CostPer1MInCached: model.CostPer1MInCached, - CostPer1MOutCached: model.CostPer1MOutCached, - ContextWindow: model.ContextWindow, - DefaultMaxTokens: model.DefaultMaxTokens, - CanReason: model.CanReason, - SupportsImages: model.SupportsImages, - } - // Set reasoning effort for reasoning models - if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { - configModel.HasReasoningEffort = model.HasReasoningEffort - configModel.ReasoningEffort = model.DefaultReasoningEffort - } - providerConfig.Models = append(providerConfig.Models, configModel) - } - return providerConfig -} - -func defaultConfigBasedOnEnv() *Config { - cfg := &Config{ - Options: Options{ - DataDirectory: defaultDataDirectory, - ContextPaths: defaultContextPaths, - }, - Providers: make(map[provider.InferenceProvider]ProviderConfig), - Agents: make(map[AgentID]Agent), - LSP: make(map[string]LSPConfig), - MCP: make(map[string]MCP), - } - - providers := Providers() - - for _, p := range providers { - if strings.HasPrefix(p.APIKey, "$") { - envVar := strings.TrimPrefix(p.APIKey, "$") - if apiKey := os.Getenv(envVar); apiKey != "" { - cfg.Providers[p.ID] = getDefaultProviderConfig(p, apiKey) - } - } - } - // TODO: support local models - - if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" { - providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI) - providerConfig.ExtraParams = map[string]string{ - "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), - "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), - } - // Find the VertexAI provider definition to get default models - for _, p := range providers { - if p.ID == provider.InferenceProviderVertexAI { - providerConfig.DefaultLargeModel = p.DefaultLargeModelID - providerConfig.DefaultSmallModel = p.DefaultSmallModelID - for _, model := range p.Models { - configModel := Model{ - ID: model.ID, - Name: model.Name, - CostPer1MIn: model.CostPer1MIn, - CostPer1MOut: model.CostPer1MOut, - CostPer1MInCached: model.CostPer1MInCached, - CostPer1MOutCached: model.CostPer1MOutCached, - ContextWindow: model.ContextWindow, - DefaultMaxTokens: model.DefaultMaxTokens, - CanReason: model.CanReason, - SupportsImages: model.SupportsImages, - } - // Set reasoning effort for reasoning models - if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { - configModel.HasReasoningEffort = model.HasReasoningEffort - configModel.ReasoningEffort = model.DefaultReasoningEffort - } - providerConfig.Models = append(providerConfig.Models, configModel) - } - break - } - } - cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig - } - - if hasAWSCredentials() { - providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock) - providerConfig.ExtraParams = map[string]string{ - "region": os.Getenv("AWS_DEFAULT_REGION"), - } - if providerConfig.ExtraParams["region"] == "" { - providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION") - } - // Find the Bedrock provider definition to get default models - for _, p := range providers { - if p.ID == provider.InferenceProviderBedrock { - providerConfig.DefaultLargeModel = p.DefaultLargeModelID - providerConfig.DefaultSmallModel = p.DefaultSmallModelID - for _, model := range p.Models { - configModel := Model{ - ID: model.ID, - Name: model.Name, - CostPer1MIn: model.CostPer1MIn, - CostPer1MOut: model.CostPer1MOut, - CostPer1MInCached: model.CostPer1MInCached, - CostPer1MOutCached: model.CostPer1MOutCached, - ContextWindow: model.ContextWindow, - DefaultMaxTokens: model.DefaultMaxTokens, - CanReason: model.CanReason, - SupportsImages: model.SupportsImages, - } - // Set reasoning effort for reasoning models - if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { - configModel.HasReasoningEffort = model.HasReasoningEffort - configModel.ReasoningEffort = model.DefaultReasoningEffort - } - providerConfig.Models = append(providerConfig.Models, configModel) - } - break +func (c *Config) GetModel(provider, model string) *provider.Model { + if providerConfig, ok := c.Providers[provider]; ok { + for _, m := range providerConfig.Models { + if m.ID == model { + return &m } } - cfg.Providers[provider.InferenceProviderBedrock] = providerConfig - } - return cfg -} - -func hasAWSCredentials() bool { - if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { - return true } - - if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { - return true - } - - if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { - return true - } - - if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || - os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { - return true - } - - return false -} - -func WorkingDirectory() string { - return cwd + return nil } -// TODO: Handle error state - -func GetAgentModel(agentID AgentID) Model { - cfg := Get() - agent, ok := cfg.Agents[agentID] - if !ok { - logging.Error("Agent not found", "agent_id", agentID) - return Model{} - } - - var model PreferredModel - switch agent.Model { - case LargeModel: - model = cfg.Models.Large - case SmallModel: - model = cfg.Models.Small - default: - logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model) - model = cfg.Models.Large // Fallback to large model - } - providerConfig, ok := cfg.Providers[model.Provider] +func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig { + model, ok := c.Models[modelType] if !ok { - logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider) - return Model{} + return nil } - - for _, m := range providerConfig.Models { - if m.ID == model.ModelID { - return m - } + if providerConfig, ok := c.Providers[model.Provider]; ok { + return &providerConfig } - - logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model) - return Model{} + return nil } -func GetAgentProvider(agentID AgentID) ProviderConfig { - cfg := Get() - agent, ok := cfg.Agents[agentID] +func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model { + model, ok := c.Models[modelType] if !ok { - logging.Error("Agent not found", "agent_id", agentID) - return ProviderConfig{} + return nil } - - var model PreferredModel - switch agent.Model { - case LargeModel: - model = cfg.Models.Large - case SmallModel: - model = cfg.Models.Small - default: - logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model) - model = cfg.Models.Large // Fallback to large model - } - - providerConfig, ok := cfg.Providers[model.Provider] - if !ok { - logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider) - return ProviderConfig{} - } - - return providerConfig + return c.GetModel(model.Provider, model.Model) } -func GetProviderModel(provider provider.InferenceProvider, modelID string) Model { - cfg := Get() - providerConfig, ok := cfg.Providers[provider] +func (c *Config) LargeModel() *provider.Model { + model, ok := c.Models[SelectedModelTypeLarge] if !ok { - logging.Error("Provider not found", "provider", provider) - return Model{} + return nil } - - for _, model := range providerConfig.Models { - if model.ID == modelID { - return model - } - } - - logging.Error("Model not found for provider", "provider", provider, "model_id", modelID) - return Model{} + return c.GetModel(model.Provider, model.Model) } -func GetModel(modelType ModelType) Model { - cfg := Get() - var model PreferredModel - switch modelType { - case LargeModel: - model = cfg.Models.Large - case SmallModel: - model = cfg.Models.Small - default: - model = cfg.Models.Large // Fallback to large model - } - providerConfig, ok := cfg.Providers[model.Provider] +func (c *Config) SmallModel() *provider.Model { + model, ok := c.Models[SelectedModelTypeSmall] if !ok { - return Model{} + return nil } - - for _, m := range providerConfig.Models { - if m.ID == model.ModelID { - return m - } - } - return Model{} + return c.GetModel(model.Provider, model.Model) } -func UpdatePreferredModel(modelType ModelType, model PreferredModel) error { - cfg := Get() - switch modelType { - case LargeModel: - cfg.Models.Large = model - case SmallModel: - cfg.Models.Small = model - default: - return fmt.Errorf("unknown model type: %s", modelType) +func (c *Config) Resolve(key string) (string, error) { + if c.resolver == nil { + return "", fmt.Errorf("no variable resolver configured") } - return nil -} - -// ValidationError represents a configuration validation error -type ValidationError struct { - Field string - Message string + return c.resolver.ResolveValue(key) } -func (e ValidationError) Error() string { - return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message) -} - -// ValidationErrors represents multiple validation errors -type ValidationErrors []ValidationError - -func (e ValidationErrors) Error() string { - if len(e) == 0 { - return "no validation errors" - } - if len(e) == 1 { - return e[0].Error() - } - - var messages []string - for _, err := range e { - messages = append(messages, err.Error()) - } - return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; ")) -} - -// HasErrors returns true if there are any validation errors -func (e ValidationErrors) HasErrors() bool { - return len(e) > 0 -} - -// Add appends a new validation error -func (e *ValidationErrors) Add(field, message string) { - *e = append(*e, ValidationError{Field: field, Message: message}) -} - -// Validate performs comprehensive validation of the configuration -func (c *Config) Validate() error { - var errors ValidationErrors - - // Validate providers - c.validateProviders(&errors) - - // Validate models - c.validateModels(&errors) - - // Validate agents - c.validateAgents(&errors) - - // Validate options - c.validateOptions(&errors) - - // Validate MCP configurations - c.validateMCPs(&errors) - - // Validate LSP configurations - c.validateLSPs(&errors) - - // Validate cross-references - c.validateCrossReferences(&errors) - - // Validate completeness - c.validateCompleteness(&errors) - - if errors.HasErrors() { - return errors - } - +// TODO: maybe handle this better +func UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { + cfg := Get() + cfg.Models[modelType] = model return nil } - -// validateProviders validates all provider configurations -func (c *Config) validateProviders(errors *ValidationErrors) { - if c.Providers == nil { - c.Providers = make(map[provider.InferenceProvider]ProviderConfig) - } - - knownProviders := provider.KnownProviders() - validTypes := []provider.Type{ - provider.TypeOpenAI, - provider.TypeAnthropic, - provider.TypeGemini, - provider.TypeAzure, - provider.TypeBedrock, - provider.TypeVertexAI, - provider.TypeXAI, - } - - for providerID, providerConfig := range c.Providers { - fieldPrefix := fmt.Sprintf("providers.%s", providerID) - - // Validate API key for non-disabled providers - if !providerConfig.Disabled && providerConfig.APIKey == "" { - // Special case for AWS Bedrock and VertexAI which may use other auth methods - if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI { - errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers") - } - } - - // Validate provider type - validType := slices.Contains(validTypes, providerConfig.ProviderType) - if !validType { - errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType)) - } - - // Validate custom providers - isKnownProvider := slices.Contains(knownProviders, providerID) - - if !isKnownProvider { - // Custom provider validation - if providerConfig.BaseURL == "" { - errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers") - } - if providerConfig.ProviderType != provider.TypeOpenAI { - errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type") - } - } - - // Validate models - modelIDs := make(map[string]bool) - for i, model := range providerConfig.Models { - modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i) - - // Check for duplicate model IDs - if modelIDs[model.ID] { - errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID)) - } - modelIDs[model.ID] = true - - // Validate required model fields - if model.ID == "" { - errors.Add(modelFieldPrefix+".id", "model ID is required") - } - if model.Name == "" { - errors.Add(modelFieldPrefix+".name", "model name is required") - } - if model.ContextWindow <= 0 { - errors.Add(modelFieldPrefix+".context_window", "context window must be positive") - } - if model.DefaultMaxTokens <= 0 { - errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive") - } - if model.DefaultMaxTokens > model.ContextWindow { - errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window") - } - - // Validate cost fields - if model.CostPer1MIn < 0 { - errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative") - } - if model.CostPer1MOut < 0 { - errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative") - } - if model.CostPer1MInCached < 0 { - errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative") - } - if model.CostPer1MOutCached < 0 { - errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative") - } - } - - // Validate default model references - if providerConfig.DefaultLargeModel != "" { - if !modelIDs[providerConfig.DefaultLargeModel] { - errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel)) - } - } - if providerConfig.DefaultSmallModel != "" { - if !modelIDs[providerConfig.DefaultSmallModel] { - errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel)) - } - } - - // Validate provider-specific requirements - c.validateProviderSpecific(providerID, providerConfig, errors) - } -} - -// validateProviderSpecific validates provider-specific requirements -func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) { - fieldPrefix := fmt.Sprintf("providers.%s", providerID) - - switch providerID { - case provider.InferenceProviderVertexAI: - if !providerConfig.Disabled { - if providerConfig.ExtraParams == nil { - errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration") - } else { - if providerConfig.ExtraParams["project"] == "" { - errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter") - } - if providerConfig.ExtraParams["location"] == "" { - errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter") - } - } - } - case provider.InferenceProviderBedrock: - if !providerConfig.Disabled { - if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" { - errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter") - } - // Check for AWS credentials in environment - if !hasAWSCredentials() { - errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment") - } - } - } -} - -// validateModels validates preferred model configurations -func (c *Config) validateModels(errors *ValidationErrors) { - // Validate large model - if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" { - if c.Models.Large.ModelID == "" { - errors.Add("models.large.model_id", "large model ID is required when provider is set") - } - if c.Models.Large.Provider == "" { - errors.Add("models.large.provider", "large model provider is required when model ID is set") - } - - // Check if provider exists and is not disabled - if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists { - if providerConfig.Disabled { - errors.Add("models.large.provider", "large model provider is disabled") - } - - // Check if model exists in provider - modelExists := false - for _, model := range providerConfig.Models { - if model.ID == c.Models.Large.ModelID { - modelExists = true - break - } - } - if !modelExists { - errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider)) - } - } else { - errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider)) - } - } - - // Validate small model - if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" { - if c.Models.Small.ModelID == "" { - errors.Add("models.small.model_id", "small model ID is required when provider is set") - } - if c.Models.Small.Provider == "" { - errors.Add("models.small.provider", "small model provider is required when model ID is set") - } - - // Check if provider exists and is not disabled - if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists { - if providerConfig.Disabled { - errors.Add("models.small.provider", "small model provider is disabled") - } - - // Check if model exists in provider - modelExists := false - for _, model := range providerConfig.Models { - if model.ID == c.Models.Small.ModelID { - modelExists = true - break - } - } - if !modelExists { - errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider)) - } - } else { - errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider)) - } - } -} - -// validateAgents validates agent configurations -func (c *Config) validateAgents(errors *ValidationErrors) { - if c.Agents == nil { - c.Agents = make(map[AgentID]Agent) - } - - validTools := []string{ - "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", - } - - for agentID, agent := range c.Agents { - fieldPrefix := fmt.Sprintf("agents.%s", agentID) - - // Validate agent ID consistency - if agent.ID != agentID { - errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID)) - } - - // Validate required fields - if agent.ID == "" { - errors.Add(fieldPrefix+".id", "agent ID is required") - } - if agent.Name == "" { - errors.Add(fieldPrefix+".name", "agent name is required") - } - - // Validate model type - if agent.Model != LargeModel && agent.Model != SmallModel { - errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model)) - } - - // Validate allowed tools - if agent.AllowedTools != nil { - for i, tool := range agent.AllowedTools { - validTool := slices.Contains(validTools, tool) - if !validTool { - errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool)) - } - } - } - - // Validate MCP references - if agent.AllowedMCP != nil { - for mcpName := range agent.AllowedMCP { - if _, exists := c.MCP[mcpName]; !exists { - errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName)) - } - } - } - - // Validate LSP references - if agent.AllowedLSP != nil { - for _, lspName := range agent.AllowedLSP { - if _, exists := c.LSP[lspName]; !exists { - errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName)) - } - } - } - - // Validate context paths (basic path validation) - for i, contextPath := range agent.ContextPaths { - if contextPath == "" { - errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty") - } - // Check for invalid characters in path - if strings.Contains(contextPath, "\x00") { - errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters") - } - } - - // Validate known agents maintain their core properties - if agentID == AgentCoder { - if agent.Name != "Coder" { - errors.Add(fieldPrefix+".name", "coder agent name cannot be changed") - } - if agent.Description != "An agent that helps with executing coding tasks." { - errors.Add(fieldPrefix+".description", "coder agent description cannot be changed") - } - } else if agentID == AgentTask { - if agent.Name != "Task" { - errors.Add(fieldPrefix+".name", "task agent name cannot be changed") - } - if agent.Description != "An agent that helps with searching for context and finding implementation details." { - errors.Add(fieldPrefix+".description", "task agent description cannot be changed") - } - expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} - if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) { - errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed") - } - } - } -} - -// validateOptions validates configuration options -func (c *Config) validateOptions(errors *ValidationErrors) { - // Validate data directory - if c.Options.DataDirectory == "" { - errors.Add("options.data_directory", "data directory is required") - } - - // Validate context paths - for i, contextPath := range c.Options.ContextPaths { - if contextPath == "" { - errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty") - } - if strings.Contains(contextPath, "\x00") { - errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters") - } - } -} - -// validateMCPs validates MCP configurations -func (c *Config) validateMCPs(errors *ValidationErrors) { - if c.MCP == nil { - c.MCP = make(map[string]MCP) - } - - for mcpName, mcpConfig := range c.MCP { - fieldPrefix := fmt.Sprintf("mcp.%s", mcpName) - - // Validate MCP type - if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse && mcpConfig.Type != MCPHttp { - errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse' or 'http')", mcpConfig.Type)) - } - - // Validate based on type - if mcpConfig.Type == MCPStdio { - if mcpConfig.Command == "" { - errors.Add(fieldPrefix+".command", "command is required for stdio MCP") - } - } else if mcpConfig.Type == MCPSse { - if mcpConfig.URL == "" { - errors.Add(fieldPrefix+".url", "URL is required for SSE MCP") - } - } - } -} - -// validateLSPs validates LSP configurations -func (c *Config) validateLSPs(errors *ValidationErrors) { - if c.LSP == nil { - c.LSP = make(map[string]LSPConfig) - } - - for lspName, lspConfig := range c.LSP { - fieldPrefix := fmt.Sprintf("lsp.%s", lspName) - - if lspConfig.Command == "" { - errors.Add(fieldPrefix+".command", "command is required for LSP") - } - } -} - -// validateCrossReferences validates cross-references between different config sections -func (c *Config) validateCrossReferences(errors *ValidationErrors) { - // Validate that agents can use their assigned model types - for agentID, agent := range c.Agents { - fieldPrefix := fmt.Sprintf("agents.%s", agentID) - - var preferredModel PreferredModel - switch agent.Model { - case LargeModel: - preferredModel = c.Models.Large - case SmallModel: - preferredModel = c.Models.Small - } - - if preferredModel.Provider != "" { - if providerConfig, exists := c.Providers[preferredModel.Provider]; exists { - if providerConfig.Disabled { - errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider)) - } - } - } - } -} - -// validateCompleteness validates that the configuration is complete and usable -func (c *Config) validateCompleteness(errors *ValidationErrors) { - // Check for at least one valid, non-disabled provider - hasValidProvider := false - for _, providerConfig := range c.Providers { - if !providerConfig.Disabled { - hasValidProvider = true - break - } - } - if !hasValidProvider { - errors.Add("providers", "at least one non-disabled provider is required") - } - - // Check that default agents exist - if _, exists := c.Agents[AgentCoder]; !exists { - errors.Add("agents", "coder agent is required") - } - if _, exists := c.Agents[AgentTask]; !exists { - errors.Add("agents", "task agent is required") - } - - // Check that preferred models are set if providers exist - if hasValidProvider { - if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" { - errors.Add("models.large", "large preferred model must be configured when providers are available") - } - if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" { - errors.Add("models.small", "small preferred model must be configured when providers are available") - } - } -} - -// JSONSchemaExtend adds custom schema properties for AgentID -func (AgentID) JSONSchemaExtend(schema *jsonschema.Schema) { - schema.Enum = []any{ - string(AgentCoder), - string(AgentTask), - } -} - -// JSONSchemaExtend adds custom schema properties for ModelType -func (ModelType) JSONSchemaExtend(schema *jsonschema.Schema) { - schema.Enum = []any{ - string(LargeModel), - string(SmallModel), - } -} - -// JSONSchemaExtend adds custom schema properties for MCPType -func (MCPType) JSONSchemaExtend(schema *jsonschema.Schema) { - schema.Enum = []any{ - string(MCPStdio), - string(MCPSse), - } -} diff --git a/internal/config/config_test.go b/internal/config/config_test.go deleted file mode 100644 index de8024bdd126bd46e13eb6ece102c9de69458266..0000000000000000000000000000000000000000 --- a/internal/config/config_test.go +++ /dev/null @@ -1,2075 +0,0 @@ -package config - -import ( - "encoding/json" - "os" - "path/filepath" - "sync" - "testing" - - "github.com/charmbracelet/crush/internal/fur/provider" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func reset() { - // Clear all environment variables that could affect config - envVarsToUnset := []string{ - // API Keys - "ANTHROPIC_API_KEY", - "OPENAI_API_KEY", - "GEMINI_API_KEY", - "XAI_API_KEY", - "OPENROUTER_API_KEY", - - // Google Cloud / VertexAI - "GOOGLE_GENAI_USE_VERTEXAI", - "GOOGLE_CLOUD_PROJECT", - "GOOGLE_CLOUD_LOCATION", - - // AWS Credentials - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_REGION", - "AWS_DEFAULT_REGION", - "AWS_PROFILE", - "AWS_DEFAULT_PROFILE", - "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", - "AWS_CONTAINER_CREDENTIALS_FULL_URI", - - // Other - "CRUSH_DEV_DEBUG", - } - - for _, envVar := range envVarsToUnset { - os.Unsetenv(envVar) - } - - // Reset singleton - once = sync.Once{} - instance = nil - cwd = "" - testConfigDir = "" - - // Enable mock providers for all tests to avoid API calls - UseMockProviders = true - ResetProviders() -} - -// Core Configuration Loading Tests - -func TestInit_ValidWorkingDirectory(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.NotNil(t, cfg) - assert.Equal(t, cwdDir, WorkingDirectory()) - assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory) - assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) -} - -func TestInit_WithDebugFlag(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - cfg, err := Init(cwdDir, true) - - require.NoError(t, err) - assert.True(t, cfg.Options.Debug) -} - -func TestInit_SingletonBehavior(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - cfg1, err1 := Init(cwdDir, false) - cfg2, err2 := Init(cwdDir, false) - - require.NoError(t, err1) - require.NoError(t, err2) - assert.Same(t, cfg1, cfg2) -} - -func TestGet_BeforeInitialization(t *testing.T) { - reset() - - assert.Panics(t, func() { - Get() - }) -} - -func TestGet_AfterInitialization(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - cfg1, err := Init(cwdDir, false) - require.NoError(t, err) - - cfg2 := Get() - assert.Same(t, cfg1, cfg2) -} - -func TestLoadConfig_NoConfigFiles(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) -} - -func TestLoadConfig_OnlyGlobalConfig(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - CostPer1MIn: 30.0, - CostPer1MOut: 60.0, - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - CostPer1MIn: 1.0, - CostPer1MOut: 2.0, - ContextWindow: 4096, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - Options: Options{ - ContextPaths: []string{"custom-context.md"}, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) - assert.Contains(t, cfg.Options.ContextPaths, "custom-context.md") -} - -func TestLoadConfig_OnlyLocalConfig(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - localConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "local-key", - ProviderType: provider.TypeAnthropic, - DefaultLargeModel: "claude-3-opus", - DefaultSmallModel: "claude-3-haiku", - Models: []Model{ - { - ID: "claude-3-opus", - Name: "Claude 3 Opus", - CostPer1MIn: 15.0, - CostPer1MOut: 75.0, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - }, - { - ID: "claude-3-haiku", - Name: "Claude 3 Haiku", - CostPer1MIn: 0.25, - CostPer1MOut: 1.25, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - Options: Options{ - TUI: TUIOptions{CompactMode: true}, - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err := json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) - assert.True(t, cfg.Options.TUI.CompactMode) -} - -func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "global-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - CostPer1MIn: 30.0, - CostPer1MOut: 60.0, - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - CostPer1MIn: 1.0, - CostPer1MOut: 2.0, - ContextWindow: 4096, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - Options: Options{ - ContextPaths: []string{"global-context.md"}, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - localConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - APIKey: "local-key", // Override global - }, - provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "anthropic-key", - ProviderType: provider.TypeAnthropic, - DefaultLargeModel: "claude-3-opus", - DefaultSmallModel: "claude-3-haiku", - Models: []Model{ - { - ID: "claude-3-opus", - Name: "Claude 3 Opus", - CostPer1MIn: 15.0, - CostPer1MOut: 75.0, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - }, - { - ID: "claude-3-haiku", - Name: "Claude 3 Haiku", - CostPer1MIn: 0.25, - CostPer1MOut: 1.25, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - Options: Options{ - ContextPaths: []string{"local-context.md"}, - TUI: TUIOptions{CompactMode: true}, - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 2) - - openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.Equal(t, "local-key", openaiProvider.APIKey) - - assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) - - assert.Contains(t, cfg.Options.ContextPaths, "global-context.md") - assert.Contains(t, cfg.Options.ContextPaths, "local-context.md") - assert.True(t, cfg.Options.TUI.CompactMode) -} - -func TestLoadConfig_MalformedGlobalJSON(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644)) - - _, err := Init(cwdDir, false) - assert.Error(t, err) -} - -func TestLoadConfig_MalformedLocalJSON(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - localConfigPath := filepath.Join(cwdDir, "crush.json") - require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644)) - - _, err := Init(cwdDir, false) - assert.Error(t, err) -} - -func TestConfigWithoutEnv(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - cfg, _ := Init(cwdDir, false) - assert.Len(t, cfg.Providers, 0) -} - -func TestConfigWithEnv(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") - os.Setenv("OPENAI_API_KEY", "test-openai-key") - os.Setenv("GEMINI_API_KEY", "test-gemini-key") - os.Setenv("XAI_API_KEY", "test-xai-key") - os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") - - cfg, _ := Init(cwdDir, false) - assert.Len(t, cfg.Providers, 5) -} - -// Environment Variable Tests - -func TestEnvVars_NoEnvironmentVariables(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 0) -} - -func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") - os.Setenv("OPENAI_API_KEY", "test-openai-key") - os.Setenv("GEMINI_API_KEY", "test-gemini-key") - os.Setenv("XAI_API_KEY", "test-xai-key") - os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 5) - - anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] - assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey) - assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType) - - openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.Equal(t, "test-openai-key", openaiProvider.APIKey) - assert.Equal(t, provider.TypeOpenAI, openaiProvider.ProviderType) - - geminiProvider := cfg.Providers[provider.InferenceProviderGemini] - assert.Equal(t, "test-gemini-key", geminiProvider.APIKey) - assert.Equal(t, provider.TypeGemini, geminiProvider.ProviderType) - - xaiProvider := cfg.Providers[provider.InferenceProviderXAI] - assert.Equal(t, "test-xai-key", xaiProvider.APIKey) - assert.Equal(t, provider.TypeXAI, xaiProvider.ProviderType) - - openrouterProvider := cfg.Providers[provider.InferenceProviderOpenRouter] - assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey) - assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType) - assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL) -} - -func TestEnvVars_PartialEnvironmentVariables(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") - os.Setenv("OPENAI_API_KEY", "test-openai-key") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 2) - assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) - assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) - assert.NotContains(t, cfg.Providers, provider.InferenceProviderGemini) -} - -func TestEnvVars_VertexAIConfiguration(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") - os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") - os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Contains(t, cfg.Providers, provider.InferenceProviderVertexAI) - - vertexProvider := cfg.Providers[provider.InferenceProviderVertexAI] - assert.Equal(t, provider.TypeVertexAI, vertexProvider.ProviderType) - assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"]) - assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"]) -} - -func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") - os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.NotContains(t, cfg.Providers, provider.InferenceProviderVertexAI) -} - -func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key") - os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") - os.Setenv("AWS_DEFAULT_REGION", "us-east-1") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) - - bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] - assert.Equal(t, provider.TypeBedrock, bedrockProvider.ProviderType) - assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"]) -} - -func TestEnvVars_AWSBedrockWithProfile(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("AWS_PROFILE", "test-profile") - os.Setenv("AWS_REGION", "eu-west-1") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) - - bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] - assert.Equal(t, "eu-west-1", bedrockProvider.ExtraParams["region"]) -} - -func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test") - os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) -} - -func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("AWS_ACCESS_KEY_ID", "test-key") - os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") - os.Setenv("AWS_DEFAULT_REGION", "us-west-2") - os.Setenv("AWS_REGION", "us-east-1") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] - assert.Equal(t, "us-west-2", bedrockProvider.ExtraParams["region"]) -} - -func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("AWS_ACCESS_KEY_ID", "test-key") - os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") - os.Setenv("AWS_REGION", "us-east-1") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] - assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"]) -} - -func TestEnvVars_NoAWSCredentials(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.NotContains(t, cfg.Providers, provider.InferenceProviderBedrock) -} - -func TestEnvVars_CustomEnvironmentVariables(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - if len(cfg.Providers) > 0 { - if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists { - assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey) - } - } -} - -func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-anthropic") - os.Setenv("OPENAI_API_KEY", "test-openai") - os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") - os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") - os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") - os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key") - os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret") - os.Setenv("AWS_DEFAULT_REGION", "us-west-1") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - expectedProviders := []provider.InferenceProvider{ - provider.InferenceProviderAnthropic, - provider.InferenceProviderOpenAI, - provider.InferenceProviderVertexAI, - provider.InferenceProviderBedrock, - } - - for _, expectedProvider := range expectedProviders { - assert.Contains(t, cfg.Providers, expectedProvider) - } -} - -func TestHasAWSCredentials_AccessKeys(t *testing.T) { - reset() - - os.Setenv("AWS_ACCESS_KEY_ID", "test-key") - os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") - - assert.True(t, hasAWSCredentials()) -} - -func TestHasAWSCredentials_Profile(t *testing.T) { - reset() - - os.Setenv("AWS_PROFILE", "test-profile") - - assert.True(t, hasAWSCredentials()) -} - -func TestHasAWSCredentials_DefaultProfile(t *testing.T) { - reset() - - os.Setenv("AWS_DEFAULT_PROFILE", "default") - - assert.True(t, hasAWSCredentials()) -} - -func TestHasAWSCredentials_Region(t *testing.T) { - reset() - - os.Setenv("AWS_REGION", "us-east-1") - - assert.True(t, hasAWSCredentials()) -} - -func TestHasAWSCredentials_ContainerCredentials(t *testing.T) { - reset() - - os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test") - - assert.True(t, hasAWSCredentials()) -} - -func TestHasAWSCredentials_NoCredentials(t *testing.T) { - reset() - - assert.False(t, hasAWSCredentials()) -} - -func TestProviderMerging_GlobalToBase(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "global-openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - - openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.Equal(t, "global-openai-key", openaiProvider.APIKey) - assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel) - assert.Equal(t, "gpt-4o", openaiProvider.DefaultSmallModel) - assert.GreaterOrEqual(t, len(openaiProvider.Models), 2) -} - -func TestProviderMerging_LocalToBase(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - localConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "local-anthropic-key", - ProviderType: provider.TypeAnthropic, - DefaultLargeModel: "claude-3-opus", - DefaultSmallModel: "claude-3-haiku", - Models: []Model{ - { - ID: "claude-3-opus", - Name: "Claude 3 Opus", - ContextWindow: 200000, - DefaultMaxTokens: 4096, - CostPer1MIn: 15.0, - CostPer1MOut: 75.0, - }, - { - ID: "claude-3-haiku", - Name: "Claude 3 Haiku", - ContextWindow: 200000, - DefaultMaxTokens: 4096, - CostPer1MIn: 0.25, - CostPer1MOut: 1.25, - }, - }, - }, - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err := json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - - anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] - assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey) - assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel) - assert.Equal(t, "claude-3-5-haiku-20241022", anthropicProvider.DefaultSmallModel) - assert.GreaterOrEqual(t, len(anthropicProvider.Models), 2) -} - -func TestProviderMerging_ConflictingSettings(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "global-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - { - ID: "gpt-4-turbo", - Name: "GPT-4 Turbo", - ContextWindow: 128000, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - // Create local config that overrides - localConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - APIKey: "local-key", - DefaultLargeModel: "gpt-4-turbo", - }, - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.Equal(t, "local-key", openaiProvider.APIKey) - assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel) - assert.False(t, openaiProvider.Disabled) - assert.Equal(t, "gpt-4o", openaiProvider.DefaultSmallModel) -} - -func TestProviderMerging_CustomVsKnownProviders(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - customProviderID := provider.InferenceProvider("custom-provider") - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - BaseURL: "should-not-override", - ProviderType: provider.TypeAnthropic, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - }, - }, - customProviderID: { - ID: customProviderID, - APIKey: "custom-key", - BaseURL: "https://custom.api.com", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "custom-large", - DefaultSmallModel: "custom-small", - Models: []Model{ - { - ID: "custom-large", - Name: "Custom Large", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "custom-small", - Name: "Custom Small", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - }, - }, - }, - } - - localConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - BaseURL: "https://should-not-change.com", - ProviderType: provider.TypeGemini, // Should not change - }, - customProviderID: { - BaseURL: "https://updated-custom.api.com", - ProviderType: provider.TypeOpenAI, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL) - assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType) - - customProvider := cfg.Providers[customProviderID] - assert.Equal(t, "custom-key", customProvider.APIKey) - assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL) - assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType) -} - -func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - customProviderID := provider.InferenceProvider("custom-provider") - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - customProviderID: { - ID: customProviderID, - APIKey: "custom-key", - ProviderType: provider.TypeOpenAI, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.NotContains(t, cfg.Providers, customProviderID) -} - -func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - customProviderID := provider.InferenceProvider("custom-provider") - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - customProviderID: { - ID: customProviderID, - BaseURL: "https://custom.api.com", - ProviderType: provider.TypeOpenAI, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.NotContains(t, cfg.Providers, customProviderID) -} - -func TestProviderValidation_CustomProviderInvalidType(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - customProviderID := provider.InferenceProvider("custom-provider") - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - customProviderID: { - ID: customProviderID, - APIKey: "custom-key", - BaseURL: "https://custom.api.com", - ProviderType: provider.Type("invalid-type"), - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.NotContains(t, cfg.Providers, customProviderID) -} - -func TestProviderValidation_KnownProviderValid(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) -} - -func TestProviderValidation_DisabledProvider(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - Disabled: true, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - }, - }, - provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "anthropic-key", - ProviderType: provider.TypeAnthropic, - Disabled: false, // This one is enabled - DefaultLargeModel: "claude-3-opus", - DefaultSmallModel: "claude-3-haiku", - Models: []Model{ - { - ID: "claude-3-opus", - Name: "Claude 3 Opus", - ContextWindow: 200000, - DefaultMaxTokens: 4096, - }, - { - ID: "claude-3-haiku", - Name: "Claude 3 Haiku", - ContextWindow: 200000, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) - assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled) - assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) - assert.False(t, cfg.Providers[provider.InferenceProviderAnthropic].Disabled) -} - -func TestProviderModels_AddingNewModels(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-4-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - } - - localConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - Models: []Model{ - { - ID: "gpt-4-turbo", - Name: "GPT-4 Turbo", - ContextWindow: 128000, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.GreaterOrEqual(t, len(openaiProvider.Models), 2) - - modelIDs := make([]string, len(openaiProvider.Models)) - for i, model := range openaiProvider.Models { - modelIDs[i] = model.ID - } - assert.Contains(t, modelIDs, "gpt-4") - assert.Contains(t, modelIDs, "gpt-4-turbo") -} - -func TestProviderModels_DuplicateModelHandling(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-4", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - } - - localConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4 Updated", - ContextWindow: 16384, - DefaultMaxTokens: 8192, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.GreaterOrEqual(t, len(openaiProvider.Models), 1) - - // Find the first model that matches our test data - var testModel *Model - for _, model := range openaiProvider.Models { - if model.ID == "gpt-4" { - testModel = &model - break - } - } - - // If gpt-4 not found, use the first available model - if testModel == nil { - testModel = &openaiProvider.Models[0] - } - - assert.NotEmpty(t, testModel.ID) - assert.NotEmpty(t, testModel.Name) - assert.Greater(t, testModel.ContextWindow, int64(0)) -} - -func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-4", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - CostPer1MIn: 30.0, - CostPer1MOut: 60.0, - CostPer1MInCached: 15.0, - CostPer1MOutCached: 30.0, - ContextWindow: 8192, - DefaultMaxTokens: 4096, - CanReason: true, - ReasoningEffort: "medium", - SupportsImages: true, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - require.GreaterOrEqual(t, len(openaiProvider.Models), 1) - - // Find the test model or use the first one - var testModel *Model - for _, model := range openaiProvider.Models { - if model.ID == "gpt-4" { - testModel = &model - break - } - } - - if testModel == nil { - testModel = &openaiProvider.Models[0] - } - - // Only test the custom properties if this is actually our test model - if testModel.ID == "gpt-4" { - assert.Equal(t, 30.0, testModel.CostPer1MIn) - assert.Equal(t, 60.0, testModel.CostPer1MOut) - assert.Equal(t, 15.0, testModel.CostPer1MInCached) - assert.Equal(t, 30.0, testModel.CostPer1MOutCached) - assert.True(t, testModel.CanReason) - assert.Equal(t, "medium", testModel.ReasoningEffort) - assert.True(t, testModel.SupportsImages) - } -} - -func TestDefaultAgents_CoderAgent(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Contains(t, cfg.Agents, AgentCoder) - - coderAgent := cfg.Agents[AgentCoder] - assert.Equal(t, AgentCoder, coderAgent.ID) - assert.Equal(t, "Coder", coderAgent.Name) - assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description) - assert.Equal(t, LargeModel, coderAgent.Model) - assert.False(t, coderAgent.Disabled) - assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths) - assert.Nil(t, coderAgent.AllowedTools) -} - -func TestDefaultAgents_TaskAgent(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - assert.Contains(t, cfg.Agents, AgentTask) - - taskAgent := cfg.Agents[AgentTask] - assert.Equal(t, AgentTask, taskAgent.ID) - assert.Equal(t, "Task", taskAgent.Name) - assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) - assert.Equal(t, LargeModel, taskAgent.Model) - assert.False(t, taskAgent.Disabled) - assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths) - - expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} - assert.Equal(t, expectedTools, taskAgent.AllowedTools) - - assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP) - assert.Equal(t, []string{}, taskAgent.AllowedLSP) -} - -func TestAgentMerging_CustomAgent(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Agents: map[AgentID]Agent{ - AgentID("custom-agent"): { - ID: AgentID("custom-agent"), - Name: "Custom Agent", - Description: "A custom agent for testing", - Model: SmallModel, - AllowedTools: []string{"glob", "grep"}, - AllowedMCP: map[string][]string{"mcp1": {"tool1", "tool2"}}, - AllowedLSP: []string{"typescript", "go"}, - ContextPaths: []string{"custom-context.md"}, - }, - }, - MCP: map[string]MCP{ - "mcp1": { - Type: MCPStdio, - Command: "test-mcp-command", - Args: []string{"--test"}, - }, - }, - LSP: map[string]LSPConfig{ - "typescript": { - Command: "typescript-language-server", - Args: []string{"--stdio"}, - }, - "go": { - Command: "gopls", - Args: []string{}, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - assert.Contains(t, cfg.Agents, AgentCoder) - assert.Contains(t, cfg.Agents, AgentTask) - assert.Contains(t, cfg.Agents, AgentID("custom-agent")) - - customAgent := cfg.Agents[AgentID("custom-agent")] - assert.Equal(t, "Custom Agent", customAgent.Name) - assert.Equal(t, "A custom agent for testing", customAgent.Description) - assert.Equal(t, SmallModel, customAgent.Model) - assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools) - assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP) - assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP) - expectedContextPaths := append(defaultContextPaths, "custom-context.md") - assert.Equal(t, expectedContextPaths, customAgent.ContextPaths) -} - -func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Agents: map[AgentID]Agent{ - AgentCoder: { - Model: SmallModel, - AllowedMCP: map[string][]string{"mcp1": {"tool1"}}, - AllowedLSP: []string{"typescript"}, - ContextPaths: []string{"coder-specific.md"}, - }, - }, - MCP: map[string]MCP{ - "mcp1": { - Type: MCPStdio, - Command: "test-mcp-command", - Args: []string{"--test"}, - }, - }, - LSP: map[string]LSPConfig{ - "typescript": { - Command: "typescript-language-server", - Args: []string{"--stdio"}, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - coderAgent := cfg.Agents[AgentCoder] - assert.Equal(t, AgentCoder, coderAgent.ID) - assert.Equal(t, "Coder", coderAgent.Name) - assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description) - - expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md") - assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) - - assert.Equal(t, SmallModel, coderAgent.Model) - assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP) - assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP) -} - -func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Agents: map[AgentID]Agent{ - AgentTask: { - Model: SmallModel, - AllowedMCP: map[string][]string{"search-mcp": nil}, - AllowedLSP: []string{"python"}, - Name: "Search Agent", - Description: "Custom search agent", - Disabled: true, - AllowedTools: []string{"glob", "grep", "view"}, - }, - }, - MCP: map[string]MCP{ - "search-mcp": { - Type: MCPStdio, - Command: "search-mcp-command", - Args: []string{"--search"}, - }, - }, - LSP: map[string]LSPConfig{ - "python": { - Command: "pylsp", - Args: []string{}, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - taskAgent := cfg.Agents[AgentTask] - assert.Equal(t, "Task", taskAgent.Name) - assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) - assert.False(t, taskAgent.Disabled) - assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) - - assert.Equal(t, SmallModel, taskAgent.Model) - assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP) - assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP) -} - -func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Agents: map[AgentID]Agent{ - AgentID("test-agent"): { - ID: AgentID("test-agent"), - Name: "Global Agent", - Description: "Global description", - Model: LargeModel, - AllowedTools: []string{"glob"}, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - // Create local config that overrides - localConfig := Config{ - Agents: map[AgentID]Agent{ - AgentID("test-agent"): { - Name: "Local Agent", - Description: "Local description", - Model: SmallModel, - Disabled: true, - AllowedTools: []string{"grep", "view"}, - AllowedMCP: map[string][]string{"local-mcp": {"tool1"}}, - }, - }, - MCP: map[string]MCP{ - "local-mcp": { - Type: MCPStdio, - Command: "local-mcp-command", - Args: []string{"--local"}, - }, - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - testAgent := cfg.Agents[AgentID("test-agent")] - assert.Equal(t, "Local Agent", testAgent.Name) - assert.Equal(t, "Local description", testAgent.Description) - assert.Equal(t, SmallModel, testAgent.Model) - assert.True(t, testAgent.Disabled) - assert.Equal(t, []string{"grep", "view"}, testAgent.AllowedTools) - assert.Equal(t, map[string][]string{"local-mcp": {"tool1"}}, testAgent.AllowedMCP) -} - -func TestAgentModelTypeAssignment(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Agents: map[AgentID]Agent{ - AgentID("large-agent"): { - ID: AgentID("large-agent"), - Name: "Large Model Agent", - Model: LargeModel, - }, - AgentID("small-agent"): { - ID: AgentID("small-agent"), - Name: "Small Model Agent", - Model: SmallModel, - }, - AgentID("default-agent"): { - ID: AgentID("default-agent"), - Name: "Default Model Agent", - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model) - assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model) - assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) -} - -func TestAgentContextPathOverrides(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Options: Options{ - ContextPaths: []string{"global-context.md", "shared-context.md"}, - }, - Agents: map[AgentID]Agent{ - AgentID("custom-context-agent"): { - ID: AgentID("custom-context-agent"), - Name: "Custom Context Agent", - ContextPaths: []string{"agent-specific.md", "custom.md"}, - }, - AgentID("default-context-agent"): { - ID: AgentID("default-context-agent"), - Name: "Default Context Agent", - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - customAgent := cfg.Agents[AgentID("custom-context-agent")] - expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md") - assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths) - - defaultAgent := cfg.Agents[AgentID("default-context-agent")] - expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md") - assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths) - - coderAgent := cfg.Agents[AgentCoder] - assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) -} - -func TestOptionsMerging_ContextPaths(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Options: Options{ - ContextPaths: []string{"global1.md", "global2.md"}, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - localConfig := Config{ - Options: Options{ - ContextPaths: []string{"local1.md", "local2.md"}, - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md") - assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths) -} - -func TestOptionsMerging_TUIOptions(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Options: Options{ - TUI: TUIOptions{ - CompactMode: false, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - localConfig := Config{ - Options: Options{ - TUI: TUIOptions{ - CompactMode: true, - }, - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - assert.True(t, cfg.Options.TUI.CompactMode) -} - -func TestOptionsMerging_DebugFlags(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Options: Options{ - Debug: false, - DebugLSP: false, - DisableAutoSummarize: false, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - localConfig := Config{ - Options: Options{ - DebugLSP: true, - DisableAutoSummarize: true, - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - assert.False(t, cfg.Options.Debug) - assert.True(t, cfg.Options.DebugLSP) - assert.True(t, cfg.Options.DisableAutoSummarize) -} - -func TestOptionsMerging_DataDirectory(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Options: Options{ - DataDirectory: "global-data", - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - localConfig := Config{ - Options: Options{ - DataDirectory: "local-data", - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - assert.Equal(t, "local-data", cfg.Options.DataDirectory) -} - -func TestOptionsMerging_DefaultValues(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory) - assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) - assert.False(t, cfg.Options.TUI.CompactMode) - assert.False(t, cfg.Options.Debug) - assert.False(t, cfg.Options.DebugLSP) - assert.False(t, cfg.Options.DisableAutoSummarize) -} - -func TestOptionsMerging_DebugFlagFromInit(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Options: Options{ - Debug: false, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - cfg, err := Init(cwdDir, true) - - require.NoError(t, err) - - // Debug flag from Init should take precedence - assert.True(t, cfg.Options.Debug) -} - -func TestOptionsMerging_ComplexScenario(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - // Create global config with various options - globalConfig := Config{ - Options: Options{ - ContextPaths: []string{"global-context.md"}, - DataDirectory: "global-data", - Debug: false, - DebugLSP: false, - DisableAutoSummarize: false, - TUI: TUIOptions{ - CompactMode: false, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - // Create local config that partially overrides - localConfig := Config{ - Options: Options{ - ContextPaths: []string{"local-context.md"}, - DebugLSP: true, // Override - DisableAutoSummarize: true, // Override - TUI: TUIOptions{ - CompactMode: true, // Override - }, - // DataDirectory and Debug not specified - should keep global values - }, - } - - localConfigPath := filepath.Join(cwdDir, "crush.json") - data, err = json.Marshal(localConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - - // Check merged results - expectedContextPaths := append(defaultContextPaths, "global-context.md", "local-context.md") - assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths) - assert.Equal(t, "global-data", cfg.Options.DataDirectory) // From global - assert.False(t, cfg.Options.Debug) // From global - assert.True(t, cfg.Options.DebugLSP) // From local - assert.True(t, cfg.Options.DisableAutoSummarize) // From local - assert.True(t, cfg.Options.TUI.CompactMode) // From local -} - -// Model Selection Tests - -func TestModelSelection_PreferredModelSelection(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up multiple providers to test selection logic - os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") - os.Setenv("OPENAI_API_KEY", "test-openai-key") - - cfg, err := Init(cwdDir, false) - - require.NoError(t, err) - require.Len(t, cfg.Providers, 2) - - // Should have preferred models set - assert.NotEmpty(t, cfg.Models.Large.ModelID) - assert.NotEmpty(t, cfg.Models.Large.Provider) - assert.NotEmpty(t, cfg.Models.Small.ModelID) - assert.NotEmpty(t, cfg.Models.Small.Provider) - - // Both should use the same provider (first available) - assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider) -} - -func TestValidation_InvalidModelReference(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "non-existent-model", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - _, err = Init(cwdDir, false) - assert.Error(t, err) -} - -func TestValidation_InvalidAgentModelType(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - globalConfig := Config{ - Agents: map[AgentID]Agent{ - AgentID("invalid-agent"): { - ID: AgentID("invalid-agent"), - Name: "Invalid Agent", - Model: ModelType("invalid"), - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - _, err = Init(cwdDir, false) - assert.Error(t, err) -} diff --git a/internal/config/fs.go b/internal/config/fs.go deleted file mode 100644 index efa622cf937846370616042de4fe2bcd6f33b7a1..0000000000000000000000000000000000000000 --- a/internal/config/fs.go +++ /dev/null @@ -1,71 +0,0 @@ -package config - -import ( - "fmt" - "os" - "path/filepath" - "runtime" -) - -var testConfigDir string - -func baseConfigPath() string { - if testConfigDir != "" { - return testConfigDir - } - - xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") - if xdgConfigHome != "" { - return filepath.Join(xdgConfigHome, "crush") - } - - // return the path to the main config directory - // for windows, it should be in `%LOCALAPPDATA%/crush/` - // for linux and macOS, it should be in `$HOME/.config/crush/` - if runtime.GOOS == "windows" { - localAppData := os.Getenv("LOCALAPPDATA") - if localAppData == "" { - localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") - } - return filepath.Join(localAppData, appName) - } - - return filepath.Join(os.Getenv("HOME"), ".config", appName) -} - -func baseDataPath() string { - if testConfigDir != "" { - return testConfigDir - } - - xdgDataHome := os.Getenv("XDG_DATA_HOME") - if xdgDataHome != "" { - return filepath.Join(xdgDataHome, appName) - } - - // return the path to the main data directory - // for windows, it should be in `%LOCALAPPDATA%/crush/` - // for linux and macOS, it should be in `$HOME/.local/share/crush/` - if runtime.GOOS == "windows" { - localAppData := os.Getenv("LOCALAPPDATA") - if localAppData == "" { - localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") - } - return filepath.Join(localAppData, appName) - } - - return filepath.Join(os.Getenv("HOME"), ".local", "share", appName) -} - -func ConfigPath() string { - return filepath.Join(baseConfigPath(), fmt.Sprintf("%s.json", appName)) -} - -func CrushInitialized() bool { - cfgPath := ConfigPath() - if _, err := os.Stat(cfgPath); os.IsNotExist(err) { - // config file does not exist, so Crush is not initialized - return false - } - return true -} diff --git a/internal/config/init.go b/internal/config/init.go index f17e1db28e41cc44e168765e55e88311423e1102..9d4614c81ce1bb71aa0d18c2084c3ee9587ad6e1 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -5,27 +5,53 @@ import ( "os" "path/filepath" "strings" + "sync" + "sync/atomic" + + "github.com/charmbracelet/crush/internal/logging" ) const ( - // InitFlagFilename is the name of the file that indicates whether the project has been initialized InitFlagFilename = "init" ) -// ProjectInitFlag represents the initialization status for a project directory type ProjectInitFlag struct { Initialized bool `json:"initialized"` } -// ProjectNeedsInitialization checks if the current project needs initialization +// TODO: we need to remove the global config instance keeping it now just until everything is migrated +var ( + instance atomic.Pointer[Config] + cwd string + once sync.Once // Ensures the initialization happens only once +) + +func Init(workingDir string, debug bool) (*Config, error) { + var err error + once.Do(func() { + cwd = workingDir + cfg, err := Load(cwd, debug) + if err != nil { + logging.Error("Failed to load config", "error", err) + } + instance.Store(cfg) + }) + + return instance.Load(), err +} + +func Get() *Config { + return instance.Load() +} + func ProjectNeedsInitialization() (bool, error) { - if instance == nil { + cfg := Get() + if cfg == nil { return false, fmt.Errorf("config not loaded") } - flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename) + flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename) - // Check if the flag file exists _, err := os.Stat(flagFilePath) if err == nil { return false, nil @@ -35,8 +61,7 @@ func ProjectNeedsInitialization() (bool, error) { return false, fmt.Errorf("failed to check init flag file: %w", err) } - // Check if any variation of CRUSH.md already exists in working directory - crushExists, err := crushMdExists(WorkingDirectory()) + crushExists, err := crushMdExists(cfg.WorkingDir()) if err != nil { return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err) } @@ -47,7 +72,6 @@ func ProjectNeedsInitialization() (bool, error) { return true, nil } -// crushMdExists checks if any case variation of crush.md exists in the directory func crushMdExists(dir string) (bool, error) { entries, err := os.ReadDir(dir) if err != nil { @@ -68,12 +92,12 @@ func crushMdExists(dir string) (bool, error) { return false, nil } -// MarkProjectInitialized marks the current project as initialized func MarkProjectInitialized() error { - if instance == nil { + cfg := Get() + if cfg == nil { return fmt.Errorf("config not loaded") } - flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename) + flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename) file, err := os.Create(flagFilePath) if err != nil { diff --git a/pkg/config/load.go b/internal/config/load.go similarity index 94% rename from pkg/config/load.go rename to internal/config/load.go index 2dc8a89735c6ab047eaa21b1093ea3217f019269..ec3d1c9e650e51629894bedafb9b15186936224d 100644 --- a/pkg/config/load.go +++ b/internal/config/load.go @@ -10,10 +10,10 @@ import ( "slices" "strings" + "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" - "github.com/charmbracelet/crush/pkg/env" - "github.com/charmbracelet/crush/pkg/log" + "github.com/charmbracelet/crush/internal/log" "golang.org/x/exp/slog" ) @@ -68,6 +68,7 @@ func Load(workingDir string, debug bool) (*Config, error) { env := env.New() // Configure providers valueResolver := NewShellVariableResolver(env) + cfg.resolver = valueResolver if err := cfg.configureProviders(env, valueResolver, providers); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } @@ -81,6 +82,36 @@ func Load(workingDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to configure selected models: %w", err) } + // TODO: remove the agents concept from the config + agents := map[string]Agent{ + "coder": { + ID: "coder", + Name: "Coder", + Description: "An agent that helps with executing coding tasks.", + Model: SelectedModelTypeLarge, + ContextPaths: cfg.Options.ContextPaths, + // All tools allowed + }, + "task": { + ID: "task", + Name: "Task", + Description: "An agent that helps with searching for context and finding implementation details.", + Model: SelectedModelTypeLarge, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{ + "glob", + "grep", + "ls", + "sourcegraph", + "view", + }, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + } + cfg.Agents = agents + return cfg, nil } diff --git a/pkg/config/load_test.go b/internal/config/load_test.go similarity index 99% rename from pkg/config/load_test.go rename to internal/config/load_test.go index 01cf088c5b639683c30dc7d61505cde6b28ff593..0f46f7899b2e443e9f99c646d13370c3a2e146d4 100644 --- a/pkg/config/load_test.go +++ b/internal/config/load_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/charmbracelet/crush/internal/fur/provider" - "github.com/charmbracelet/crush/pkg/env" + "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/assert" ) diff --git a/pkg/config/merge.go b/internal/config/merge.go similarity index 100% rename from pkg/config/merge.go rename to internal/config/merge.go diff --git a/pkg/config/merge_test.go b/internal/config/merge_test.go similarity index 100% rename from pkg/config/merge_test.go rename to internal/config/merge_test.go diff --git a/internal/config/provider.go b/internal/config/provider.go index 09e3b0e3fc84b9e2688ccc4d2559604aca83ddfc..af61a07776da15e181215eece94da5bb6f64fd9e 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -4,27 +4,44 @@ import ( "encoding/json" "os" "path/filepath" + "runtime" "sync" "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" ) -var fur = client.New() +type ProviderClient interface { + GetProviders() ([]provider.Provider, error) +} var ( - providerOnc sync.Once // Ensures the initialization happens only once + providerOnce sync.Once providerList []provider.Provider - // UseMockProviders can be set to true in tests to avoid API calls - UseMockProviders bool ) -func providersPath() string { - return filepath.Join(baseDataPath(), "providers.json") +// file to cache provider data +func providerCacheFileData() string { + xdgDataHome := os.Getenv("XDG_DATA_HOME") + if xdgDataHome != "" { + return filepath.Join(xdgDataHome, appName) + } + + // return the path to the main data directory + // for windows, it should be in `%LOCALAPPDATA%/crush/` + // for linux and macOS, it should be in `$HOME/.local/share/crush/` + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") + } + return filepath.Join(localAppData, appName) + } + + return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json") } -func saveProviders(providers []provider.Provider) error { - path := providersPath() +func saveProvidersInCache(path string, providers []provider.Provider) error { dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0o755); err != nil { return err @@ -38,8 +55,7 @@ func saveProviders(providers []provider.Provider) error { return os.WriteFile(path, data, 0o644) } -func loadProviders() ([]provider.Provider, error) { - path := providersPath() +func loadProvidersFromCache(path string) ([]provider.Provider, error) { data, err := os.ReadFile(path) if err != nil { return nil, err @@ -50,34 +66,33 @@ func loadProviders() ([]provider.Provider, error) { return providers, err } -func Providers() []provider.Provider { - providerOnc.Do(func() { - // Use mock providers when testing - if UseMockProviders { - providerList = MockProviders() - return +func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) { + providers, err := client.GetProviders() + if err != nil { + fallbackToCache, err := loadProvidersFromCache(path) + if err != nil { + return nil, err } - - // Try to get providers from upstream API - if providers, err := fur.GetProviders(); err == nil { - providerList = providers - // Save providers locally for future fallback - _ = saveProviders(providers) - } else { - // If upstream fails, try to load from local cache - if localProviders, localErr := loadProviders(); localErr == nil { - providerList = localProviders - } else { - // If both fail, return empty list - providerList = []provider.Provider{} - } + providers = fallbackToCache + } else { + if err := saveProvidersInCache(path, providerList); err != nil { + return nil, err } - }) - return providerList + } + return providers, nil +} + +func Providers() ([]provider.Provider, error) { + return LoadProviders(client.New()) } -// ResetProviders resets the provider cache. Useful for testing. -func ResetProviders() { - providerOnc = sync.Once{} - providerList = nil +func LoadProviders(client ProviderClient) ([]provider.Provider, error) { + var err error + providerOnce.Do(func() { + providerList, err = loadProviders(providerCacheFileData(), client) + }) + if err != nil { + return nil, err + } + return providerList, nil } diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go deleted file mode 100644 index 801afdd8d6c9891eb47fa53294c047917b031637..0000000000000000000000000000000000000000 --- a/internal/config/provider_mock.go +++ /dev/null @@ -1,293 +0,0 @@ -package config - -import ( - "github.com/charmbracelet/crush/internal/fur/provider" -) - -// MockProviders returns a mock list of providers for testing. -// This avoids making API calls during tests and provides consistent test data. -// Simplified version with only default models from each provider. -func MockProviders() []provider.Provider { - return []provider.Provider{ - { - Name: "Anthropic", - ID: provider.InferenceProviderAnthropic, - APIKey: "$ANTHROPIC_API_KEY", - APIEndpoint: "$ANTHROPIC_API_ENDPOINT", - Type: provider.TypeAnthropic, - DefaultLargeModelID: "claude-sonnet-4-20250514", - DefaultSmallModelID: "claude-3-5-haiku-20241022", - Models: []provider.Model{ - { - ID: "claude-sonnet-4-20250514", - Name: "Claude Sonnet 4", - CostPer1MIn: 3.0, - CostPer1MOut: 15.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.3, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsImages: true, - }, - { - ID: "claude-3-5-haiku-20241022", - Name: "Claude 3.5 Haiku", - CostPer1MIn: 0.8, - CostPer1MOut: 4.0, - CostPer1MInCached: 1.0, - CostPer1MOutCached: 0.08, - ContextWindow: 200000, - DefaultMaxTokens: 5000, - CanReason: false, - SupportsImages: true, - }, - }, - }, - { - Name: "OpenAI", - ID: provider.InferenceProviderOpenAI, - APIKey: "$OPENAI_API_KEY", - APIEndpoint: "$OPENAI_API_ENDPOINT", - Type: provider.TypeOpenAI, - DefaultLargeModelID: "codex-mini-latest", - DefaultSmallModelID: "gpt-4o", - Models: []provider.Model{ - { - ID: "codex-mini-latest", - Name: "Codex Mini", - CostPer1MIn: 1.5, - CostPer1MOut: 6.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.375, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: true, - HasReasoningEffort: true, - DefaultReasoningEffort: "medium", - SupportsImages: true, - }, - { - ID: "gpt-4o", - Name: "GPT-4o", - CostPer1MIn: 2.5, - CostPer1MOut: 10.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 1.25, - ContextWindow: 128000, - DefaultMaxTokens: 20000, - CanReason: false, - SupportsImages: true, - }, - }, - }, - { - Name: "Google Gemini", - ID: provider.InferenceProviderGemini, - APIKey: "$GEMINI_API_KEY", - APIEndpoint: "$GEMINI_API_ENDPOINT", - Type: provider.TypeGemini, - DefaultLargeModelID: "gemini-2.5-pro", - DefaultSmallModelID: "gemini-2.5-flash", - Models: []provider.Model{ - { - ID: "gemini-2.5-pro", - Name: "Gemini 2.5 Pro", - CostPer1MIn: 1.25, - CostPer1MOut: 10.0, - CostPer1MInCached: 1.625, - CostPer1MOutCached: 0.31, - ContextWindow: 1048576, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsImages: true, - }, - { - ID: "gemini-2.5-flash", - Name: "Gemini 2.5 Flash", - CostPer1MIn: 0.3, - CostPer1MOut: 2.5, - CostPer1MInCached: 0.3833, - CostPer1MOutCached: 0.075, - ContextWindow: 1048576, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsImages: true, - }, - }, - }, - { - Name: "xAI", - ID: provider.InferenceProviderXAI, - APIKey: "$XAI_API_KEY", - APIEndpoint: "https://api.x.ai/v1", - Type: provider.TypeXAI, - DefaultLargeModelID: "grok-3", - DefaultSmallModelID: "grok-3-mini", - Models: []provider.Model{ - { - ID: "grok-3", - Name: "Grok 3", - CostPer1MIn: 3.0, - CostPer1MOut: 15.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.75, - ContextWindow: 131072, - DefaultMaxTokens: 20000, - CanReason: false, - SupportsImages: false, - }, - { - ID: "grok-3-mini", - Name: "Grok 3 Mini", - CostPer1MIn: 0.3, - CostPer1MOut: 0.5, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.075, - ContextWindow: 131072, - DefaultMaxTokens: 20000, - CanReason: true, - SupportsImages: false, - }, - }, - }, - { - Name: "Azure OpenAI", - ID: provider.InferenceProviderAzure, - APIKey: "$AZURE_OPENAI_API_KEY", - APIEndpoint: "$AZURE_OPENAI_API_ENDPOINT", - Type: provider.TypeAzure, - DefaultLargeModelID: "o4-mini", - DefaultSmallModelID: "gpt-4o", - Models: []provider.Model{ - { - ID: "o4-mini", - Name: "o4 Mini", - CostPer1MIn: 1.1, - CostPer1MOut: 4.4, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.275, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: true, - HasReasoningEffort: false, - DefaultReasoningEffort: "medium", - SupportsImages: true, - }, - { - ID: "gpt-4o", - Name: "GPT-4o", - CostPer1MIn: 2.5, - CostPer1MOut: 10.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 1.25, - ContextWindow: 128000, - DefaultMaxTokens: 20000, - CanReason: false, - SupportsImages: true, - }, - }, - }, - { - Name: "AWS Bedrock", - ID: provider.InferenceProviderBedrock, - Type: provider.TypeBedrock, - DefaultLargeModelID: "anthropic.claude-sonnet-4-20250514-v1:0", - DefaultSmallModelID: "anthropic.claude-3-5-haiku-20241022-v1:0", - Models: []provider.Model{ - { - ID: "anthropic.claude-sonnet-4-20250514-v1:0", - Name: "AWS Claude Sonnet 4", - CostPer1MIn: 3.0, - CostPer1MOut: 15.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.3, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsImages: true, - }, - { - ID: "anthropic.claude-3-5-haiku-20241022-v1:0", - Name: "AWS Claude 3.5 Haiku", - CostPer1MIn: 0.8, - CostPer1MOut: 4.0, - CostPer1MInCached: 1.0, - CostPer1MOutCached: 0.08, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: false, - SupportsImages: true, - }, - }, - }, - { - Name: "Google Vertex AI", - ID: provider.InferenceProviderVertexAI, - Type: provider.TypeVertexAI, - DefaultLargeModelID: "gemini-2.5-pro", - DefaultSmallModelID: "gemini-2.5-flash", - Models: []provider.Model{ - { - ID: "gemini-2.5-pro", - Name: "Gemini 2.5 Pro", - CostPer1MIn: 1.25, - CostPer1MOut: 10.0, - CostPer1MInCached: 1.625, - CostPer1MOutCached: 0.31, - ContextWindow: 1048576, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsImages: true, - }, - { - ID: "gemini-2.5-flash", - Name: "Gemini 2.5 Flash", - CostPer1MIn: 0.3, - CostPer1MOut: 2.5, - CostPer1MInCached: 0.3833, - CostPer1MOutCached: 0.075, - ContextWindow: 1048576, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsImages: true, - }, - }, - }, - { - Name: "OpenRouter", - ID: provider.InferenceProviderOpenRouter, - APIKey: "$OPENROUTER_API_KEY", - APIEndpoint: "https://openrouter.ai/api/v1", - Type: provider.TypeOpenAI, - DefaultLargeModelID: "anthropic/claude-sonnet-4", - DefaultSmallModelID: "anthropic/claude-haiku-3.5", - Models: []provider.Model{ - { - ID: "anthropic/claude-sonnet-4", - Name: "Anthropic: Claude Sonnet 4", - CostPer1MIn: 3.0, - CostPer1MOut: 15.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.3, - ContextWindow: 200000, - DefaultMaxTokens: 32000, - CanReason: true, - SupportsImages: true, - }, - { - ID: "anthropic/claude-haiku-3.5", - Name: "Anthropic: Claude 3.5 Haiku", - CostPer1MIn: 0.8, - CostPer1MOut: 4.0, - CostPer1MInCached: 1.0, - CostPer1MOutCached: 0.08, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: true, - }, - }, - }, - } -} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 53a084d244c2d48538a514e8c72530a3850782d7..a3562838c7103239aa303c906c866220164a4ba0 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,81 +1,73 @@ package config import ( + "encoding/json" + "errors" + "os" "testing" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestProviders_MockEnabled(t *testing.T) { - originalUseMock := UseMockProviders - UseMockProviders = true - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - ResetProviders() - providers := Providers() - require.NotEmpty(t, providers) +type mockProviderClient struct { + shouldFail bool +} - providerIDs := make(map[provider.InferenceProvider]bool) - for _, p := range providers { - providerIDs[p.ID] = true +func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) { + if m.shouldFail { + return nil, errors.New("failed to load providers") } - - assert.True(t, providerIDs[provider.InferenceProviderAnthropic]) - assert.True(t, providerIDs[provider.InferenceProviderOpenAI]) - assert.True(t, providerIDs[provider.InferenceProviderGemini]) + return []provider.Provider{ + { + Name: "Mock", + }, + }, nil } -func TestProviders_ResetFunctionality(t *testing.T) { - UseMockProviders = true - defer func() { - UseMockProviders = false - ResetProviders() - }() - - providers1 := Providers() - require.NotEmpty(t, providers1) - - ResetProviders() - providers2 := Providers() - require.NotEmpty(t, providers2) +func TestProvider_loadProvidersNoIssues(t *testing.T) { + client := &mockProviderClient{shouldFail: false} + tmpPath := t.TempDir() + "/providers.json" + providers, err := loadProviders(tmpPath, client) + assert.NoError(t, err) + assert.NotNil(t, providers) + assert.Len(t, providers, 1) - assert.Equal(t, len(providers1), len(providers2)) + // check if file got saved + fileInfo, err := os.Stat(tmpPath) + assert.NoError(t, err) + assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory") } -func TestProviders_ModelCapabilities(t *testing.T) { - originalUseMock := UseMockProviders - UseMockProviders = true - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - ResetProviders() - providers := Providers() - - var openaiProvider provider.Provider - for _, p := range providers { - if p.ID == provider.InferenceProviderOpenAI { - openaiProvider = p - break - } +func TestProvider_loadProvidersWithIssues(t *testing.T) { + client := &mockProviderClient{shouldFail: true} + tmpPath := t.TempDir() + "/providers.json" + // store providers to a temporary file + oldProviders := []provider.Provider{ + { + Name: "OldProvider", + }, + } + data, err := json.Marshal(oldProviders) + if err != nil { + t.Fatalf("Failed to marshal old providers: %v", err) } - require.NotEmpty(t, openaiProvider.ID) - var foundReasoning, foundNonReasoning bool - for _, model := range openaiProvider.Models { - if model.CanReason && model.HasReasoningEffort { - foundReasoning = true - } else if !model.CanReason { - foundNonReasoning = true - } + err = os.WriteFile(tmpPath, data, 0o644) + if err != nil { + t.Fatalf("Failed to write old providers to file: %v", err) } + providers, err := loadProviders(tmpPath, client) + assert.NoError(t, err) + assert.NotNil(t, providers) + assert.Len(t, providers, 1) + assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails") +} - assert.True(t, foundReasoning) - assert.True(t, foundNonReasoning) +func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { + client := &mockProviderClient{shouldFail: true} + tmpPath := t.TempDir() + "/providers.json" + providers, err := loadProviders(tmpPath, client) + assert.Error(t, err) + assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") } diff --git a/pkg/config/resolve.go b/internal/config/resolve.go similarity index 97% rename from pkg/config/resolve.go rename to internal/config/resolve.go index 9e88a8f06e3572bd557ac09abdef8e84ada71f9e..776897e78be6fb4f20058e700cf2906d1072fc4b 100644 --- a/pkg/config/resolve.go +++ b/internal/config/resolve.go @@ -7,7 +7,7 @@ import ( "time" "github.com/charmbracelet/crush/internal/shell" - "github.com/charmbracelet/crush/pkg/env" + "github.com/charmbracelet/crush/internal/env" ) type VariableResolver interface { diff --git a/pkg/config/resolve_test.go b/internal/config/resolve_test.go similarity index 98% rename from pkg/config/resolve_test.go rename to internal/config/resolve_test.go index b1de46ad7565a6574af51f05d8292ef062e706c3..7cdcd2a7913cb581e5312f787791e8e89e699281 100644 --- a/pkg/config/resolve_test.go +++ b/internal/config/resolve_test.go @@ -5,7 +5,7 @@ import ( "errors" "testing" - "github.com/charmbracelet/crush/pkg/env" + "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/assert" ) diff --git a/internal/config/shell.go b/internal/config/shell.go deleted file mode 100644 index b7c3c8c5a787def8ff28aec677193f5ac58b652a..0000000000000000000000000000000000000000 --- a/internal/config/shell.go +++ /dev/null @@ -1,73 +0,0 @@ -package config - -import ( - "context" - "fmt" - "os" - "strings" - "time" - - "github.com/charmbracelet/crush/internal/logging" - "github.com/charmbracelet/crush/internal/shell" -) - -// ExecuteCommand executes a shell command and returns the output -// This is a shared utility that can be used by both provider config and tools -func ExecuteCommand(ctx context.Context, command string, workingDir string) (string, error) { - if workingDir == "" { - workingDir = WorkingDirectory() - } - - persistentShell := shell.NewShell(&shell.Options{WorkingDir: workingDir}) - - stdout, stderr, err := persistentShell.Exec(ctx, command) - if err != nil { - logging.Debug("Command execution failed", "command", command, "error", err, "stderr", stderr) - return "", fmt.Errorf("command execution failed: %w", err) - } - - return strings.TrimSpace(stdout), nil -} - -// ResolveAPIKey resolves an API key that can be either: -// - A direct string value -// - An environment variable (prefixed with $) -// - A shell command (wrapped in $(...)) -func ResolveAPIKey(apiKey string) (string, error) { - if !strings.HasPrefix(apiKey, "$") { - return apiKey, nil - } - - if strings.HasPrefix(apiKey, "$(") && strings.HasSuffix(apiKey, ")") { - command := strings.TrimSuffix(strings.TrimPrefix(apiKey, "$("), ")") - logging.Debug("Resolving API key from command", "command", command) - return resolveCommandAPIKey(command) - } - - envVar := strings.TrimPrefix(apiKey, "$") - if value := os.Getenv(envVar); value != "" { - logging.Debug("Resolved environment variable", "envVar", envVar, "value", value) - return value, nil - } - - logging.Debug("Environment variable not found", "envVar", envVar) - - return "", fmt.Errorf("environment variable %s not found", envVar) -} - -// resolveCommandAPIKey executes a command to get an API key, with caching support -func resolveCommandAPIKey(command string) (string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - logging.Debug("Executing command for API key", "command", command) - - workingDir := WorkingDirectory() - - result, err := ExecuteCommand(ctx, command, workingDir) - if err != nil { - return "", fmt.Errorf("failed to execute API key command: %w", err) - } - logging.Debug("Command executed successfully", "command", command, "result", result) - return result, nil -} diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go deleted file mode 100644 index 0aef035ae7bddfc7532e9dde550ab0184ed180db..0000000000000000000000000000000000000000 --- a/internal/config/validation_test.go +++ /dev/null @@ -1,462 +0,0 @@ -package config - -import ( - "testing" - - "github.com/charmbracelet/crush/internal/fur/provider" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestConfig_Validate_ValidConfig(t *testing.T) { - cfg := &Config{ - Models: PreferredModels{ - Large: PreferredModel{ - ModelID: "gpt-4", - Provider: provider.InferenceProviderOpenAI, - }, - Small: PreferredModel{ - ModelID: "gpt-3.5-turbo", - Provider: provider.InferenceProviderOpenAI, - }, - }, - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - CostPer1MIn: 30.0, - CostPer1MOut: 60.0, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - CostPer1MIn: 1.5, - CostPer1MOut: 2.0, - }, - }, - }, - }, - Agents: map[AgentID]Agent{ - AgentCoder: { - ID: AgentCoder, - Name: "Coder", - Description: "An agent that helps with executing coding tasks.", - Model: LargeModel, - ContextPaths: []string{"CRUSH.md"}, - }, - AgentTask: { - ID: AgentTask, - Name: "Task", - Description: "An agent that helps with searching for context and finding implementation details.", - Model: LargeModel, - ContextPaths: []string{"CRUSH.md"}, - AllowedTools: []string{"glob", "grep", "ls", "sourcegraph", "view"}, - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - }, - MCP: map[string]MCP{}, - LSP: map[string]LSPConfig{}, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - assert.NoError(t, err) -} - -func TestConfig_Validate_MissingAPIKey(t *testing.T) { - cfg := &Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - ProviderType: provider.TypeOpenAI, - // Missing APIKey - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "API key is required") -} - -func TestConfig_Validate_InvalidProviderType(t *testing.T) { - cfg := &Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.Type("invalid"), - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid provider type") -} - -func TestConfig_Validate_CustomProviderMissingBaseURL(t *testing.T) { - customProvider := provider.InferenceProvider("custom-provider") - cfg := &Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - customProvider: { - ID: customProvider, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - // Missing BaseURL for custom provider - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "BaseURL is required for custom providers") -} - -func TestConfig_Validate_DuplicateModelIDs(t *testing.T) { - cfg := &Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-4", // Duplicate ID - Name: "GPT-4 Duplicate", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "duplicate model ID") -} - -func TestConfig_Validate_InvalidModelFields(t *testing.T) { - cfg := &Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - Models: []Model{ - { - ID: "", // Empty ID - Name: "GPT-4", - ContextWindow: 0, // Invalid context window - DefaultMaxTokens: -1, // Invalid max tokens - CostPer1MIn: -5.0, // Negative cost - }, - }, - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - validationErr := err.(ValidationErrors) - assert.True(t, len(validationErr) >= 4) // Should have multiple validation errors -} - -func TestConfig_Validate_DefaultModelNotFound(t *testing.T) { - cfg := &Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "nonexistent-model", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - }, - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "default large model 'nonexistent-model' not found") -} - -func TestConfig_Validate_AgentIDMismatch(t *testing.T) { - cfg := &Config{ - Agents: map[AgentID]Agent{ - AgentCoder: { - ID: AgentTask, // Wrong ID - Name: "Coder", - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "agent ID mismatch") -} - -func TestConfig_Validate_InvalidAgentModelType(t *testing.T) { - cfg := &Config{ - Agents: map[AgentID]Agent{ - AgentCoder: { - ID: AgentCoder, - Name: "Coder", - Model: ModelType("invalid"), - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid model type") -} - -func TestConfig_Validate_UnknownTool(t *testing.T) { - cfg := &Config{ - Agents: map[AgentID]Agent{ - AgentID("custom-agent"): { - ID: AgentID("custom-agent"), - Name: "Custom Agent", - Model: LargeModel, - AllowedTools: []string{"unknown-tool"}, - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "unknown tool") -} - -func TestConfig_Validate_MCPReference(t *testing.T) { - cfg := &Config{ - Agents: map[AgentID]Agent{ - AgentID("custom-agent"): { - ID: AgentID("custom-agent"), - Name: "Custom Agent", - Model: LargeModel, - AllowedMCP: map[string][]string{"nonexistent-mcp": nil}, - }, - }, - MCP: map[string]MCP{}, // Empty MCP map - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "referenced MCP 'nonexistent-mcp' not found") -} - -func TestConfig_Validate_InvalidMCPType(t *testing.T) { - cfg := &Config{ - MCP: map[string]MCP{ - "test-mcp": { - Type: MCPType("invalid"), - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid MCP type") -} - -func TestConfig_Validate_MCPMissingCommand(t *testing.T) { - cfg := &Config{ - MCP: map[string]MCP{ - "test-mcp": { - Type: MCPStdio, - // Missing Command - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "command is required for stdio MCP") -} - -func TestConfig_Validate_LSPMissingCommand(t *testing.T) { - cfg := &Config{ - LSP: map[string]LSPConfig{ - "test-lsp": { - // Missing Command - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "command is required for LSP") -} - -func TestConfig_Validate_NoValidProviders(t *testing.T) { - cfg := &Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - Disabled: true, // Disabled - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "at least one non-disabled provider is required") -} - -func TestConfig_Validate_MissingDefaultAgents(t *testing.T) { - cfg := &Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - }, - }, - Agents: map[AgentID]Agent{}, // Missing default agents - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "coder agent is required") - assert.Contains(t, err.Error(), "task agent is required") -} - -func TestConfig_Validate_KnownAgentProtection(t *testing.T) { - cfg := &Config{ - Agents: map[AgentID]Agent{ - AgentCoder: { - ID: AgentCoder, - Name: "Modified Coder", // Should not be allowed - Description: "Modified description", // Should not be allowed - Model: LargeModel, - }, - }, - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "coder agent name cannot be changed") - assert.Contains(t, err.Error(), "coder agent description cannot be changed") -} - -func TestConfig_Validate_EmptyDataDirectory(t *testing.T) { - cfg := &Config{ - Options: Options{ - DataDirectory: "", // Empty - ContextPaths: []string{"CRUSH.md"}, - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "data directory is required") -} - -func TestConfig_Validate_EmptyContextPath(t *testing.T) { - cfg := &Config{ - Options: Options{ - DataDirectory: ".crush", - ContextPaths: []string{""}, // Empty context path - }, - } - - err := cfg.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "context path cannot be empty") -} diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 694b9db417790bbcda679ec84f426a2a4d0e8f7b..928785339de639ee910b5b7bfb282296662e4687 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -11,7 +11,7 @@ import ( func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) { // remove the cwd prefix and ensure consistent path format // this prevents issues with absolute paths in different environments - cwd := config.WorkingDirectory() + cwd := config.Get().WorkingDir() fileName = strings.TrimPrefix(fileName, cwd) fileName = strings.TrimPrefix(fileName, "/") diff --git a/pkg/env/env.go b/internal/env/env.go similarity index 100% rename from pkg/env/env.go rename to internal/env/env.go diff --git a/pkg/env/env_test.go b/internal/env/env_test.go similarity index 100% rename from pkg/env/env_test.go rename to internal/env/env_test.go diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index da652dc0af9c0fd6dbf768f759009a65b9ef0574..c515f4b60efb6e1a0a1ea19d848f42eee333d2a9 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -10,6 +10,7 @@ import ( "time" "github.com/charmbracelet/crush/internal/config" + fur "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" @@ -49,7 +50,7 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] - Model() config.Model + Model() fur.Model Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() @@ -76,9 +77,9 @@ type agent struct { activeRequests sync.Map } -var agentPromptMap = map[config.AgentID]prompt.PromptID{ - config.AgentCoder: prompt.PromptCoder, - config.AgentTask: prompt.PromptTask, +var agentPromptMap = map[string]prompt.PromptID{ + "coder": prompt.PromptCoder, + "task": prompt.PromptTask, } func NewAgent( @@ -109,8 +110,8 @@ func NewAgent( tools.NewWriteTool(lspClients, permissions, history), } - if agentCfg.ID == config.AgentCoder { - taskAgentCfg := config.Get().Agents[config.AgentTask] + if agentCfg.ID == "coder" { + taskAgentCfg := config.Get().Agents["task"] if taskAgentCfg.ID == "" { return nil, fmt.Errorf("task agent not found in config") } @@ -130,13 +131,13 @@ func NewAgent( } allTools = append(allTools, otherTools...) - providerCfg := config.GetAgentProvider(agentCfg.ID) - if providerCfg.ID == "" { + providerCfg := config.Get().GetProviderForModel(agentCfg.Model) + if providerCfg == nil { return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) } - model := config.GetAgentModel(agentCfg.ID) + model := config.Get().GetModelByType(agentCfg.Model) - if model.ID == "" { + if model == nil { return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name) } @@ -148,51 +149,40 @@ func NewAgent( provider.WithModel(agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), } - agentProvider, err := provider.NewProvider(providerCfg, opts...) + agentProvider, err := provider.NewProvider(*providerCfg, opts...) if err != nil { return nil, err } - smallModelCfg := cfg.Models.Small - var smallModel config.Model - - var smallModelProviderCfg config.ProviderConfig + smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] + var smallModelProviderCfg *config.ProviderConfig if smallModelCfg.Provider == providerCfg.ID { smallModelProviderCfg = providerCfg } else { - for _, p := range cfg.Providers { - if p.ID == smallModelCfg.Provider { - smallModelProviderCfg = p - break - } - } + smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall) + if smallModelProviderCfg.ID == "" { return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) } } - for _, m := range smallModelProviderCfg.Models { - if m.ID == smallModelCfg.ModelID { - smallModel = m - break - } - } + smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall) if smallModel.ID == "" { - return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID) + return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID) } titleOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SmallModel), + provider.WithModel(config.SelectedModelTypeSmall), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), } - titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) + titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...) if err != nil { return nil, err } summarizeOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SmallModel), + provider.WithModel(config.SelectedModelTypeSmall), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), } - summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...) + summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...) if err != nil { return nil, err } @@ -225,8 +215,8 @@ func NewAgent( return agent, nil } -func (a *agent) Model() config.Model { - return config.GetAgentModel(a.agentCfg.ID) +func (a *agent) Model() fur.Model { + return *config.Get().GetModelByType(a.agentCfg.Model) } func (a *agent) Cancel(sessionID string) { @@ -610,7 +600,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg return nil } -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error { +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error { sess, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -819,7 +809,7 @@ func (a *agent) UpdateModel() error { cfg := config.Get() // Get current provider configuration - currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID) + currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model) if currentProviderCfg.ID == "" { return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name) } @@ -827,7 +817,7 @@ func (a *agent) UpdateModel() error { // Check if provider has changed if string(currentProviderCfg.ID) != a.providerID { // Provider changed, need to recreate the main provider - model := config.GetAgentModel(a.agentCfg.ID) + model := cfg.GetModelByType(a.agentCfg.Model) if model.ID == "" { return fmt.Errorf("model not found for agent %s", a.agentCfg.Name) } @@ -842,7 +832,7 @@ func (a *agent) UpdateModel() error { provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), } - newProvider, err := provider.NewProvider(currentProviderCfg, opts...) + newProvider, err := provider.NewProvider(*currentProviderCfg, opts...) if err != nil { return fmt.Errorf("failed to create new provider: %w", err) } @@ -853,7 +843,7 @@ func (a *agent) UpdateModel() error { } // Check if small model provider has changed (affects title and summarize providers) - smallModelCfg := cfg.Models.Small + smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] var smallModelProviderCfg config.ProviderConfig for _, p := range cfg.Providers { @@ -869,20 +859,14 @@ func (a *agent) UpdateModel() error { // Check if summarize provider has changed if string(smallModelProviderCfg.ID) != a.summarizeProviderID { - var smallModel config.Model - for _, m := range smallModelProviderCfg.Models { - if m.ID == smallModelCfg.ModelID { - smallModel = m - break - } - } - if smallModel.ID == "" { - return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID) + smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall) + if smallModel == nil { + return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID) } // Recreate title provider titleOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SmallModel), + provider.WithModel(config.SelectedModelTypeSmall), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), // We want the title to be short, so we limit the max tokens provider.WithMaxTokens(40), @@ -894,7 +878,7 @@ func (a *agent) UpdateModel() error { // Recreate summarize provider summarizeOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SmallModel), + provider.WithModel(config.SelectedModelTypeSmall), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), } newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...) diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 3fa4e778e9df09f1728641ca578cb7382d9c87b0..5d1bd44d56056051a841e610bdd31e0bf91f2183 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -19,7 +19,7 @@ import ( type mcpTool struct { mcpName string tool mcp.Tool - mcpConfig config.MCP + mcpConfig config.MCPConfig permissions permission.Service } @@ -97,7 +97,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes p := b.permissions.Request( permission.CreatePermissionRequest{ SessionID: sessionID, - Path: config.WorkingDirectory(), + Path: config.Get().WorkingDir(), ToolName: b.Info().Name, Action: "execute", Description: permissionDescription, @@ -142,7 +142,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("invalid mcp type"), nil } -func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool { +func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool { return &mcpTool{ mcpName: name, tool: tool, @@ -153,7 +153,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC var mcpTools []tools.BaseTool -func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool { +func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool { var stdioTools []tools.BaseTool initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 523933d18e5c39ea766c42e1aafe09b5aaff3e63..b85b084e9954ad1afa35ebf58acfcc299038db2b 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -14,12 +14,12 @@ import ( "github.com/charmbracelet/crush/internal/logging" ) -func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string { +func CoderPrompt(p string, contextFiles ...string) string { var basePrompt string switch p { - case provider.InferenceProviderOpenAI: + case string(provider.InferenceProviderOpenAI): basePrompt = baseOpenAICoderPrompt - case provider.InferenceProviderGemini, provider.InferenceProviderVertexAI: + case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI): basePrompt = baseGeminiCoderPrompt default: basePrompt = baseAnthropicCoderPrompt @@ -380,7 +380,7 @@ Your core function is efficient and safe assistance. Balance extreme conciseness ` func getEnvironmentInfo() string { - cwd := config.WorkingDirectory() + cwd := config.Get().WorkingDir() isGit := isGitRepo(cwd) platform := runtime.GOOS date := time.Now().Format("1/2/2006") diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 36148edd9c71790c3a4cb06d551cdee06272c8b7..0c70d41770ca0d35bcf1cb1d0e7353d1b194790d 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" ) type PromptID string @@ -20,17 +19,17 @@ const ( PromptDefault PromptID = "default" ) -func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string { +func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string { basePrompt := "" switch promptID { case PromptCoder: basePrompt = CoderPrompt(provider) case PromptTitle: - basePrompt = TitlePrompt(provider) + basePrompt = TitlePrompt() case PromptTask: - basePrompt = TaskPrompt(provider) + basePrompt = TaskPrompt() case PromptSummarizer: - basePrompt = SummarizerPrompt(provider) + basePrompt = SummarizerPrompt() default: basePrompt = "You are a helpful assistant" } @@ -38,7 +37,7 @@ func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPa } func getContextFromPaths(contextPaths []string) string { - return processContextPaths(config.WorkingDirectory(), contextPaths) + return processContextPaths(config.Get().WorkingDir(), contextPaths) } func processContextPaths(workDir string, paths []string) string { diff --git a/internal/llm/prompt/summarizer.go b/internal/llm/prompt/summarizer.go index 77d98184bcf985ebb2bc569205b6b4cc77b3d601..f9c4c336390c30dcfd8bf6fe950aff2b76a386a4 100644 --- a/internal/llm/prompt/summarizer.go +++ b/internal/llm/prompt/summarizer.go @@ -1,10 +1,6 @@ package prompt -import ( - "github.com/charmbracelet/crush/internal/fur/provider" -) - -func SummarizerPrompt(_ provider.InferenceProvider) string { +func SummarizerPrompt() string { return `You are a helpful AI assistant tasked with summarizing conversations. When asked to summarize, provide a detailed but concise summary of the conversation. diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index 719c0ef45778814e38b391e86174708edcdd7c3e..e4f021d4ab7ef9f49873bc6893a231d72f2f3994 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -2,11 +2,9 @@ package prompt import ( "fmt" - - "github.com/charmbracelet/crush/internal/fur/provider" ) -func TaskPrompt(_ provider.InferenceProvider) string { +func TaskPrompt() string { agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 11bab4b6835ac0e53adc578cfddd3133f8b654e5..0dae6fde63d1a4ccc6996c5186c0deca74126984 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,10 +1,6 @@ package prompt -import ( - "github.com/charmbracelet/crush/internal/fur/provider" -) - -func TitlePrompt(_ provider.InferenceProvider) string { +func TitlePrompt() string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 25f418878e071a46ac122d8bc51db6969f1fcbc7..4ed18d7b6595bfd28c4b02e473a186c21c0f84eb 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -153,9 +153,9 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to model := a.providerOptions.model(a.providerOptions.modelType) var thinkingParam anthropic.ThinkingConfigParamUnion cfg := config.Get() - modelConfig := cfg.Models.Large - if a.providerOptions.modelType == config.SmallModel { - modelConfig = cfg.Models.Small + modelConfig := cfg.Models[config.SelectedModelTypeLarge] + if a.providerOptions.modelType == config.SelectedModelTypeSmall { + modelConfig = cfg.Models[config.SelectedModelTypeSmall] } temperature := anthropic.Float(0) @@ -399,7 +399,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err } if apiErr.StatusCode == 401 { - a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey) + a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } @@ -490,6 +490,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { } } -func (a *anthropicClient) Model() config.Model { +func (a *anthropicClient) Model() provider.Model { return a.providerOptions.model(a.providerOptions.modelType) } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 1519099b00401e32ad5f19c1f6ed253eb8b7130d..0c0ccdbab2d642f139a2b1ab2f19f6298f1ac73d 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -31,14 +32,14 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { } } - opts.model = func(modelType config.ModelType) config.Model { - model := config.GetModel(modelType) + opts.model = func(modelType config.SelectedModelType) provider.Model { + model := config.Get().GetModelByType(modelType) // Prefix the model name with region regionPrefix := region[:2] modelName := model.ID model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) - return model + return *model } model := opts.model(opts.modelType) @@ -87,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, return b.childProvider.stream(ctx, messages, tools) } -func (b *bedrockClient) Model() config.Model { +func (b *bedrockClient) Model() provider.Model { return b.providerOptions.model(b.providerOptions.modelType) } diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index e80af34d0815695ea6ed76d01c25262381a836ec..7e9fdbd405dbf64c58873a6e1cfc108e9d4b4f7f 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -10,6 +10,7 @@ import ( "time" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/message" @@ -170,9 +171,9 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too logging.Debug("Prepared messages", "messages", string(jsonData)) } - modelConfig := cfg.Models.Large - if g.providerOptions.modelType == config.SmallModel { - modelConfig = cfg.Models.Small + modelConfig := cfg.Models[config.SelectedModelTypeLarge] + if g.providerOptions.modelType == config.SelectedModelTypeSmall { + modelConfig = cfg.Models[config.SelectedModelTypeSmall] } maxTokens := model.DefaultMaxTokens @@ -268,9 +269,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t logging.Debug("Prepared messages", "messages", string(jsonData)) } - modelConfig := cfg.Models.Large - if g.providerOptions.modelType == config.SmallModel { - modelConfig = cfg.Models.Small + modelConfig := cfg.Models[config.SelectedModelTypeLarge] + if g.providerOptions.modelType == config.SelectedModelTypeSmall { + modelConfig = cfg.Models[config.SelectedModelTypeSmall] } maxTokens := model.DefaultMaxTokens if modelConfig.MaxTokens > 0 { @@ -424,7 +425,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) // Check for token expiration (401 Unauthorized) if contains(errMsg, "unauthorized", "invalid api key", "api key expired") { - g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey) + g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } @@ -462,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { } } -func (g *geminiClient) Model() config.Model { +func (g *geminiClient) Model() provider.Model { return g.providerOptions.model(g.providerOptions.modelType) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 561046b74cf7d4d0c5fd871f65b82d9634a8cfa1..6fcc0b25bb2f0721d3a46f0f0bcb32589e477816 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -148,15 +148,12 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar model := o.providerOptions.model(o.providerOptions.modelType) cfg := config.Get() - modelConfig := cfg.Models.Large - if o.providerOptions.modelType == config.SmallModel { - modelConfig = cfg.Models.Small + modelConfig := cfg.Models[config.SelectedModelTypeLarge] + if o.providerOptions.modelType == config.SelectedModelTypeSmall { + modelConfig = cfg.Models[config.SelectedModelTypeSmall] } - reasoningEffort := model.ReasoningEffort - if modelConfig.ReasoningEffort != "" { - reasoningEffort = modelConfig.ReasoningEffort - } + reasoningEffort := modelConfig.ReasoningEffort params := openai.ChatCompletionNewParams{ Model: openai.ChatModel(model.ID), @@ -363,7 +360,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) // Check for token expiration (401 Unauthorized) if apiErr.StatusCode == 401 { - o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey) + o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } @@ -420,6 +417,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { } } -func (a *openaiClient) Model() config.Model { +func (a *openaiClient) Model() provider.Model { return a.providerOptions.model(a.providerOptions.modelType) } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 3ffbf86c00c5e3ca27f1b68965f4ff950f1f7454..193affc2a2b5a6dcdecee596a839882c40f70a42 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -55,15 +55,15 @@ type Provider interface { StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() config.Model + Model() provider.Model } type providerClientOptions struct { baseURL string config config.ProviderConfig apiKey string - modelType config.ModelType - model func(config.ModelType) config.Model + modelType config.SelectedModelType + model func(config.SelectedModelType) provider.Model disableCache bool systemMessage string maxTokens int64 @@ -77,7 +77,7 @@ type ProviderClient interface { send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() config.Model + Model() provider.Model } type baseProvider[C ProviderClient] struct { @@ -106,11 +106,11 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message return p.client.stream(ctx, messages, tools) } -func (p *baseProvider[C]) Model() config.Model { +func (p *baseProvider[C]) Model() provider.Model { return p.client.Model() } -func WithModel(model config.ModelType) ProviderClientOption { +func WithModel(model config.SelectedModelType) ProviderClientOption { return func(options *providerClientOptions) { options.modelType = model } @@ -135,7 +135,7 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption { } func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { - resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey) + resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey) if err != nil { return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err) } @@ -145,14 +145,14 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi config: cfg, apiKey: resolvedAPIKey, extraHeaders: cfg.ExtraHeaders, - model: func(tp config.ModelType) config.Model { - return config.GetModel(tp) + model: func(tp config.SelectedModelType) provider.Model { + return *config.Get().GetModelByType(tp) }, } for _, o := range opts { o(&clientOptions) } - switch cfg.ProviderType { + switch cfg.Type { case provider.TypeAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, @@ -190,5 +190,5 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi client: newOpenAIClient(clientOptions), }, nil } - return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType) + return nil, fmt.Errorf("provider not supported: %s", cfg.Type) } diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 5f8b41338c8c5ef6f771e80fbd4e1355b27eb036..03d3af32d95fd032d2f0f7092d66493c60867db8 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -317,7 +317,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) p := b.permissions.Request( permission.CreatePermissionRequest{ SessionID: sessionID, - Path: config.WorkingDirectory(), + Path: config.Get().WorkingDir(), ToolName: BashToolName, Action: "execute", Description: fmt.Sprintf("Execute command: %s", params.Command), @@ -337,7 +337,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) defer cancel() } stdout, stderr, err := shell. - GetPersistentShell(config.WorkingDirectory()). + GetPersistentShell(config.Get().WorkingDir()). Exec(ctx, params.Command) interrupted := shell.IsInterrupt(err) exitCode := shell.ExitCode(err) diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index b72112f43e140edd7298e802ab88ba2747784d7c..1602f65ea109e9e7ad0468687ba24ce674fcaea9 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -143,7 +143,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } if !filepath.IsAbs(params.FilePath) { - wd := config.WorkingDirectory() + wd := config.Get().WorkingDir() params.FilePath = filepath.Join(wd, params.FilePath) } @@ -207,7 +207,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) content, filePath, ) - rootDir := config.WorkingDirectory() + rootDir := config.Get().WorkingDir() permissionPath := filepath.Dir(filePath) if strings.HasPrefix(filePath, rootDir) { permissionPath = rootDir @@ -320,7 +320,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string filePath, ) - rootDir := config.WorkingDirectory() + rootDir := config.Get().WorkingDir() permissionPath := filepath.Dir(filePath) if strings.HasPrefix(filePath, rootDir) { permissionPath = rootDir @@ -442,7 +442,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent, filePath, ) - rootDir := config.WorkingDirectory() + rootDir := config.Get().WorkingDir() permissionPath := filepath.Dir(filePath) if strings.HasPrefix(filePath, rootDir) { permissionPath = rootDir diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index ac73ddbf3b0033cf503bdc8cfa2ef065a0072477..6895556dbd925d8396b0258ab16c422cdc1a1810 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -133,7 +133,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error p := t.permissions.Request( permission.CreatePermissionRequest{ SessionID: sessionID, - Path: config.WorkingDirectory(), + Path: config.Get().WorkingDir(), ToolName: FetchToolName, Action: "fetch", Description: fmt.Sprintf("Fetch content from URL: %s", params.URL), diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 25c80860b791a5b601366d455f5ddd1ea91523ed..6a8ba40208b0d59b9034d8502ff576d7888481ca 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -108,7 +108,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) searchPath := params.Path if searchPath == "" { - searchPath = config.WorkingDirectory() + searchPath = config.Get().WorkingDir() } files, truncated, err := globFiles(params.Pattern, searchPath, 100) diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index c3e13766884f17932187ad63cb5ffaacdf375b45..ede19c1daa75c3fea16bb52f0ed4b9ff5093e247 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -200,7 +200,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) searchPath := params.Path if searchPath == "" { - searchPath = config.WorkingDirectory() + searchPath = config.Get().WorkingDir() } matches, truncated, err := searchFiles(searchPattern, searchPath, params.Include, 100) diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a51b5bdb5dccb7c209d9cdc28e94dad328e8c093..6e858a6990b64c795ad6f4df957f9e0d5c7ad6d3 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -107,11 +107,11 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { searchPath := params.Path if searchPath == "" { - searchPath = config.WorkingDirectory() + searchPath = config.Get().WorkingDir() } if !filepath.IsAbs(searchPath) { - searchPath = filepath.Join(config.WorkingDirectory(), searchPath) + searchPath = filepath.Join(config.Get().WorkingDir(), searchPath) } if _, err := os.Stat(searchPath); os.IsNotExist(err) { diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 750efef73795f115e3ad90e4da9a2d955ee10529..b156f89a26628982417aee1ab23354abf3415f61 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -117,7 +117,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) // Handle relative paths filePath := params.FilePath if !filepath.IsAbs(filePath) { - filePath = filepath.Join(config.WorkingDirectory(), filePath) + filePath = filepath.Join(config.Get().WorkingDir(), filePath) } // Check if file exists diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 0c213cec1f4e0a9bc8fc205a183206c0842f9688..676b6e02b7c0cdae28b2e256f664a339315b0eb9 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -122,7 +122,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error filePath := params.FilePath if !filepath.IsAbs(filePath) { - filePath = filepath.Join(config.WorkingDirectory(), filePath) + filePath = filepath.Join(config.Get().WorkingDir(), filePath) } fileInfo, err := os.Stat(filePath) @@ -170,7 +170,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error filePath, ) - rootDir := config.WorkingDirectory() + rootDir := config.Get().WorkingDir() permissionPath := filepath.Dir(filePath) if strings.HasPrefix(filePath, rootDir) { permissionPath = rootDir diff --git a/pkg/log/log.go b/internal/log/log.go similarity index 100% rename from pkg/log/log.go rename to internal/log/log.go diff --git a/internal/lsp/client.go b/internal/lsp/client.go index 24ff0238c355edb5499640b93f9e06f0f07568c9..0fec0c7d79fa64abba9b8c5a9650568d55334450 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -376,7 +376,7 @@ func (c *Client) detectServerType() ServerType { // openKeyConfigFiles opens important configuration files that help initialize the server func (c *Client) openKeyConfigFiles(ctx context.Context) { - workDir := config.WorkingDirectory() + workDir := config.Get().WorkingDir() serverType := c.detectServerType() var filesToOpen []string @@ -464,7 +464,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error { } // If we have no open TypeScript files, try to find and open one - workDir := config.WorkingDirectory() + workDir := config.Get().WorkingDir() err := filepath.WalkDir(workDir, func(path string, d os.DirEntry, err error) error { if err != nil { return err diff --git a/internal/permission/permission.go b/internal/permission/permission.go index f2d92249f9ce926d6421b06714b1302de0b1530d..dcdb0c46acad2dd28d781ed12f4b13f284a7bc40 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -87,7 +87,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { dir := filepath.Dir(opts.Path) if dir == "." { - dir = config.WorkingDirectory() + dir = config.Get().WorkingDir() } permission := PermissionRequest{ ID: uuid.New().String(), diff --git a/internal/tui/components/chat/header/header.go b/internal/tui/components/chat/header/header.go index d924bdc3453dc3fce0351c490cb17b726fcc2549..45874a188f59c272cfcef4a7e41f0f631afea954 100644 --- a/internal/tui/components/chat/header/header.go +++ b/internal/tui/components/chat/header/header.go @@ -91,7 +91,7 @@ func (p *header) View() tea.View { func (h *header) details() string { t := styles.CurrentTheme() - cwd := fsext.DirTrim(fsext.PrettyPath(config.WorkingDirectory()), 4) + cwd := fsext.DirTrim(fsext.PrettyPath(config.Get().WorkingDir()), 4) parts := []string{ t.S().Muted.Render(cwd), } @@ -111,7 +111,8 @@ func (h *header) details() string { parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount))) } - model := config.GetAgentModel(config.AgentCoder) + agentCfg := config.Get().Agents["coder"] + model := config.Get().GetModelByType(agentCfg.Model) percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100 formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage))) parts = append(parts, formattedPercentage) diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index 5b60297221e647218dde3b3f24b64abc97873df5..770b0729fd27a6d110605b05d6f66fae56981716 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -10,7 +10,6 @@ import ( "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/tui/components/anim" "github.com/charmbracelet/crush/internal/tui/components/core" @@ -296,7 +295,7 @@ func (m *assistantSectionModel) View() tea.View { duration := finishTime.Sub(m.lastUserMessageTime) infoMsg := t.S().Subtle.Render(duration.String()) icon := t.S().Subtle.Render(styles.ModelIcon) - model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model) + model := config.Get().GetModel(m.message.Provider, m.message.Model) modelFormatted := t.S().Muted.Render(model.Name) assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg) return tea.NewView( diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 54e3557a682dde3075fcb16a95016d9e25941183..1f01a4b228ab0e538aa1b68effcabc4ad1e03ac7 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -297,7 +297,7 @@ func (m *sidebarCmp) filesBlock() string { } extraContent := strings.Join(statusParts, " ") - cwd := config.WorkingDirectory() + string(os.PathSeparator) + cwd := config.Get().WorkingDir() + string(os.PathSeparator) filePath := file.FilePath filePath = strings.TrimPrefix(filePath, cwd) filePath = fsext.DirTrim(fsext.PrettyPath(filePath), 2) @@ -474,7 +474,8 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string { } func (s *sidebarCmp) currentModelBlock() string { - model := config.GetAgentModel(config.AgentCoder) + agentCfg := config.Get().Agents["coder"] + model := config.Get().GetModelByType(agentCfg.Model) t := styles.CurrentTheme() @@ -507,7 +508,7 @@ func (m *sidebarCmp) SetSession(session session.Session) tea.Cmd { } func cwd() string { - cwd := config.WorkingDirectory() + cwd := config.Get().WorkingDir() t := styles.CurrentTheme() // Replace home directory with ~, unless we're at the top level of the // home directory). diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index aa7a505bd19af72a55e134fc0b077085a761faa6..dc82e2fa1c745fc46f14895680c93d30864f317a 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -31,8 +31,8 @@ const ( // ModelSelectedMsg is sent when a model is selected type ModelSelectedMsg struct { - Model config.PreferredModel - ModelType config.ModelType + Model config.SelectedModel + ModelType config.SelectedModelType } // CloseModelDialogMsg is sent when a model is selected @@ -115,19 +115,19 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { items := m.modelList.Items() selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption) - var modelType config.ModelType + var modelType config.SelectedModelType if m.modelType == LargeModelType { - modelType = config.LargeModel + modelType = config.SelectedModelTypeLarge } else { - modelType = config.SmallModel + modelType = config.SelectedModelTypeSmall } return m, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), util.CmdHandler(ModelSelectedMsg{ - Model: config.PreferredModel{ - ModelID: selectedItem.Model.ID, - Provider: selectedItem.Provider.ID, + Model: config.SelectedModel{ + Model: selectedItem.Model.ID, + Provider: string(selectedItem.Provider.ID), }, ModelType: modelType, }), @@ -218,35 +218,39 @@ func (m *modelDialogCmp) modelTypeRadio() string { func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { m.modelType = modelType - providers := config.Providers() + providers, err := config.Providers() + if err != nil { + return util.ReportError(err) + } + modelItems := []util.Model{} selectIndex := 0 cfg := config.Get() - var currentModel config.PreferredModel + var currentModel config.SelectedModel if m.modelType == LargeModelType { - currentModel = cfg.Models.Large + currentModel = cfg.Models[config.SelectedModelTypeLarge] } else { - currentModel = cfg.Models.Small + currentModel = cfg.Models[config.SelectedModelTypeSmall] } // Create a map to track which providers we've already added - addedProviders := make(map[provider.InferenceProvider]bool) + addedProviders := make(map[string]bool) // First, add any configured providers that are not in the known providers list // These should appear at the top of the list knownProviders := provider.KnownProviders() for providerID, providerConfig := range cfg.Providers { - if providerConfig.Disabled { + if providerConfig.Disable { continue } // Check if this provider is not in the known providers list - if !slices.Contains(knownProviders, providerID) { + if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) { // Convert config provider to provider.Provider format configProvider := provider.Provider{ Name: string(providerID), // Use provider ID as name for unknown providers - ID: providerID, + ID: provider.InferenceProvider(providerID), Models: make([]provider.Model, len(providerConfig.Models)), } @@ -263,7 +267,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { DefaultMaxTokens: model.DefaultMaxTokens, CanReason: model.CanReason, HasReasoningEffort: model.HasReasoningEffort, - DefaultReasoningEffort: model.ReasoningEffort, + DefaultReasoningEffort: model.DefaultReasoningEffort, SupportsImages: model.SupportsImages, } } @@ -279,7 +283,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { Provider: configProvider, Model: model, })) - if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider { + if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider { selectIndex = len(modelItems) - 1 // Set the selected index to the current model } } @@ -290,12 +294,12 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { // Then add the known providers from the predefined list for _, provider := range providers { // Skip if we already added this provider as an unknown provider - if addedProviders[provider.ID] { + if addedProviders[string(provider.ID)] { continue } // Check if this provider is configured and not disabled - if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled { + if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable { continue } @@ -309,7 +313,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { Provider: provider, Model: model, })) - if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider { + if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider { selectIndex = len(modelItems) - 1 // Set the selected index to the current model } } diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 44d623847765175d3c38eb81122fa3d55abc430d..cc72dbc520bb8fe3e5aacc0412cf2e09118b8bd7 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -170,7 +170,8 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { util.CmdHandler(ChatFocusedMsg{Focused: false}), ) case key.Matches(msg, p.keyMap.AddAttachment): - model := config.GetAgentModel(config.AgentCoder) + agentCfg := config.Get().Agents["coder"] + model := config.Get().GetModelByType(agentCfg.Model) if model.SupportsImages { return p, util.CmdHandler(OpenFilePickerMsg{}) } else { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index f7b81b4dd393f10775090066a9969a117ac3f618..fcef527c812e9f5d74801b009050c69daec7fd20 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -177,14 +177,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Update the agent with the new model/provider configuration if err := a.app.UpdateAgentModel(); err != nil { logging.ErrorPersist(fmt.Sprintf("Failed to update agent model: %v", err)) - return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err)) + return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.Model, err)) } modelTypeName := "large" - if msg.ModelType == config.SmallModel { + if msg.ModelType == config.SelectedModelTypeSmall { modelTypeName = "small" } - return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.ModelID)) + return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.Model)) // File Picker case chat.OpenFilePickerMsg: diff --git a/pkg/config/config.go b/pkg/config/config.go deleted file mode 100644 index b2dd6e2ed1b625106e8c6b3727853e2e2d7d6bd5..0000000000000000000000000000000000000000 --- a/pkg/config/config.go +++ /dev/null @@ -1,224 +0,0 @@ -package config - -import ( - "slices" - "strings" - - "github.com/charmbracelet/crush/internal/fur/provider" -) - -const ( - appName = "crush" - defaultDataDirectory = ".crush" - defaultLogLevel = "info" -) - -var defaultContextPaths = []string{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", -} - -type SelectedModelType string - -const ( - SelectedModelTypeLarge SelectedModelType = "large" - SelectedModelTypeSmall SelectedModelType = "small" -) - -type SelectedModel struct { - // The model id as used by the provider API. - // Required. - Model string `json:"model"` - // The model provider, same as the key/id used in the providers config. - // Required. - Provider string `json:"provider"` - - // Only used by models that use the openai provider and need this set. - ReasoningEffort string `json:"reasoning_effort,omitempty"` - - // Overrides the default model configuration. - MaxTokens int64 `json:"max_tokens,omitempty"` - - // Used by anthropic models that can reason to indicate if the model should think. - Think bool `json:"think,omitempty"` -} - -type ProviderConfig struct { - // The provider's id. - ID string `json:"id,omitempty"` - // The provider's API endpoint. - BaseURL string `json:"base_url,omitempty"` - // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai. - Type provider.Type `json:"type,omitempty"` - // The provider's API key. - APIKey string `json:"api_key,omitempty"` - // Marks the provider as disabled. - Disable bool `json:"disable,omitempty"` - - // Extra headers to send with each request to the provider. - ExtraHeaders map[string]string - - // Used to pass extra parameters to the provider. - ExtraParams map[string]string `json:"-"` - - // The provider models - Models []provider.Model `json:"models,omitempty"` -} - -type MCPType string - -const ( - MCPStdio MCPType = "stdio" - MCPSse MCPType = "sse" - MCPHttp MCPType = "http" -) - -type MCPConfig struct { - Command string `json:"command,omitempty" ` - Env []string `json:"env,omitempty"` - Args []string `json:"args,omitempty"` - Type MCPType `json:"type"` - URL string `json:"url,omitempty"` - - // TODO: maybe make it possible to get the value from the env - Headers map[string]string `json:"headers,omitempty"` -} - -type LSPConfig struct { - Disabled bool `json:"enabled,omitempty"` - Command string `json:"command"` - Args []string `json:"args,omitempty"` - Options any `json:"options,omitempty"` -} - -type TUIOptions struct { - CompactMode bool `json:"compact_mode,omitempty"` - // Here we can add themes later or any TUI related options -} - -type Options struct { - ContextPaths []string `json:"context_paths,omitempty"` - TUI *TUIOptions `json:"tui,omitempty"` - Debug bool `json:"debug,omitempty"` - DebugLSP bool `json:"debug_lsp,omitempty"` - DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"` - // Relative to the cwd - DataDirectory string `json:"data_directory,omitempty"` -} - -type MCPs map[string]MCPConfig - -type MCP struct { - Name string `json:"name"` - MCP MCPConfig `json:"mcp"` -} - -func (m MCPs) Sorted() []MCP { - sorted := make([]MCP, 0, len(m)) - for k, v := range m { - sorted = append(sorted, MCP{ - Name: k, - MCP: v, - }) - } - slices.SortFunc(sorted, func(a, b MCP) int { - return strings.Compare(a.Name, b.Name) - }) - return sorted -} - -type LSPs map[string]LSPConfig - -type LSP struct { - Name string `json:"name"` - LSP LSPConfig `json:"lsp"` -} - -func (l LSPs) Sorted() []LSP { - sorted := make([]LSP, 0, len(l)) - for k, v := range l { - sorted = append(sorted, LSP{ - Name: k, - LSP: v, - }) - } - slices.SortFunc(sorted, func(a, b LSP) int { - return strings.Compare(a.Name, b.Name) - }) - return sorted -} - -// Config holds the configuration for crush. -type Config struct { - // We currently only support large/small as values here. - Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` - - // The providers that are configured - Providers map[string]ProviderConfig `json:"providers,omitempty"` - - MCP MCPs `json:"mcp,omitempty"` - - LSP LSPs `json:"lsp,omitempty"` - - Options *Options `json:"options,omitempty"` - - // Internal - workingDir string `json:"-"` -} - -func (c *Config) WorkingDir() string { - return c.workingDir -} - -func (c *Config) EnabledProviders() []ProviderConfig { - enabled := make([]ProviderConfig, 0, len(c.Providers)) - for _, p := range c.Providers { - if !p.Disable { - enabled = append(enabled, p) - } - } - return enabled -} - -// IsConfigured return true if at least one provider is configured -func (c *Config) IsConfigured() bool { - return len(c.EnabledProviders()) > 0 -} - -func (c *Config) GetModel(provider, model string) *provider.Model { - if providerConfig, ok := c.Providers[provider]; ok { - for _, m := range providerConfig.Models { - if m.ID == model { - return &m - } - } - } - return nil -} - -func (c *Config) LargeModel() *provider.Model { - model, ok := c.Models[SelectedModelTypeLarge] - if !ok { - return nil - } - return c.GetModel(model.Provider, model.Model) -} - -func (c *Config) SmallModel() *provider.Model { - model, ok := c.Models[SelectedModelTypeSmall] - if !ok { - return nil - } - return c.GetModel(model.Provider, model.Model) -} diff --git a/pkg/config/provider.go b/pkg/config/provider.go deleted file mode 100644 index 953959dece9e0714c108fc9cff43267fdc2487bc..0000000000000000000000000000000000000000 --- a/pkg/config/provider.go +++ /dev/null @@ -1,93 +0,0 @@ -package config - -import ( - "encoding/json" - "os" - "path/filepath" - "runtime" - "sync" - - "github.com/charmbracelet/crush/internal/fur/provider" -) - -type ProviderClient interface { - GetProviders() ([]provider.Provider, error) -} - -var ( - providerOnce sync.Once - providerList []provider.Provider -) - -// file to cache provider data -func providerCacheFileData() string { - xdgDataHome := os.Getenv("XDG_DATA_HOME") - if xdgDataHome != "" { - return filepath.Join(xdgDataHome, appName) - } - - // return the path to the main data directory - // for windows, it should be in `%LOCALAPPDATA%/crush/` - // for linux and macOS, it should be in `$HOME/.local/share/crush/` - if runtime.GOOS == "windows" { - localAppData := os.Getenv("LOCALAPPDATA") - if localAppData == "" { - localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") - } - return filepath.Join(localAppData, appName) - } - - return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json") -} - -func saveProvidersInCache(path string, providers []provider.Provider) error { - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0o755); err != nil { - return err - } - - data, err := json.MarshalIndent(providers, "", " ") - if err != nil { - return err - } - - return os.WriteFile(path, data, 0o644) -} - -func loadProvidersFromCache(path string) ([]provider.Provider, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - var providers []provider.Provider - err = json.Unmarshal(data, &providers) - return providers, err -} - -func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) { - providers, err := client.GetProviders() - if err != nil { - fallbackToCache, err := loadProvidersFromCache(path) - if err != nil { - return nil, err - } - providers = fallbackToCache - } else { - if err := saveProvidersInCache(path, providerList); err != nil { - return nil, err - } - } - return providers, nil -} - -func LoadProviders(client ProviderClient) ([]provider.Provider, error) { - var err error - providerOnce.Do(func() { - providerList, err = loadProviders(providerCacheFileData(), client) - }) - if err != nil { - return nil, err - } - return providerList, nil -} diff --git a/pkg/config/provider_test.go b/pkg/config/provider_test.go deleted file mode 100644 index a3562838c7103239aa303c906c866220164a4ba0..0000000000000000000000000000000000000000 --- a/pkg/config/provider_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package config - -import ( - "encoding/json" - "errors" - "os" - "testing" - - "github.com/charmbracelet/crush/internal/fur/provider" - "github.com/stretchr/testify/assert" -) - -type mockProviderClient struct { - shouldFail bool -} - -func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) { - if m.shouldFail { - return nil, errors.New("failed to load providers") - } - return []provider.Provider{ - { - Name: "Mock", - }, - }, nil -} - -func TestProvider_loadProvidersNoIssues(t *testing.T) { - client := &mockProviderClient{shouldFail: false} - tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(tmpPath, client) - assert.NoError(t, err) - assert.NotNil(t, providers) - assert.Len(t, providers, 1) - - // check if file got saved - fileInfo, err := os.Stat(tmpPath) - assert.NoError(t, err) - assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory") -} - -func TestProvider_loadProvidersWithIssues(t *testing.T) { - client := &mockProviderClient{shouldFail: true} - tmpPath := t.TempDir() + "/providers.json" - // store providers to a temporary file - oldProviders := []provider.Provider{ - { - Name: "OldProvider", - }, - } - data, err := json.Marshal(oldProviders) - if err != nil { - t.Fatalf("Failed to marshal old providers: %v", err) - } - - err = os.WriteFile(tmpPath, data, 0o644) - if err != nil { - t.Fatalf("Failed to write old providers to file: %v", err) - } - providers, err := loadProviders(tmpPath, client) - assert.NoError(t, err) - assert.NotNil(t, providers) - assert.Len(t, providers, 1) - assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails") -} - -func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { - client := &mockProviderClient{shouldFail: true} - tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(tmpPath, client) - assert.Error(t, err) - assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") -}