cmd/logs.go 🔗
@@ -7,7 +7,7 @@ import (
"slices"
"time"
- "github.com/charmbracelet/crush/pkg/config"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/log/v2"
"github.com/nxadm/tail"
"github.com/spf13/cobra"
Kujtim Hoxha created
cmd/logs.go | 2
cmd/schema/main.go | 155 -
crush-schema.json | 700 -----
crush.json | 1
go.mod | 16
go.sum | 6
internal/app/app.go | 3
internal/app/lsp.go | 4
internal/config/config.go | 1513 +-----------
internal/config/config_test.go | 2075 -----------------
internal/config/fs.go | 71
internal/config/init.go | 48
internal/config/load.go | 35
internal/config/load_test.go | 2
internal/config/merge.go | 0
internal/config/merge_test.go | 0
internal/config/provider.go | 87
internal/config/provider_mock.go | 293 --
internal/config/provider_test.go | 112
internal/config/resolve.go | 2
internal/config/resolve_test.go | 2
internal/config/shell.go | 73
internal/config/validation_test.go | 462 ---
internal/diff/diff.go | 2
internal/env/env.go | 0
internal/env/env_test.go | 0
internal/llm/agent/agent.go | 84
internal/llm/agent/mcp-tools.go | 8
internal/llm/prompt/coder.go | 8
internal/llm/prompt/prompt.go | 11
internal/llm/prompt/summarizer.go | 6
internal/llm/prompt/task.go | 4
internal/llm/prompt/title.go | 6
internal/llm/provider/anthropic.go | 10
internal/llm/provider/bedrock.go | 9
internal/llm/provider/gemini.go | 17
internal/llm/provider/openai.go | 15
internal/llm/provider/provider.go | 22
internal/llm/tools/bash.go | 4
internal/llm/tools/edit.go | 8
internal/llm/tools/fetch.go | 2
internal/llm/tools/glob.go | 2
internal/llm/tools/grep.go | 2
internal/llm/tools/ls.go | 4
internal/llm/tools/view.go | 2
internal/llm/tools/write.go | 4
internal/log/log.go | 0
internal/lsp/client.go | 4
internal/permission/permission.go | 2
internal/tui/components/chat/header/header.go | 5
internal/tui/components/chat/messages/messages.go | 3
internal/tui/components/chat/sidebar/sidebar.go | 7
internal/tui/components/dialogs/models/models.go | 46
internal/tui/page/chat/chat.go | 3
internal/tui/tui.go | 6
pkg/config/config.go | 224 -
pkg/config/provider.go | 93
pkg/config/provider_test.go | 73
58 files changed, 488 insertions(+), 5,870 deletions(-)
@@ -7,7 +7,7 @@ import (
"slices"
"time"
- "github.com/charmbracelet/crush/pkg/config"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/log/v2"
"github.com/nxadm/tail"
"github.com/spf13/cobra"
@@ -1,155 +0,0 @@
-package main
-
-import (
- "encoding/json"
- "fmt"
- "os"
-
- "github.com/charmbracelet/crush/internal/config"
- "github.com/invopop/jsonschema"
-)
-
-func main() {
- // Create a new reflector
- r := &jsonschema.Reflector{
- // Use anonymous schemas to avoid ID conflicts
- Anonymous: true,
- // Expand the root struct instead of referencing it
- ExpandedStruct: true,
- AllowAdditionalProperties: true,
- }
-
- // Generate schema for the main Config struct
- schema := r.Reflect(&config.Config{})
-
- // Enhance the schema with additional information
- enhanceSchema(schema)
-
- // Set the schema metadata
- schema.Version = "https://json-schema.org/draft/2020-12/schema"
- schema.Title = "Crush Configuration"
- schema.Description = "Configuration schema for the Crush application"
-
- // Pretty print the schema
- encoder := json.NewEncoder(os.Stdout)
- encoder.SetIndent("", " ")
- if err := encoder.Encode(schema); err != nil {
- fmt.Fprintf(os.Stderr, "Error encoding schema: %v\n", err)
- os.Exit(1)
- }
-}
-
-// enhanceSchema adds additional enhancements to the generated schema
-func enhanceSchema(schema *jsonschema.Schema) {
- // Add provider enums
- addProviderEnums(schema)
-
- // Add model enums
- addModelEnums(schema)
-
- // Add tool enums
- addToolEnums(schema)
-
- // Add default context paths
- addDefaultContextPaths(schema)
-}
-
-// addProviderEnums adds provider enums to the schema
-func addProviderEnums(schema *jsonschema.Schema) {
- providers := config.Providers()
- var providerIDs []any
- for _, p := range providers {
- providerIDs = append(providerIDs, string(p.ID))
- }
-
- // Add to PreferredModel provider field
- if schema.Definitions != nil {
- if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists {
- if providerProp, exists := preferredModelDef.Properties.Get("provider"); exists {
- providerProp.Enum = providerIDs
- }
- }
-
- // Add to ProviderConfig ID field
- if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists {
- if idProp, exists := providerConfigDef.Properties.Get("id"); exists {
- idProp.Enum = providerIDs
- }
- }
- }
-}
-
-// addModelEnums adds model enums to the schema
-func addModelEnums(schema *jsonschema.Schema) {
- providers := config.Providers()
- var modelIDs []any
- for _, p := range providers {
- for _, m := range p.Models {
- modelIDs = append(modelIDs, m.ID)
- }
- }
-
- // Add to PreferredModel model_id field
- if schema.Definitions != nil {
- if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists {
- if modelIDProp, exists := preferredModelDef.Properties.Get("model_id"); exists {
- modelIDProp.Enum = modelIDs
- }
- }
- }
-}
-
-// addToolEnums adds tool enums to the schema
-func addToolEnums(schema *jsonschema.Schema) {
- tools := []any{
- "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
- }
-
- if schema.Definitions != nil {
- if agentDef, exists := schema.Definitions["Agent"]; exists {
- if allowedToolsProp, exists := agentDef.Properties.Get("allowed_tools"); exists {
- if allowedToolsProp.Items != nil {
- allowedToolsProp.Items.Enum = tools
- }
- }
- }
- }
-}
-
-// addDefaultContextPaths adds default context paths to the schema
-func addDefaultContextPaths(schema *jsonschema.Schema) {
- defaultContextPaths := []any{
- ".github/copilot-instructions.md",
- ".cursorrules",
- ".cursor/rules/",
- "CLAUDE.md",
- "CLAUDE.local.md",
- "GEMINI.md",
- "gemini.md",
- "crush.md",
- "crush.local.md",
- "Crush.md",
- "Crush.local.md",
- "CRUSH.md",
- "CRUSH.local.md",
- }
-
- if schema.Definitions != nil {
- if optionsDef, exists := schema.Definitions["Options"]; exists {
- if contextPathsProp, exists := optionsDef.Properties.Get("context_paths"); exists {
- contextPathsProp.Default = defaultContextPaths
- }
- }
- }
-
- // Also add to root properties if they exist
- if schema.Properties != nil {
- if optionsProp, exists := schema.Properties.Get("options"); exists {
- if optionsProp.Properties != nil {
- if contextPathsProp, exists := optionsProp.Properties.Get("context_paths"); exists {
- contextPathsProp.Default = defaultContextPaths
- }
- }
- }
- }
-}
@@ -1,700 +0,0 @@
-{
- "$schema": "https://json-schema.org/draft/2020-12/schema",
- "$defs": {
- "Agent": {
- "properties": {
- "id": {
- "type": "string",
- "enum": [
- "coder",
- "task",
- "coder",
- "task"
- ],
- "title": "Agent ID",
- "description": "Unique identifier for the agent"
- },
- "name": {
- "type": "string",
- "title": "Name",
- "description": "Display name of the agent"
- },
- "description": {
- "type": "string",
- "title": "Description",
- "description": "Description of what the agent does"
- },
- "disabled": {
- "type": "boolean",
- "title": "Disabled",
- "description": "Whether this agent is disabled",
- "default": false
- },
- "model": {
- "type": "string",
- "enum": [
- "large",
- "small",
- "large",
- "small"
- ],
- "title": "Model Type",
- "description": "Type of model to use (large or small)"
- },
- "allowed_tools": {
- "items": {
- "type": "string",
- "enum": [
- "bash",
- "edit",
- "fetch",
- "glob",
- "grep",
- "ls",
- "sourcegraph",
- "view",
- "write",
- "agent"
- ]
- },
- "type": "array",
- "title": "Allowed Tools",
- "description": "List of tools this agent is allowed to use (if nil all tools are allowed)"
- },
- "allowed_mcp": {
- "additionalProperties": {
- "items": {
- "type": "string"
- },
- "type": "array"
- },
- "type": "object",
- "title": "Allowed MCP",
- "description": "Map of MCP servers this agent can use and their allowed tools"
- },
- "allowed_lsp": {
- "items": {
- "type": "string"
- },
- "type": "array",
- "title": "Allowed LSP",
- "description": "List of LSP servers this agent can use (if nil all LSPs are allowed)"
- },
- "context_paths": {
- "items": {
- "type": "string"
- },
- "type": "array",
- "title": "Context Paths",
- "description": "Custom context paths for this agent (additive to global context paths)"
- }
- },
- "type": "object",
- "required": [
- "model"
- ]
- },
- "LSPConfig": {
- "properties": {
- "enabled": {
- "type": "boolean",
- "title": "Enabled",
- "description": "Whether this LSP server is enabled",
- "default": true
- },
- "command": {
- "type": "string",
- "title": "Command",
- "description": "Command to execute for the LSP server"
- },
- "args": {
- "items": {
- "type": "string"
- },
- "type": "array",
- "title": "Arguments",
- "description": "Command line arguments for the LSP server"
- },
- "options": {
- "title": "Options",
- "description": "LSP server specific options"
- }
- },
- "type": "object",
- "required": [
- "command"
- ]
- },
- "MCP": {
- "properties": {
- "command": {
- "type": "string",
- "title": "Command",
- "description": "Command to execute for stdio MCP servers"
- },
- "env": {
- "items": {
- "type": "string"
- },
- "type": "array",
- "title": "Environment",
- "description": "Environment variables for the MCP server"
- },
- "args": {
- "items": {
- "type": "string"
- },
- "type": "array",
- "title": "Arguments",
- "description": "Command line arguments for the MCP server"
- },
- "type": {
- "type": "string",
- "enum": [
- "stdio",
- "sse",
- "stdio",
- "sse",
- "http"
- ],
- "title": "Type",
- "description": "Type of MCP connection",
- "default": "stdio"
- },
- "url": {
- "type": "string",
- "title": "URL",
- "description": "URL for SSE MCP servers"
- },
- "headers": {
- "additionalProperties": {
- "type": "string"
- },
- "type": "object",
- "title": "Headers",
- "description": "HTTP headers for SSE MCP servers"
- }
- },
- "type": "object",
- "required": [
- "type"
- ]
- },
- "Model": {
- "properties": {
- "id": {
- "type": "string",
- "title": "Model ID",
- "description": "Unique identifier for the model"
- },
- "name": {
- "type": "string",
- "title": "Model Name",
- "description": "Display name of the model"
- },
- "cost_per_1m_in": {
- "type": "number",
- "minimum": 0,
- "title": "Input Cost",
- "description": "Cost per 1 million input tokens"
- },
- "cost_per_1m_out": {
- "type": "number",
- "minimum": 0,
- "title": "Output Cost",
- "description": "Cost per 1 million output tokens"
- },
- "cost_per_1m_in_cached": {
- "type": "number",
- "minimum": 0,
- "title": "Cached Input Cost",
- "description": "Cost per 1 million cached input tokens"
- },
- "cost_per_1m_out_cached": {
- "type": "number",
- "minimum": 0,
- "title": "Cached Output Cost",
- "description": "Cost per 1 million cached output tokens"
- },
- "context_window": {
- "type": "integer",
- "minimum": 1,
- "title": "Context Window",
- "description": "Maximum context window size in tokens"
- },
- "default_max_tokens": {
- "type": "integer",
- "minimum": 1,
- "title": "Default Max Tokens",
- "description": "Default maximum tokens for responses"
- },
- "can_reason": {
- "type": "boolean",
- "title": "Can Reason",
- "description": "Whether the model supports reasoning capabilities"
- },
- "reasoning_effort": {
- "type": "string",
- "title": "Reasoning Effort",
- "description": "Default reasoning effort level for reasoning models"
- },
- "has_reasoning_effort": {
- "type": "boolean",
- "title": "Has Reasoning Effort",
- "description": "Whether the model supports reasoning effort configuration"
- },
- "supports_attachments": {
- "type": "boolean",
- "title": "Supports Images",
- "description": "Whether the model supports image attachments"
- }
- },
- "type": "object",
- "required": [
- "id",
- "name",
- "context_window",
- "default_max_tokens"
- ]
- },
- "Options": {
- "properties": {
- "context_paths": {
- "items": {
- "type": "string"
- },
- "type": "array",
- "title": "Context Paths",
- "description": "List of paths to search for context files",
- "default": [
- ".github/copilot-instructions.md",
- ".cursorrules",
- ".cursor/rules/",
- "CLAUDE.md",
- "CLAUDE.local.md",
- "GEMINI.md",
- "gemini.md",
- "crush.md",
- "crush.local.md",
- "Crush.md",
- "Crush.local.md",
- "CRUSH.md",
- "CRUSH.local.md"
- ]
- },
- "tui": {
- "$ref": "#/$defs/TUIOptions",
- "title": "TUI Options",
- "description": "Terminal UI configuration options"
- },
- "debug": {
- "type": "boolean",
- "title": "Debug",
- "description": "Enable debug logging",
- "default": false
- },
- "debug_lsp": {
- "type": "boolean",
- "title": "Debug LSP",
- "description": "Enable LSP debug logging",
- "default": false
- },
- "disable_auto_summarize": {
- "type": "boolean",
- "title": "Disable Auto Summarize",
- "description": "Disable automatic conversation summarization",
- "default": false
- },
- "data_directory": {
- "type": "string",
- "title": "Data Directory",
- "description": "Directory for storing application data",
- "default": ".crush"
- }
- },
- "type": "object"
- },
- "PreferredModel": {
- "properties": {
- "model_id": {
- "type": "string",
- "enum": [
- "claude-opus-4-20250514",
- "claude-sonnet-4-20250514",
- "claude-3-7-sonnet-20250219",
- "claude-3-5-haiku-20241022",
- "claude-3-5-sonnet-20240620",
- "claude-3-5-sonnet-20241022",
- "codex-mini-latest",
- "o4-mini",
- "o3",
- "o3-pro",
- "gpt-4.1",
- "gpt-4.1-mini",
- "gpt-4.1-nano",
- "gpt-4.5-preview",
- "o3-mini",
- "gpt-4o",
- "gpt-4o-mini",
- "gemini-2.5-pro",
- "gemini-2.5-flash",
- "codex-mini-latest",
- "o4-mini",
- "o3",
- "o3-pro",
- "gpt-4.1",
- "gpt-4.1-mini",
- "gpt-4.1-nano",
- "gpt-4.5-preview",
- "o3-mini",
- "gpt-4o",
- "gpt-4o-mini",
- "anthropic.claude-opus-4-20250514-v1:0",
- "anthropic.claude-sonnet-4-20250514-v1:0",
- "anthropic.claude-3-7-sonnet-20250219-v1:0",
- "anthropic.claude-3-5-haiku-20241022-v1:0",
- "gemini-2.5-pro",
- "gemini-2.5-flash",
- "grok-3-mini",
- "grok-3",
- "mistralai/mistral-small-3.2-24b-instruct:free",
- "mistralai/mistral-small-3.2-24b-instruct",
- "minimax/minimax-m1:extended",
- "minimax/minimax-m1",
- "google/gemini-2.5-flash-lite-preview-06-17",
- "google/gemini-2.5-flash",
- "google/gemini-2.5-pro",
- "openai/o3-pro",
- "x-ai/grok-3-mini",
- "x-ai/grok-3",
- "mistralai/magistral-small-2506",
- "mistralai/magistral-medium-2506",
- "mistralai/magistral-medium-2506:thinking",
- "google/gemini-2.5-pro-preview",
- "deepseek/deepseek-r1-0528",
- "anthropic/claude-opus-4",
- "anthropic/claude-sonnet-4",
- "mistralai/devstral-small:free",
- "mistralai/devstral-small",
- "google/gemini-2.5-flash-preview-05-20",
- "google/gemini-2.5-flash-preview-05-20:thinking",
- "openai/codex-mini",
- "mistralai/mistral-medium-3",
- "google/gemini-2.5-pro-preview-05-06",
- "arcee-ai/caller-large",
- "arcee-ai/virtuoso-large",
- "arcee-ai/virtuoso-medium-v2",
- "qwen/qwen3-30b-a3b",
- "qwen/qwen3-14b",
- "qwen/qwen3-32b",
- "qwen/qwen3-235b-a22b",
- "google/gemini-2.5-flash-preview",
- "google/gemini-2.5-flash-preview:thinking",
- "openai/o4-mini-high",
- "openai/o3",
- "openai/o4-mini",
- "openai/gpt-4.1",
- "openai/gpt-4.1-mini",
- "openai/gpt-4.1-nano",
- "x-ai/grok-3-mini-beta",
- "x-ai/grok-3-beta",
- "meta-llama/llama-4-maverick",
- "meta-llama/llama-4-scout",
- "all-hands/openhands-lm-32b-v0.1",
- "google/gemini-2.5-pro-exp-03-25",
- "deepseek/deepseek-chat-v3-0324:free",
- "deepseek/deepseek-chat-v3-0324",
- "mistralai/mistral-small-3.1-24b-instruct:free",
- "mistralai/mistral-small-3.1-24b-instruct",
- "ai21/jamba-1.6-large",
- "ai21/jamba-1.6-mini",
- "openai/gpt-4.5-preview",
- "google/gemini-2.0-flash-lite-001",
- "anthropic/claude-3.7-sonnet",
- "anthropic/claude-3.7-sonnet:beta",
- "anthropic/claude-3.7-sonnet:thinking",
- "mistralai/mistral-saba",
- "openai/o3-mini-high",
- "google/gemini-2.0-flash-001",
- "qwen/qwen-turbo",
- "qwen/qwen-plus",
- "qwen/qwen-max",
- "openai/o3-mini",
- "mistralai/mistral-small-24b-instruct-2501",
- "deepseek/deepseek-r1-distill-llama-70b",
- "deepseek/deepseek-r1",
- "mistralai/codestral-2501",
- "deepseek/deepseek-chat",
- "openai/o1",
- "x-ai/grok-2-1212",
- "meta-llama/llama-3.3-70b-instruct",
- "amazon/nova-lite-v1",
- "amazon/nova-micro-v1",
- "amazon/nova-pro-v1",
- "openai/gpt-4o-2024-11-20",
- "mistralai/mistral-large-2411",
- "mistralai/mistral-large-2407",
- "mistralai/pixtral-large-2411",
- "thedrummer/unslopnemo-12b",
- "anthropic/claude-3.5-haiku:beta",
- "anthropic/claude-3.5-haiku",
- "anthropic/claude-3.5-haiku-20241022:beta",
- "anthropic/claude-3.5-haiku-20241022",
- "anthropic/claude-3.5-sonnet:beta",
- "anthropic/claude-3.5-sonnet",
- "x-ai/grok-beta",
- "mistralai/ministral-8b",
- "mistralai/ministral-3b",
- "nvidia/llama-3.1-nemotron-70b-instruct",
- "google/gemini-flash-1.5-8b",
- "meta-llama/llama-3.2-11b-vision-instruct",
- "meta-llama/llama-3.2-3b-instruct",
- "qwen/qwen-2.5-72b-instruct",
- "mistralai/pixtral-12b",
- "cohere/command-r-plus-08-2024",
- "cohere/command-r-08-2024",
- "microsoft/phi-3.5-mini-128k-instruct",
- "nousresearch/hermes-3-llama-3.1-70b",
- "openai/gpt-4o-2024-08-06",
- "meta-llama/llama-3.1-405b-instruct",
- "meta-llama/llama-3.1-70b-instruct",
- "meta-llama/llama-3.1-8b-instruct",
- "mistralai/mistral-nemo",
- "openai/gpt-4o-mini",
- "openai/gpt-4o-mini-2024-07-18",
- "anthropic/claude-3.5-sonnet-20240620:beta",
- "anthropic/claude-3.5-sonnet-20240620",
- "mistralai/mistral-7b-instruct-v0.3",
- "mistralai/mistral-7b-instruct:free",
- "mistralai/mistral-7b-instruct",
- "microsoft/phi-3-mini-128k-instruct",
- "microsoft/phi-3-medium-128k-instruct",
- "google/gemini-flash-1.5",
- "openai/gpt-4o-2024-05-13",
- "openai/gpt-4o",
- "openai/gpt-4o:extended",
- "meta-llama/llama-3-8b-instruct",
- "meta-llama/llama-3-70b-instruct",
- "mistralai/mixtral-8x22b-instruct",
- "openai/gpt-4-turbo",
- "google/gemini-pro-1.5",
- "cohere/command-r-plus",
- "cohere/command-r-plus-04-2024",
- "cohere/command-r",
- "anthropic/claude-3-haiku:beta",
- "anthropic/claude-3-haiku",
- "anthropic/claude-3-opus:beta",
- "anthropic/claude-3-opus",
- "anthropic/claude-3-sonnet:beta",
- "anthropic/claude-3-sonnet",
- "cohere/command-r-03-2024",
- "mistralai/mistral-large",
- "openai/gpt-3.5-turbo-0613",
- "openai/gpt-4-turbo-preview",
- "mistralai/mistral-small",
- "mistralai/mistral-tiny",
- "mistralai/mixtral-8x7b-instruct",
- "openai/gpt-4-1106-preview",
- "mistralai/mistral-7b-instruct-v0.1",
- "openai/gpt-3.5-turbo-16k",
- "openai/gpt-4",
- "openai/gpt-4-0314"
- ],
- "title": "Model ID",
- "description": "ID of the preferred model"
- },
- "provider": {
- "type": "string",
- "enum": [
- "anthropic",
- "openai",
- "gemini",
- "azure",
- "bedrock",
- "vertex",
- "xai",
- "openrouter"
- ],
- "title": "Provider",
- "description": "Provider for the preferred model"
- },
- "reasoning_effort": {
- "type": "string",
- "title": "Reasoning Effort",
- "description": "Override reasoning effort for this model"
- },
- "max_tokens": {
- "type": "integer",
- "minimum": 1,
- "title": "Max Tokens",
- "description": "Override max tokens for this model"
- },
- "think": {
- "type": "boolean",
- "title": "Think",
- "description": "Enable thinking for reasoning models",
- "default": false
- }
- },
- "type": "object",
- "required": [
- "model_id",
- "provider"
- ]
- },
- "PreferredModels": {
- "properties": {
- "large": {
- "$ref": "#/$defs/PreferredModel",
- "title": "Large Model",
- "description": "Preferred model configuration for large model type"
- },
- "small": {
- "$ref": "#/$defs/PreferredModel",
- "title": "Small Model",
- "description": "Preferred model configuration for small model type"
- }
- },
- "type": "object"
- },
- "ProviderConfig": {
- "properties": {
- "id": {
- "type": "string",
- "enum": [
- "anthropic",
- "openai",
- "gemini",
- "azure",
- "bedrock",
- "vertex",
- "xai",
- "openrouter"
- ],
- "title": "Provider ID",
- "description": "Unique identifier for the provider"
- },
- "base_url": {
- "type": "string",
- "title": "Base URL",
- "description": "Base URL for the provider API (required for custom providers)"
- },
- "provider_type": {
- "type": "string",
- "title": "Provider Type",
- "description": "Type of the provider (openai"
- },
- "api_key": {
- "type": "string",
- "title": "API Key",
- "description": "API key for authenticating with the provider"
- },
- "disabled": {
- "type": "boolean",
- "title": "Disabled",
- "description": "Whether this provider is disabled",
- "default": false
- },
- "extra_headers": {
- "additionalProperties": {
- "type": "string"
- },
- "type": "object",
- "title": "Extra Headers",
- "description": "Additional HTTP headers to send with requests"
- },
- "extra_params": {
- "additionalProperties": {
- "type": "string"
- },
- "type": "object",
- "title": "Extra Parameters",
- "description": "Additional provider-specific parameters"
- },
- "default_large_model": {
- "type": "string",
- "title": "Default Large Model",
- "description": "Default model ID for large model type"
- },
- "default_small_model": {
- "type": "string",
- "title": "Default Small Model",
- "description": "Default model ID for small model type"
- },
- "models": {
- "items": {
- "$ref": "#/$defs/Model"
- },
- "type": "array",
- "title": "Models",
- "description": "List of available models for this provider"
- }
- },
- "type": "object",
- "required": [
- "provider_type"
- ]
- },
- "TUIOptions": {
- "properties": {
- "compact_mode": {
- "type": "boolean",
- "title": "Compact Mode",
- "description": "Enable compact mode for the TUI",
- "default": false
- }
- },
- "type": "object",
- "required": [
- "compact_mode"
- ]
- }
- },
- "properties": {
- "models": {
- "$ref": "#/$defs/PreferredModels",
- "title": "Models",
- "description": "Preferred model configurations for large and small model types"
- },
- "providers": {
- "additionalProperties": {
- "$ref": "#/$defs/ProviderConfig"
- },
- "type": "object",
- "title": "Providers",
- "description": "LLM provider configurations"
- },
- "agents": {
- "additionalProperties": {
- "$ref": "#/$defs/Agent"
- },
- "type": "object",
- "title": "Agents",
- "description": "Agent configurations for different tasks"
- },
- "mcp": {
- "additionalProperties": {
- "$ref": "#/$defs/MCP"
- },
- "type": "object",
- "title": "MCP",
- "description": "Model Control Protocol server configurations"
- },
- "lsp": {
- "additionalProperties": {
- "$ref": "#/$defs/LSPConfig"
- },
- "type": "object",
- "title": "LSP",
- "description": "Language Server Protocol configurations"
- },
- "options": {
- "$ref": "#/$defs/Options",
- "title": "Options",
- "description": "General application options and settings"
- }
- },
- "type": "object",
- "title": "Crush Configuration",
- "description": "Configuration schema for the Crush application"
-}
@@ -1,5 +1,4 @@
{
- "$schema": "./crush-schema.json",
"lsp": {
"go": {
"command": "gopls"
@@ -17,6 +17,7 @@ require (
github.com/charmbracelet/fang v0.1.0
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71
+ github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413
github.com/charmbracelet/x/exp/charmtone v0.0.0-20250627134340-c144409e381c
github.com/charmbracelet/x/exp/golden v0.0.0-20250207160936-21c02780d27a
@@ -25,31 +26,28 @@ require (
github.com/go-logfmt/logfmt v0.6.0
github.com/google/uuid v1.6.0
github.com/invopop/jsonschema v0.13.0
+ github.com/joho/godotenv v1.5.1
github.com/mark3labs/mcp-go v0.32.0
github.com/muesli/termenv v0.16.0
github.com/ncruces/go-sqlite3 v0.25.0
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
+ github.com/nxadm/tail v1.4.11
github.com/openai/openai-go v1.8.2
github.com/pressly/goose/v3 v3.24.2
+ github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06
github.com/sahilm/fuzzy v0.1.1
github.com/spf13/cobra v1.9.1
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef
github.com/stretchr/testify v1.10.0
+ golang.org/x/exp v0.0.0-20250305212735-054e65f0b394
+ gopkg.in/natefinch/lumberjack.v2 v2.2.1
mvdan.cc/sh/v3 v3.11.0
)
require (
- github.com/charmbracelet/lipgloss v1.1.0 // indirect
- github.com/charmbracelet/log v0.4.2 // indirect
- github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 // indirect
- github.com/joho/godotenv v1.5.1 // indirect
- github.com/nxadm/tail v1.4.11 // indirect
- github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c // indirect
github.com/spf13/cast v1.7.1 // indirect
- golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
- gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
)
@@ -84,7 +82,7 @@ require (
github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250611152503-f53cdd7e01ef
github.com/charmbracelet/x/input v0.3.5-0.20250509021451-13796e822d86 // indirect
- github.com/charmbracelet/x/term v0.2.1 // indirect
+ github.com/charmbracelet/x/term v0.2.1
github.com/charmbracelet/x/windows v0.2.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/disintegration/gift v1.1.2 // indirect
@@ -82,14 +82,8 @@ github.com/charmbracelet/fang v0.1.0 h1:SlZS2crf3/zQh7Mr4+W+7QR1k+L08rrPX5rm5z3d
github.com/charmbracelet/fang v0.1.0/go.mod h1:Zl/zeUQ8EtQuGyiV0ZKZlZPDowKRTzu8s/367EpN/fc=
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe h1:i6ce4CcAlPpTj2ER69m1DBeLZ3RRcHnKExuwhKa3GfY=
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe/go.mod h1:p3Q+aN4eQKeM5jhrmXPMgPrlKbmc59rWSnMsSA3udhk=
-github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
-github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
-github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c h1:177KMz8zHRlEZJsWzafbKYh6OdjgvTspoH+UjaxgIXY=
-github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c/go.mod h1:EJWvaCrhOhNGVZMvcjc0yVryl4qqpMs8tz0r9WyEkdQ=
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 h1:X0tsNa2UHCKNw+illiavosasVzqioRo32SRV35iwr2I=
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71/go.mod h1:EJWvaCrhOhNGVZMvcjc0yVryl4qqpMs8tz0r9WyEkdQ=
-github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
-github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 h1:WkwO6Ks3mSIGnGuSdKl9qDSyfbYK50z2wc2gGMggegE=
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706/go.mod h1:mjJGp00cxcfvD5xdCa+bso251Jt4owrQvuimJtVmEmM=
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 h1:L07QkDqRF274IZ2UJ/mCTL8DR95efU9BNWLYCDXEjvQ=
@@ -57,7 +57,8 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
cfg := config.Get()
- coderAgentCfg := cfg.Agents[config.AgentCoder]
+ // TODO: remove the concept of agent config most likely
+ coderAgentCfg := cfg.Agents["coder"]
if coderAgentCfg.ID == "" {
return nil, fmt.Errorf("coder agent configuration is missing")
}
@@ -38,7 +38,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
defer cancel()
// Initialize with the initialization context
- _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory())
+ _, err = lspClient.InitializeLSPClient(initCtx, config.Get().WorkingDir())
if err != nil {
logging.Error("Initialize failed", "name", name, "error", err)
// Clean up the client to prevent resource leaks
@@ -91,7 +91,7 @@ func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceW
app.restartLSPClient(ctx, name)
})
- workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
+ workspaceWatcher.WatchWorkspace(ctx, config.Get().WorkingDir())
logging.Info("Workspace watcher stopped", "client", name)
}
@@ -1,28 +1,17 @@
package config
import (
- "encoding/json"
- "errors"
"fmt"
- "log/slog"
- "maps"
- "os"
- "path/filepath"
"slices"
"strings"
- "sync"
"github.com/charmbracelet/crush/internal/fur/provider"
- "github.com/charmbracelet/crush/internal/logging"
- "github.com/invopop/jsonschema"
)
const (
+ appName = "crush"
defaultDataDirectory = ".crush"
defaultLogLevel = "info"
- appName = "crush"
-
- MaxTokensFallbackDefault = 4096
)
var defaultContextPaths = []string{
@@ -41,82 +30,51 @@ var defaultContextPaths = []string{
"CRUSH.local.md",
}
-type AgentID string
+type SelectedModelType string
const (
- AgentCoder AgentID = "coder"
- AgentTask AgentID = "task"
+ SelectedModelTypeLarge SelectedModelType = "large"
+ SelectedModelTypeSmall SelectedModelType = "small"
)
-type ModelType string
+type SelectedModel struct {
+ // The model id as used by the provider API.
+ // Required.
+ Model string `json:"model"`
+ // The model provider, same as the key/id used in the providers config.
+ // Required.
+ Provider string `json:"provider"`
-const (
- LargeModel ModelType = "large"
- SmallModel ModelType = "small"
-)
+ // Only used by models that use the openai provider and need this set.
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
-type Model struct {
- ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model, the API model"`
- Name string `json:"name" jsonschema:"title=Model Name,description=Display name of the model"`
- CostPer1MIn float64 `json:"cost_per_1m_in,omitempty" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"`
- CostPer1MOut float64 `json:"cost_per_1m_out,omitempty" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"`
- CostPer1MInCached float64 `json:"cost_per_1m_in_cached,omitempty" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"`
- CostPer1MOutCached float64 `json:"cost_per_1m_out_cached,omitempty" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"`
- ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"`
- DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"`
- CanReason bool `json:"can_reason,omitempty" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"`
- ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"`
- HasReasoningEffort bool `json:"has_reasoning_effort,omitempty" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"`
- SupportsImages bool `json:"supports_attachments,omitempty" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"`
-}
+ // Overrides the default model configuration.
+ MaxTokens int64 `json:"max_tokens,omitempty"`
-type VertexAIOptions struct {
- APIKey string `json:"api_key,omitempty"`
- Project string `json:"project,omitempty"`
- Location string `json:"location,omitempty"`
+ // Used by anthropic models that can reason to indicate if the model should think.
+ Think bool `json:"think,omitempty"`
}
type ProviderConfig struct {
- ID provider.InferenceProvider `json:"id,omitempty" jsonschema:"title=Provider ID,description=Unique identifier for the provider"`
- BaseURL string `json:"base_url,omitempty" jsonschema:"title=Base URL,description=Base URL for the provider API (required for custom providers)"`
- ProviderType provider.Type `json:"provider_type" jsonschema:"title=Provider Type,description=Type of the provider (openai, anthropic, etc.)"`
- APIKey string `json:"api_key,omitempty" jsonschema:"title=API Key,description=API key for authenticating with the provider"`
- Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this provider is disabled,default=false"`
- ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"title=Extra Headers,description=Additional HTTP headers to send with requests"`
- // used for e.x for vertex to set the project
- ExtraParams map[string]string `json:"extra_params,omitempty" jsonschema:"title=Extra Parameters,description=Additional provider-specific parameters"`
-
- DefaultLargeModel string `json:"default_large_model,omitempty" jsonschema:"title=Default Large Model,description=Default model ID for large model type"`
- DefaultSmallModel string `json:"default_small_model,omitempty" jsonschema:"title=Default Small Model,description=Default model ID for small model type"`
-
- Models []Model `json:"models,omitempty" jsonschema:"title=Models,description=List of available models for this provider"`
-}
-
-type Agent struct {
- ID AgentID `json:"id,omitempty" jsonschema:"title=Agent ID,description=Unique identifier for the agent,enum=coder,enum=task"`
- Name string `json:"name,omitempty" jsonschema:"title=Name,description=Display name of the agent"`
- Description string `json:"description,omitempty" jsonschema:"title=Description,description=Description of what the agent does"`
- // This is the id of the system prompt used by the agent
- Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this agent is disabled,default=false"`
-
- Model ModelType `json:"model" jsonschema:"title=Model Type,description=Type of model to use (large or small),enum=large,enum=small"`
+ // The provider's id.
+ ID string `json:"id,omitempty"`
+ // The provider's API endpoint.
+ BaseURL string `json:"base_url,omitempty"`
+ // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
+ Type provider.Type `json:"type,omitempty"`
+ // The provider's API key.
+ APIKey string `json:"api_key,omitempty"`
+ // Marks the provider as disabled.
+ Disable bool `json:"disable,omitempty"`
- // The available tools for the agent
- // if this is nil, all tools are available
- AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"title=Allowed Tools,description=List of tools this agent is allowed to use (if nil all tools are allowed)"`
+ // Extra headers to send with each request to the provider.
+ ExtraHeaders map[string]string
- // this tells us which MCPs are available for this agent
- // if this is empty all mcps are available
- // the string array is the list of tools from the AllowedMCP the agent has available
- // if the string array is nil, all tools from the AllowedMCP are available
- AllowedMCP map[string][]string `json:"allowed_mcp,omitempty" jsonschema:"title=Allowed MCP,description=Map of MCP servers this agent can use and their allowed tools"`
+ // Used to pass extra parameters to the provider.
+ ExtraParams map[string]string `json:"-"`
- // The list of LSPs that this agent can use
- // if this is nil, all LSPs are available
- AllowedLSP []string `json:"allowed_lsp,omitempty" jsonschema:"title=Allowed LSP,description=List of LSP servers this agent can use (if nil all LSPs are allowed)"`
-
- // Overrides the context paths for this agent
- ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=Custom context paths for this agent (additive to global context paths)"`
+ // The provider models
+ Models []provider.Model `json:"models,omitempty"`
}
type MCPType string
@@ -127,1358 +85,205 @@ const (
MCPHttp MCPType = "http"
)
-type MCP struct {
- Command string `json:"command,omitempty" jsonschema:"title=Command,description=Command to execute for stdio MCP servers"`
- Env []string `json:"env,omitempty" jsonschema:"title=Environment,description=Environment variables for the MCP server"`
- Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the MCP server"`
- Type MCPType `json:"type" jsonschema:"title=Type,description=Type of MCP connection,enum=stdio,enum=sse,enum=http,default=stdio"`
- URL string `json:"url,omitempty" jsonschema:"title=URL,description=URL for SSE MCP servers"`
+type MCPConfig struct {
+ Command string `json:"command,omitempty" `
+ Env []string `json:"env,omitempty"`
+ Args []string `json:"args,omitempty"`
+ Type MCPType `json:"type"`
+ URL string `json:"url,omitempty"`
+
// TODO: maybe make it possible to get the value from the env
- Headers map[string]string `json:"headers,omitempty" jsonschema:"title=Headers,description=HTTP headers for SSE MCP servers"`
+ Headers map[string]string `json:"headers,omitempty"`
}
type LSPConfig struct {
- Disabled bool `json:"enabled,omitempty" jsonschema:"title=Enabled,description=Whether this LSP server is enabled,default=true"`
- Command string `json:"command" jsonschema:"title=Command,description=Command to execute for the LSP server"`
- Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the LSP server"`
- Options any `json:"options,omitempty" jsonschema:"title=Options,description=LSP server specific options"`
+ Disabled bool `json:"enabled,omitempty"`
+ Command string `json:"command"`
+ Args []string `json:"args,omitempty"`
+ Options any `json:"options,omitempty"`
}
type TUIOptions struct {
- CompactMode bool `json:"compact_mode" jsonschema:"title=Compact Mode,description=Enable compact mode for the TUI,default=false"`
+ CompactMode bool `json:"compact_mode,omitempty"`
// Here we can add themes later or any TUI related options
}
type Options struct {
- ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=List of paths to search for context files"`
- TUI TUIOptions `json:"tui,omitempty" jsonschema:"title=TUI Options,description=Terminal UI configuration options"`
- Debug bool `json:"debug,omitempty" jsonschema:"title=Debug,description=Enable debug logging,default=false"`
- DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"title=Debug LSP,description=Enable LSP debug logging,default=false"`
- DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"title=Disable Auto Summarize,description=Disable automatic conversation summarization,default=false"`
+ ContextPaths []string `json:"context_paths,omitempty"`
+ TUI *TUIOptions `json:"tui,omitempty"`
+ Debug bool `json:"debug,omitempty"`
+ DebugLSP bool `json:"debug_lsp,omitempty"`
+ DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
// Relative to the cwd
- DataDirectory string `json:"data_directory,omitempty" jsonschema:"title=Data Directory,description=Directory for storing application data,default=.crush"`
+ DataDirectory string `json:"data_directory,omitempty"`
}
-type PreferredModel struct {
- ModelID string `json:"model_id" jsonschema:"title=Model ID,description=ID of the preferred model"`
- Provider provider.InferenceProvider `json:"provider" jsonschema:"title=Provider,description=Provider for the preferred model"`
- // ReasoningEffort overrides the default reasoning effort for this model
- ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Override reasoning effort for this model"`
- // MaxTokens overrides the default max tokens for this model
- MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"title=Max Tokens,description=Override max tokens for this model,minimum=1"`
+type MCPs map[string]MCPConfig
- // Think indicates if the model should think, only applicable for anthropic reasoning models
- Think bool `json:"think,omitempty" jsonschema:"title=Think,description=Enable thinking for reasoning models,default=false"`
-}
-
-type PreferredModels struct {
- Large PreferredModel `json:"large,omitempty" jsonschema:"title=Large Model,description=Preferred model configuration for large model type"`
- Small PreferredModel `json:"small,omitempty" jsonschema:"title=Small Model,description=Preferred model configuration for small model type"`
-}
-
-type Config struct {
- Models PreferredModels `json:"models,omitempty" jsonschema:"title=Models,description=Preferred model configurations for large and small model types"`
- // List of configured providers
- Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty" jsonschema:"title=Providers,description=LLM provider configurations"`
-
- // List of configured agents
- Agents map[AgentID]Agent `json:"agents,omitempty" jsonschema:"title=Agents,description=Agent configurations for different tasks"`
-
- // List of configured MCPs
- MCP map[string]MCP `json:"mcp,omitempty" jsonschema:"title=MCP,description=Model Control Protocol server configurations"`
-
- // List of configured LSPs
- LSP map[string]LSPConfig `json:"lsp,omitempty" jsonschema:"title=LSP,description=Language Server Protocol configurations"`
-
- // Miscellaneous options
- Options Options `json:"options,omitempty" jsonschema:"title=Options,description=General application options and settings"`
-}
-
-var (
- instance *Config // The single instance of the Singleton
- cwd string
- once sync.Once // Ensures the initialization happens only once
-
-)
-
-func readConfigFile(path string) (*Config, error) {
- var cfg *Config
- if _, err := os.Stat(path); err != nil && !os.IsNotExist(err) {
- // some other error occurred while checking the file
- return nil, err
- } else if err == nil {
- // config file exists, read it
- file, err := os.ReadFile(path)
- if err != nil {
- return nil, err
- }
- cfg = &Config{}
- if err := json.Unmarshal(file, cfg); err != nil {
- return nil, err
- }
- } else {
- // config file does not exist, create a new one
- cfg = &Config{}
- }
- return cfg, nil
+type MCP struct {
+ Name string `json:"name"`
+ MCP MCPConfig `json:"mcp"`
}
-func loadConfig(cwd string, debug bool) (*Config, error) {
- // First read the global config file
- cfgPath := ConfigPath()
-
- cfg := defaultConfigBasedOnEnv()
- cfg.Options.Debug = debug
- defaultLevel := slog.LevelInfo
- if cfg.Options.Debug {
- defaultLevel = slog.LevelDebug
+func (m MCPs) Sorted() []MCP {
+ sorted := make([]MCP, 0, len(m))
+ for k, v := range m {
+ sorted = append(sorted, MCP{
+ Name: k,
+ MCP: v,
+ })
}
- if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
- loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
-
- // if file does not exist create it
- if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
- if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
- return cfg, fmt.Errorf("failed to create directory: %w", err)
- }
- if _, err := os.Create(loggingFile); err != nil {
- return cfg, fmt.Errorf("failed to create log file: %w", err)
- }
- }
-
- messagesPath := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "messages")
-
- if _, err := os.Stat(messagesPath); os.IsNotExist(err) {
- if err := os.MkdirAll(messagesPath, 0o756); err != nil {
- return cfg, fmt.Errorf("failed to create directory: %w", err)
- }
- }
- logging.MessageDir = messagesPath
-
- sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
- if err != nil {
- return cfg, fmt.Errorf("failed to open log file: %w", err)
- }
- // Configure logger
- logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
- Level: defaultLevel,
- }))
- slog.SetDefault(logger)
- } else {
- // Configure logger
- logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
- Level: defaultLevel,
- }))
- slog.SetDefault(logger)
- }
-
- priorityOrderedConfigFiles := []string{
- cfgPath, // Global config file
- filepath.Join(cwd, "crush.json"), // Local config file
- filepath.Join(cwd, ".crush.json"), // Local config file
- }
-
- configs := make([]*Config, 0)
- for _, path := range priorityOrderedConfigFiles {
- localConfig, err := readConfigFile(path)
- if err != nil {
- return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
- }
- if localConfig != nil {
- // If the config file was read successfully, add it to the list
- configs = append(configs, localConfig)
- }
- }
-
- // merge options
- mergeOptions(cfg, configs...)
-
- mergeProviderConfigs(cfg, configs...)
- // no providers found the app is not initialized yet
- if len(cfg.Providers) == 0 {
- return cfg, nil
- }
- preferredProvider := getPreferredProvider(cfg.Providers)
- if preferredProvider != nil {
- cfg.Models = PreferredModels{
- Large: PreferredModel{
- ModelID: preferredProvider.DefaultLargeModel,
- Provider: preferredProvider.ID,
- },
- Small: PreferredModel{
- ModelID: preferredProvider.DefaultSmallModel,
- Provider: preferredProvider.ID,
- },
- }
- } else {
- // No valid providers found, set empty models
- cfg.Models = PreferredModels{}
- }
-
- mergeModels(cfg, configs...)
-
- agents := map[AgentID]Agent{
- AgentCoder: {
- ID: AgentCoder,
- Name: "Coder",
- Description: "An agent that helps with executing coding tasks.",
- Model: LargeModel,
- ContextPaths: cfg.Options.ContextPaths,
- // All tools allowed
- },
- AgentTask: {
- ID: AgentTask,
- Name: "Task",
- Description: "An agent that helps with searching for context and finding implementation details.",
- Model: LargeModel,
- ContextPaths: cfg.Options.ContextPaths,
- AllowedTools: []string{
- "glob",
- "grep",
- "ls",
- "sourcegraph",
- "view",
- },
- // NO MCPs or LSPs by default
- AllowedMCP: map[string][]string{},
- AllowedLSP: []string{},
- },
- }
- cfg.Agents = agents
- mergeAgents(cfg, configs...)
- mergeMCPs(cfg, configs...)
- mergeLSPs(cfg, configs...)
-
- // Validate the final configuration
- if err := cfg.Validate(); err != nil {
- return cfg, fmt.Errorf("configuration validation failed: %w", err)
- }
-
- return cfg, nil
-}
-
-func Init(workingDir string, debug bool) (*Config, error) {
- var err error
- once.Do(func() {
- cwd = workingDir
- instance, err = loadConfig(cwd, debug)
- if err != nil {
- logging.Error("Failed to load config", "error", err)
- }
+ slices.SortFunc(sorted, func(a, b MCP) int {
+ return strings.Compare(a.Name, b.Name)
})
-
- return instance, err
+ return sorted
}
-func Get() *Config {
- if instance == nil {
- // TODO: Handle this better
- panic("Config not initialized. Call InitConfig first.")
- }
- return instance
-}
-
-func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
- providers := Providers()
- for _, p := range providers {
- if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
- return &providerConfig
- }
- }
- // if none found return the first configured provider
- for _, providerConfig := range configuredProviders {
- if !providerConfig.Disabled {
- return &providerConfig
- }
- }
- return nil
-}
-
-func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
- if other.APIKey != "" {
- base.APIKey = other.APIKey
- }
- // Only change these options if the provider is not a known provider
- if !slices.Contains(provider.KnownProviders(), p) {
- if other.BaseURL != "" {
- base.BaseURL = other.BaseURL
- }
- if other.ProviderType != "" {
- base.ProviderType = other.ProviderType
- }
- if len(other.ExtraHeaders) > 0 {
- if base.ExtraHeaders == nil {
- base.ExtraHeaders = make(map[string]string)
- }
- maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
- }
- if len(other.ExtraParams) > 0 {
- if base.ExtraParams == nil {
- base.ExtraParams = make(map[string]string)
- }
- maps.Copy(base.ExtraParams, other.ExtraParams)
- }
- }
-
- if other.Disabled {
- base.Disabled = other.Disabled
- }
-
- if other.DefaultLargeModel != "" {
- base.DefaultLargeModel = other.DefaultLargeModel
- }
- // Add new models if they don't exist
- if other.Models != nil {
- for _, model := range other.Models {
- // check if the model already exists
- exists := false
- for _, existingModel := range base.Models {
- if existingModel.ID == model.ID {
- exists = true
- break
- }
- }
- if !exists {
- base.Models = append(base.Models, model)
- }
- }
- }
-
- return base
-}
+type LSPs map[string]LSPConfig
-func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
- if !slices.Contains(provider.KnownProviders(), p) {
- if providerConfig.ProviderType != provider.TypeOpenAI {
- return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
- }
- if providerConfig.BaseURL == "" {
- return errors.New("base URL must be set for custom providers")
- }
- if providerConfig.APIKey == "" {
- return errors.New("API key must be set for custom providers")
- }
- }
- return nil
+type LSP struct {
+ Name string `json:"name"`
+ LSP LSPConfig `json:"lsp"`
}
-func mergeModels(base *Config, others ...*Config) {
- for _, cfg := range others {
- if cfg == nil {
- continue
- }
- if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
- base.Models.Large = cfg.Models.Large
- }
-
- if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
- base.Models.Small = cfg.Models.Small
- }
+func (l LSPs) Sorted() []LSP {
+ sorted := make([]LSP, 0, len(l))
+ for k, v := range l {
+ sorted = append(sorted, LSP{
+ Name: k,
+ LSP: v,
+ })
}
+ slices.SortFunc(sorted, func(a, b LSP) int {
+ return strings.Compare(a.Name, b.Name)
+ })
+ return sorted
}
-func mergeOptions(base *Config, others ...*Config) {
- for _, cfg := range others {
- if cfg == nil {
- continue
- }
- baseOptions := base.Options
- other := cfg.Options
- if len(other.ContextPaths) > 0 {
- baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
- }
+type Agent struct {
+ ID string `json:"id,omitempty"`
+ Name string `json:"name,omitempty"`
+ Description string `json:"description,omitempty"`
+ // This is the id of the system prompt used by the agent
+ Disabled bool `json:"disabled,omitempty"`
- if other.TUI.CompactMode {
- baseOptions.TUI.CompactMode = other.TUI.CompactMode
- }
+ Model SelectedModelType `json:"model"`
- if other.Debug {
- baseOptions.Debug = other.Debug
- }
+ // The available tools for the agent
+ // if this is nil, all tools are available
+ AllowedTools []string `json:"allowed_tools,omitempty"`
- if other.DebugLSP {
- baseOptions.DebugLSP = other.DebugLSP
- }
+ // this tells us which MCPs are available for this agent
+ // if this is empty all mcps are available
+ // the string array is the list of tools from the AllowedMCP the agent has available
+ // if the string array is nil, all tools from the AllowedMCP are available
+ AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
- if other.DisableAutoSummarize {
- baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
- }
+ // The list of LSPs that this agent can use
+ // if this is nil, all LSPs are available
+ AllowedLSP []string `json:"allowed_lsp,omitempty"`
- if other.DataDirectory != "" {
- baseOptions.DataDirectory = other.DataDirectory
- }
- base.Options = baseOptions
- }
+ // Overrides the context paths for this agent
+ ContextPaths []string `json:"context_paths,omitempty"`
}
-func mergeAgents(base *Config, others ...*Config) {
- for _, cfg := range others {
- if cfg == nil {
- continue
- }
- for agentID, newAgent := range cfg.Agents {
- if _, ok := base.Agents[agentID]; !ok {
- newAgent.ID = agentID
- if newAgent.Model == "" {
- newAgent.Model = LargeModel
- }
- if len(newAgent.ContextPaths) > 0 {
- newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
- } else {
- newAgent.ContextPaths = base.Options.ContextPaths
- }
- base.Agents[agentID] = newAgent
- } else {
- baseAgent := base.Agents[agentID]
+// Config holds the configuration for crush.
+type Config struct {
+ // We currently only support large/small as values here.
+ Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
- if agentID == AgentCoder || agentID == AgentTask {
- if newAgent.Model != "" {
- baseAgent.Model = newAgent.Model
- }
- if newAgent.AllowedMCP != nil {
- baseAgent.AllowedMCP = newAgent.AllowedMCP
- }
- if newAgent.AllowedLSP != nil {
- baseAgent.AllowedLSP = newAgent.AllowedLSP
- }
- // Context paths are additive for known agents too
- if len(newAgent.ContextPaths) > 0 {
- baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
- }
- } else {
- if newAgent.Name != "" {
- baseAgent.Name = newAgent.Name
- }
- if newAgent.Description != "" {
- baseAgent.Description = newAgent.Description
- }
- if newAgent.Model != "" {
- baseAgent.Model = newAgent.Model
- } else if baseAgent.Model == "" {
- baseAgent.Model = LargeModel
- }
+ // The providers that are configured
+ Providers map[string]ProviderConfig `json:"providers,omitempty"`
- baseAgent.Disabled = newAgent.Disabled
+ MCP MCPs `json:"mcp,omitempty"`
- if newAgent.AllowedTools != nil {
- baseAgent.AllowedTools = newAgent.AllowedTools
- }
- if newAgent.AllowedMCP != nil {
- baseAgent.AllowedMCP = newAgent.AllowedMCP
- }
- if newAgent.AllowedLSP != nil {
- baseAgent.AllowedLSP = newAgent.AllowedLSP
- }
- if len(newAgent.ContextPaths) > 0 {
- baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
- }
- }
+ LSP LSPs `json:"lsp,omitempty"`
- base.Agents[agentID] = baseAgent
- }
- }
- }
-}
+ Options *Options `json:"options,omitempty"`
-func mergeMCPs(base *Config, others ...*Config) {
- for _, cfg := range others {
- if cfg == nil {
- continue
- }
- maps.Copy(base.MCP, cfg.MCP)
- }
+ // Internal
+ workingDir string `json:"-"`
+ // TODO: most likely remove this concept when I come back to it
+ Agents map[string]Agent `json:"-"`
+ // TODO: find a better way to do this this should probably not be part of the config
+ resolver VariableResolver
}
-func mergeLSPs(base *Config, others ...*Config) {
- for _, cfg := range others {
- if cfg == nil {
- continue
- }
- maps.Copy(base.LSP, cfg.LSP)
- }
+func (c *Config) WorkingDir() string {
+ return c.workingDir
}
-func mergeProviderConfigs(base *Config, others ...*Config) {
- for _, cfg := range others {
- if cfg == nil {
- continue
- }
- for providerName, p := range cfg.Providers {
- p.ID = providerName
- if _, ok := base.Providers[providerName]; !ok {
- if slices.Contains(provider.KnownProviders(), providerName) {
- providers := Providers()
- for _, providerDef := range providers {
- if providerDef.ID == providerName {
- logging.Info("Using default provider config for", "provider", providerName)
- baseProvider := getDefaultProviderConfig(providerDef, providerDef.APIKey)
- base.Providers[providerName] = mergeProviderConfig(providerName, baseProvider, p)
- break
- }
- }
- } else {
- base.Providers[providerName] = p
- }
- } else {
- base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
- }
- }
- }
-
- finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
- for providerName, providerConfig := range base.Providers {
- err := validateProvider(providerName, providerConfig)
- if err != nil {
- logging.Warn("Skipping provider", "name", providerName, "error", err)
- continue // Skip invalid providers
+func (c *Config) EnabledProviders() []ProviderConfig {
+ enabled := make([]ProviderConfig, 0, len(c.Providers))
+ for _, p := range c.Providers {
+ if !p.Disable {
+ enabled = append(enabled, p)
}
- finalProviders[providerName] = providerConfig
}
- base.Providers = finalProviders
+ return enabled
}
-func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig {
- switch providerID {
- case provider.InferenceProviderAnthropic:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeAnthropic,
- }
- case provider.InferenceProviderOpenAI:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeOpenAI,
- }
- case provider.InferenceProviderGemini:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeGemini,
- }
- case provider.InferenceProviderBedrock:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeBedrock,
- }
- case provider.InferenceProviderAzure:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeAzure,
- }
- case provider.InferenceProviderOpenRouter:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeOpenAI,
- BaseURL: "https://openrouter.ai/api/v1",
- ExtraHeaders: map[string]string{
- "HTTP-Referer": "crush.charm.land",
- "X-Title": "Crush",
- },
- }
- case provider.InferenceProviderXAI:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeXAI,
- BaseURL: "https://api.x.ai/v1",
- }
- case provider.InferenceProviderVertexAI:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeVertexAI,
- }
- default:
- return ProviderConfig{
- ID: providerID,
- ProviderType: provider.TypeOpenAI,
- }
- }
+// IsConfigured return true if at least one provider is configured
+func (c *Config) IsConfigured() bool {
+ return len(c.EnabledProviders()) > 0
}
-func getDefaultProviderConfig(p provider.Provider, apiKey string) ProviderConfig {
- providerConfig := providerDefaultConfig(p.ID)
- providerConfig.APIKey = apiKey
- providerConfig.DefaultLargeModel = p.DefaultLargeModelID
- providerConfig.DefaultSmallModel = p.DefaultSmallModelID
- baseURL := p.APIEndpoint
- if strings.HasPrefix(baseURL, "$") {
- envVar := strings.TrimPrefix(baseURL, "$")
- baseURL = os.Getenv(envVar)
- }
- providerConfig.BaseURL = baseURL
- for _, model := range p.Models {
- configModel := Model{
- ID: model.ID,
- Name: model.Name,
- CostPer1MIn: model.CostPer1MIn,
- CostPer1MOut: model.CostPer1MOut,
- CostPer1MInCached: model.CostPer1MInCached,
- CostPer1MOutCached: model.CostPer1MOutCached,
- ContextWindow: model.ContextWindow,
- DefaultMaxTokens: model.DefaultMaxTokens,
- CanReason: model.CanReason,
- SupportsImages: model.SupportsImages,
- }
- // Set reasoning effort for reasoning models
- if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
- configModel.HasReasoningEffort = model.HasReasoningEffort
- configModel.ReasoningEffort = model.DefaultReasoningEffort
- }
- providerConfig.Models = append(providerConfig.Models, configModel)
- }
- return providerConfig
-}
-
-func defaultConfigBasedOnEnv() *Config {
- cfg := &Config{
- Options: Options{
- DataDirectory: defaultDataDirectory,
- ContextPaths: defaultContextPaths,
- },
- Providers: make(map[provider.InferenceProvider]ProviderConfig),
- Agents: make(map[AgentID]Agent),
- LSP: make(map[string]LSPConfig),
- MCP: make(map[string]MCP),
- }
-
- providers := Providers()
-
- for _, p := range providers {
- if strings.HasPrefix(p.APIKey, "$") {
- envVar := strings.TrimPrefix(p.APIKey, "$")
- if apiKey := os.Getenv(envVar); apiKey != "" {
- cfg.Providers[p.ID] = getDefaultProviderConfig(p, apiKey)
- }
- }
- }
- // TODO: support local models
-
- if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
- providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
- providerConfig.ExtraParams = map[string]string{
- "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
- "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
- }
- // Find the VertexAI provider definition to get default models
- for _, p := range providers {
- if p.ID == provider.InferenceProviderVertexAI {
- providerConfig.DefaultLargeModel = p.DefaultLargeModelID
- providerConfig.DefaultSmallModel = p.DefaultSmallModelID
- for _, model := range p.Models {
- configModel := Model{
- ID: model.ID,
- Name: model.Name,
- CostPer1MIn: model.CostPer1MIn,
- CostPer1MOut: model.CostPer1MOut,
- CostPer1MInCached: model.CostPer1MInCached,
- CostPer1MOutCached: model.CostPer1MOutCached,
- ContextWindow: model.ContextWindow,
- DefaultMaxTokens: model.DefaultMaxTokens,
- CanReason: model.CanReason,
- SupportsImages: model.SupportsImages,
- }
- // Set reasoning effort for reasoning models
- if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
- configModel.HasReasoningEffort = model.HasReasoningEffort
- configModel.ReasoningEffort = model.DefaultReasoningEffort
- }
- providerConfig.Models = append(providerConfig.Models, configModel)
- }
- break
- }
- }
- cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
- }
-
- if hasAWSCredentials() {
- providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
- providerConfig.ExtraParams = map[string]string{
- "region": os.Getenv("AWS_DEFAULT_REGION"),
- }
- if providerConfig.ExtraParams["region"] == "" {
- providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
- }
- // Find the Bedrock provider definition to get default models
- for _, p := range providers {
- if p.ID == provider.InferenceProviderBedrock {
- providerConfig.DefaultLargeModel = p.DefaultLargeModelID
- providerConfig.DefaultSmallModel = p.DefaultSmallModelID
- for _, model := range p.Models {
- configModel := Model{
- ID: model.ID,
- Name: model.Name,
- CostPer1MIn: model.CostPer1MIn,
- CostPer1MOut: model.CostPer1MOut,
- CostPer1MInCached: model.CostPer1MInCached,
- CostPer1MOutCached: model.CostPer1MOutCached,
- ContextWindow: model.ContextWindow,
- DefaultMaxTokens: model.DefaultMaxTokens,
- CanReason: model.CanReason,
- SupportsImages: model.SupportsImages,
- }
- // Set reasoning effort for reasoning models
- if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
- configModel.HasReasoningEffort = model.HasReasoningEffort
- configModel.ReasoningEffort = model.DefaultReasoningEffort
- }
- providerConfig.Models = append(providerConfig.Models, configModel)
- }
- break
+func (c *Config) GetModel(provider, model string) *provider.Model {
+ if providerConfig, ok := c.Providers[provider]; ok {
+ for _, m := range providerConfig.Models {
+ if m.ID == model {
+ return &m
}
}
- cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
- }
- return cfg
-}
-
-func hasAWSCredentials() bool {
- if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
- return true
}
-
- if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
- return true
- }
-
- if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
- return true
- }
-
- if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
- os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
- return true
- }
-
- return false
-}
-
-func WorkingDirectory() string {
- return cwd
+ return nil
}
-// TODO: Handle error state
-
-func GetAgentModel(agentID AgentID) Model {
- cfg := Get()
- agent, ok := cfg.Agents[agentID]
- if !ok {
- logging.Error("Agent not found", "agent_id", agentID)
- return Model{}
- }
-
- var model PreferredModel
- switch agent.Model {
- case LargeModel:
- model = cfg.Models.Large
- case SmallModel:
- model = cfg.Models.Small
- default:
- logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
- model = cfg.Models.Large // Fallback to large model
- }
- providerConfig, ok := cfg.Providers[model.Provider]
+func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
+ model, ok := c.Models[modelType]
if !ok {
- logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
- return Model{}
+ return nil
}
-
- for _, m := range providerConfig.Models {
- if m.ID == model.ModelID {
- return m
- }
+ if providerConfig, ok := c.Providers[model.Provider]; ok {
+ return &providerConfig
}
-
- logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
- return Model{}
+ return nil
}
-func GetAgentProvider(agentID AgentID) ProviderConfig {
- cfg := Get()
- agent, ok := cfg.Agents[agentID]
+func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
+ model, ok := c.Models[modelType]
if !ok {
- logging.Error("Agent not found", "agent_id", agentID)
- return ProviderConfig{}
+ return nil
}
-
- var model PreferredModel
- switch agent.Model {
- case LargeModel:
- model = cfg.Models.Large
- case SmallModel:
- model = cfg.Models.Small
- default:
- logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
- model = cfg.Models.Large // Fallback to large model
- }
-
- providerConfig, ok := cfg.Providers[model.Provider]
- if !ok {
- logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
- return ProviderConfig{}
- }
-
- return providerConfig
+ return c.GetModel(model.Provider, model.Model)
}
-func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
- cfg := Get()
- providerConfig, ok := cfg.Providers[provider]
+func (c *Config) LargeModel() *provider.Model {
+ model, ok := c.Models[SelectedModelTypeLarge]
if !ok {
- logging.Error("Provider not found", "provider", provider)
- return Model{}
+ return nil
}
-
- for _, model := range providerConfig.Models {
- if model.ID == modelID {
- return model
- }
- }
-
- logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
- return Model{}
+ return c.GetModel(model.Provider, model.Model)
}
-func GetModel(modelType ModelType) Model {
- cfg := Get()
- var model PreferredModel
- switch modelType {
- case LargeModel:
- model = cfg.Models.Large
- case SmallModel:
- model = cfg.Models.Small
- default:
- model = cfg.Models.Large // Fallback to large model
- }
- providerConfig, ok := cfg.Providers[model.Provider]
+func (c *Config) SmallModel() *provider.Model {
+ model, ok := c.Models[SelectedModelTypeSmall]
if !ok {
- return Model{}
+ return nil
}
-
- for _, m := range providerConfig.Models {
- if m.ID == model.ModelID {
- return m
- }
- }
- return Model{}
+ return c.GetModel(model.Provider, model.Model)
}
-func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
- cfg := Get()
- switch modelType {
- case LargeModel:
- cfg.Models.Large = model
- case SmallModel:
- cfg.Models.Small = model
- default:
- return fmt.Errorf("unknown model type: %s", modelType)
+func (c *Config) Resolve(key string) (string, error) {
+ if c.resolver == nil {
+ return "", fmt.Errorf("no variable resolver configured")
}
- return nil
-}
-
-// ValidationError represents a configuration validation error
-type ValidationError struct {
- Field string
- Message string
+ return c.resolver.ResolveValue(key)
}
-func (e ValidationError) Error() string {
- return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
-}
-
-// ValidationErrors represents multiple validation errors
-type ValidationErrors []ValidationError
-
-func (e ValidationErrors) Error() string {
- if len(e) == 0 {
- return "no validation errors"
- }
- if len(e) == 1 {
- return e[0].Error()
- }
-
- var messages []string
- for _, err := range e {
- messages = append(messages, err.Error())
- }
- return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
-}
-
-// HasErrors returns true if there are any validation errors
-func (e ValidationErrors) HasErrors() bool {
- return len(e) > 0
-}
-
-// Add appends a new validation error
-func (e *ValidationErrors) Add(field, message string) {
- *e = append(*e, ValidationError{Field: field, Message: message})
-}
-
-// Validate performs comprehensive validation of the configuration
-func (c *Config) Validate() error {
- var errors ValidationErrors
-
- // Validate providers
- c.validateProviders(&errors)
-
- // Validate models
- c.validateModels(&errors)
-
- // Validate agents
- c.validateAgents(&errors)
-
- // Validate options
- c.validateOptions(&errors)
-
- // Validate MCP configurations
- c.validateMCPs(&errors)
-
- // Validate LSP configurations
- c.validateLSPs(&errors)
-
- // Validate cross-references
- c.validateCrossReferences(&errors)
-
- // Validate completeness
- c.validateCompleteness(&errors)
-
- if errors.HasErrors() {
- return errors
- }
-
+// TODO: maybe handle this better
+func UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
+ cfg := Get()
+ cfg.Models[modelType] = model
return nil
}
-
-// validateProviders validates all provider configurations
-func (c *Config) validateProviders(errors *ValidationErrors) {
- if c.Providers == nil {
- c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
- }
-
- knownProviders := provider.KnownProviders()
- validTypes := []provider.Type{
- provider.TypeOpenAI,
- provider.TypeAnthropic,
- provider.TypeGemini,
- provider.TypeAzure,
- provider.TypeBedrock,
- provider.TypeVertexAI,
- provider.TypeXAI,
- }
-
- for providerID, providerConfig := range c.Providers {
- fieldPrefix := fmt.Sprintf("providers.%s", providerID)
-
- // Validate API key for non-disabled providers
- if !providerConfig.Disabled && providerConfig.APIKey == "" {
- // Special case for AWS Bedrock and VertexAI which may use other auth methods
- if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
- errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
- }
- }
-
- // Validate provider type
- validType := slices.Contains(validTypes, providerConfig.ProviderType)
- if !validType {
- errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
- }
-
- // Validate custom providers
- isKnownProvider := slices.Contains(knownProviders, providerID)
-
- if !isKnownProvider {
- // Custom provider validation
- if providerConfig.BaseURL == "" {
- errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
- }
- if providerConfig.ProviderType != provider.TypeOpenAI {
- errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
- }
- }
-
- // Validate models
- modelIDs := make(map[string]bool)
- for i, model := range providerConfig.Models {
- modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
-
- // Check for duplicate model IDs
- if modelIDs[model.ID] {
- errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
- }
- modelIDs[model.ID] = true
-
- // Validate required model fields
- if model.ID == "" {
- errors.Add(modelFieldPrefix+".id", "model ID is required")
- }
- if model.Name == "" {
- errors.Add(modelFieldPrefix+".name", "model name is required")
- }
- if model.ContextWindow <= 0 {
- errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
- }
- if model.DefaultMaxTokens <= 0 {
- errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
- }
- if model.DefaultMaxTokens > model.ContextWindow {
- errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
- }
-
- // Validate cost fields
- if model.CostPer1MIn < 0 {
- errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
- }
- if model.CostPer1MOut < 0 {
- errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
- }
- if model.CostPer1MInCached < 0 {
- errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
- }
- if model.CostPer1MOutCached < 0 {
- errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
- }
- }
-
- // Validate default model references
- if providerConfig.DefaultLargeModel != "" {
- if !modelIDs[providerConfig.DefaultLargeModel] {
- errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
- }
- }
- if providerConfig.DefaultSmallModel != "" {
- if !modelIDs[providerConfig.DefaultSmallModel] {
- errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
- }
- }
-
- // Validate provider-specific requirements
- c.validateProviderSpecific(providerID, providerConfig, errors)
- }
-}
-
-// validateProviderSpecific validates provider-specific requirements
-func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
- fieldPrefix := fmt.Sprintf("providers.%s", providerID)
-
- switch providerID {
- case provider.InferenceProviderVertexAI:
- if !providerConfig.Disabled {
- if providerConfig.ExtraParams == nil {
- errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
- } else {
- if providerConfig.ExtraParams["project"] == "" {
- errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
- }
- if providerConfig.ExtraParams["location"] == "" {
- errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
- }
- }
- }
- case provider.InferenceProviderBedrock:
- if !providerConfig.Disabled {
- if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
- errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
- }
- // Check for AWS credentials in environment
- if !hasAWSCredentials() {
- errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
- }
- }
- }
-}
-
-// validateModels validates preferred model configurations
-func (c *Config) validateModels(errors *ValidationErrors) {
- // Validate large model
- if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
- if c.Models.Large.ModelID == "" {
- errors.Add("models.large.model_id", "large model ID is required when provider is set")
- }
- if c.Models.Large.Provider == "" {
- errors.Add("models.large.provider", "large model provider is required when model ID is set")
- }
-
- // Check if provider exists and is not disabled
- if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
- if providerConfig.Disabled {
- errors.Add("models.large.provider", "large model provider is disabled")
- }
-
- // Check if model exists in provider
- modelExists := false
- for _, model := range providerConfig.Models {
- if model.ID == c.Models.Large.ModelID {
- modelExists = true
- break
- }
- }
- if !modelExists {
- errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
- }
- } else {
- errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
- }
- }
-
- // Validate small model
- if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
- if c.Models.Small.ModelID == "" {
- errors.Add("models.small.model_id", "small model ID is required when provider is set")
- }
- if c.Models.Small.Provider == "" {
- errors.Add("models.small.provider", "small model provider is required when model ID is set")
- }
-
- // Check if provider exists and is not disabled
- if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
- if providerConfig.Disabled {
- errors.Add("models.small.provider", "small model provider is disabled")
- }
-
- // Check if model exists in provider
- modelExists := false
- for _, model := range providerConfig.Models {
- if model.ID == c.Models.Small.ModelID {
- modelExists = true
- break
- }
- }
- if !modelExists {
- errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
- }
- } else {
- errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
- }
- }
-}
-
-// validateAgents validates agent configurations
-func (c *Config) validateAgents(errors *ValidationErrors) {
- if c.Agents == nil {
- c.Agents = make(map[AgentID]Agent)
- }
-
- validTools := []string{
- "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
- }
-
- for agentID, agent := range c.Agents {
- fieldPrefix := fmt.Sprintf("agents.%s", agentID)
-
- // Validate agent ID consistency
- if agent.ID != agentID {
- errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
- }
-
- // Validate required fields
- if agent.ID == "" {
- errors.Add(fieldPrefix+".id", "agent ID is required")
- }
- if agent.Name == "" {
- errors.Add(fieldPrefix+".name", "agent name is required")
- }
-
- // Validate model type
- if agent.Model != LargeModel && agent.Model != SmallModel {
- errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
- }
-
- // Validate allowed tools
- if agent.AllowedTools != nil {
- for i, tool := range agent.AllowedTools {
- validTool := slices.Contains(validTools, tool)
- if !validTool {
- errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
- }
- }
- }
-
- // Validate MCP references
- if agent.AllowedMCP != nil {
- for mcpName := range agent.AllowedMCP {
- if _, exists := c.MCP[mcpName]; !exists {
- errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
- }
- }
- }
-
- // Validate LSP references
- if agent.AllowedLSP != nil {
- for _, lspName := range agent.AllowedLSP {
- if _, exists := c.LSP[lspName]; !exists {
- errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
- }
- }
- }
-
- // Validate context paths (basic path validation)
- for i, contextPath := range agent.ContextPaths {
- if contextPath == "" {
- errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
- }
- // Check for invalid characters in path
- if strings.Contains(contextPath, "\x00") {
- errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
- }
- }
-
- // Validate known agents maintain their core properties
- if agentID == AgentCoder {
- if agent.Name != "Coder" {
- errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
- }
- if agent.Description != "An agent that helps with executing coding tasks." {
- errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
- }
- } else if agentID == AgentTask {
- if agent.Name != "Task" {
- errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
- }
- if agent.Description != "An agent that helps with searching for context and finding implementation details." {
- errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
- }
- expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
- if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
- errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
- }
- }
- }
-}
-
-// validateOptions validates configuration options
-func (c *Config) validateOptions(errors *ValidationErrors) {
- // Validate data directory
- if c.Options.DataDirectory == "" {
- errors.Add("options.data_directory", "data directory is required")
- }
-
- // Validate context paths
- for i, contextPath := range c.Options.ContextPaths {
- if contextPath == "" {
- errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
- }
- if strings.Contains(contextPath, "\x00") {
- errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
- }
- }
-}
-
-// validateMCPs validates MCP configurations
-func (c *Config) validateMCPs(errors *ValidationErrors) {
- if c.MCP == nil {
- c.MCP = make(map[string]MCP)
- }
-
- for mcpName, mcpConfig := range c.MCP {
- fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
-
- // Validate MCP type
- if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse && mcpConfig.Type != MCPHttp {
- errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse' or 'http')", mcpConfig.Type))
- }
-
- // Validate based on type
- if mcpConfig.Type == MCPStdio {
- if mcpConfig.Command == "" {
- errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
- }
- } else if mcpConfig.Type == MCPSse {
- if mcpConfig.URL == "" {
- errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
- }
- }
- }
-}
-
-// validateLSPs validates LSP configurations
-func (c *Config) validateLSPs(errors *ValidationErrors) {
- if c.LSP == nil {
- c.LSP = make(map[string]LSPConfig)
- }
-
- for lspName, lspConfig := range c.LSP {
- fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
-
- if lspConfig.Command == "" {
- errors.Add(fieldPrefix+".command", "command is required for LSP")
- }
- }
-}
-
-// validateCrossReferences validates cross-references between different config sections
-func (c *Config) validateCrossReferences(errors *ValidationErrors) {
- // Validate that agents can use their assigned model types
- for agentID, agent := range c.Agents {
- fieldPrefix := fmt.Sprintf("agents.%s", agentID)
-
- var preferredModel PreferredModel
- switch agent.Model {
- case LargeModel:
- preferredModel = c.Models.Large
- case SmallModel:
- preferredModel = c.Models.Small
- }
-
- if preferredModel.Provider != "" {
- if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
- if providerConfig.Disabled {
- errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
- }
- }
- }
- }
-}
-
-// validateCompleteness validates that the configuration is complete and usable
-func (c *Config) validateCompleteness(errors *ValidationErrors) {
- // Check for at least one valid, non-disabled provider
- hasValidProvider := false
- for _, providerConfig := range c.Providers {
- if !providerConfig.Disabled {
- hasValidProvider = true
- break
- }
- }
- if !hasValidProvider {
- errors.Add("providers", "at least one non-disabled provider is required")
- }
-
- // Check that default agents exist
- if _, exists := c.Agents[AgentCoder]; !exists {
- errors.Add("agents", "coder agent is required")
- }
- if _, exists := c.Agents[AgentTask]; !exists {
- errors.Add("agents", "task agent is required")
- }
-
- // Check that preferred models are set if providers exist
- if hasValidProvider {
- if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
- errors.Add("models.large", "large preferred model must be configured when providers are available")
- }
- if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
- errors.Add("models.small", "small preferred model must be configured when providers are available")
- }
- }
-}
-
-// JSONSchemaExtend adds custom schema properties for AgentID
-func (AgentID) JSONSchemaExtend(schema *jsonschema.Schema) {
- schema.Enum = []any{
- string(AgentCoder),
- string(AgentTask),
- }
-}
-
-// JSONSchemaExtend adds custom schema properties for ModelType
-func (ModelType) JSONSchemaExtend(schema *jsonschema.Schema) {
- schema.Enum = []any{
- string(LargeModel),
- string(SmallModel),
- }
-}
-
-// JSONSchemaExtend adds custom schema properties for MCPType
-func (MCPType) JSONSchemaExtend(schema *jsonschema.Schema) {
- schema.Enum = []any{
- string(MCPStdio),
- string(MCPSse),
- }
-}
@@ -1,2075 +0,0 @@
-package config
-
-import (
- "encoding/json"
- "os"
- "path/filepath"
- "sync"
- "testing"
-
- "github.com/charmbracelet/crush/internal/fur/provider"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func reset() {
- // Clear all environment variables that could affect config
- envVarsToUnset := []string{
- // API Keys
- "ANTHROPIC_API_KEY",
- "OPENAI_API_KEY",
- "GEMINI_API_KEY",
- "XAI_API_KEY",
- "OPENROUTER_API_KEY",
-
- // Google Cloud / VertexAI
- "GOOGLE_GENAI_USE_VERTEXAI",
- "GOOGLE_CLOUD_PROJECT",
- "GOOGLE_CLOUD_LOCATION",
-
- // AWS Credentials
- "AWS_ACCESS_KEY_ID",
- "AWS_SECRET_ACCESS_KEY",
- "AWS_REGION",
- "AWS_DEFAULT_REGION",
- "AWS_PROFILE",
- "AWS_DEFAULT_PROFILE",
- "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
- "AWS_CONTAINER_CREDENTIALS_FULL_URI",
-
- // Other
- "CRUSH_DEV_DEBUG",
- }
-
- for _, envVar := range envVarsToUnset {
- os.Unsetenv(envVar)
- }
-
- // Reset singleton
- once = sync.Once{}
- instance = nil
- cwd = ""
- testConfigDir = ""
-
- // Enable mock providers for all tests to avoid API calls
- UseMockProviders = true
- ResetProviders()
-}
-
-// Core Configuration Loading Tests
-
-func TestInit_ValidWorkingDirectory(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.NotNil(t, cfg)
- assert.Equal(t, cwdDir, WorkingDirectory())
- assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory)
- assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
-}
-
-func TestInit_WithDebugFlag(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- cfg, err := Init(cwdDir, true)
-
- require.NoError(t, err)
- assert.True(t, cfg.Options.Debug)
-}
-
-func TestInit_SingletonBehavior(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- cfg1, err1 := Init(cwdDir, false)
- cfg2, err2 := Init(cwdDir, false)
-
- require.NoError(t, err1)
- require.NoError(t, err2)
- assert.Same(t, cfg1, cfg2)
-}
-
-func TestGet_BeforeInitialization(t *testing.T) {
- reset()
-
- assert.Panics(t, func() {
- Get()
- })
-}
-
-func TestGet_AfterInitialization(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- cfg1, err := Init(cwdDir, false)
- require.NoError(t, err)
-
- cfg2 := Get()
- assert.Same(t, cfg1, cfg2)
-}
-
-func TestLoadConfig_NoConfigFiles(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
-}
-
-func TestLoadConfig_OnlyGlobalConfig(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- CostPer1MIn: 30.0,
- CostPer1MOut: 60.0,
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- CostPer1MIn: 1.0,
- CostPer1MOut: 2.0,
- ContextWindow: 4096,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- Options: Options{
- ContextPaths: []string{"custom-context.md"},
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
-
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
- assert.Contains(t, cfg.Options.ContextPaths, "custom-context.md")
-}
-
-func TestLoadConfig_OnlyLocalConfig(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- localConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderAnthropic: {
- ID: provider.InferenceProviderAnthropic,
- APIKey: "local-key",
- ProviderType: provider.TypeAnthropic,
- DefaultLargeModel: "claude-3-opus",
- DefaultSmallModel: "claude-3-haiku",
- Models: []Model{
- {
- ID: "claude-3-opus",
- Name: "Claude 3 Opus",
- CostPer1MIn: 15.0,
- CostPer1MOut: 75.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "claude-3-haiku",
- Name: "Claude 3 Haiku",
- CostPer1MIn: 0.25,
- CostPer1MOut: 1.25,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- Options: Options{
- TUI: TUIOptions{CompactMode: true},
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err := json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
- assert.True(t, cfg.Options.TUI.CompactMode)
-}
-
-func TestLoadConfig_BothGlobalAndLocal(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "global-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- CostPer1MIn: 30.0,
- CostPer1MOut: 60.0,
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- CostPer1MIn: 1.0,
- CostPer1MOut: 2.0,
- ContextWindow: 4096,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- Options: Options{
- ContextPaths: []string{"global-context.md"},
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- localConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- APIKey: "local-key", // Override global
- },
- provider.InferenceProviderAnthropic: {
- ID: provider.InferenceProviderAnthropic,
- APIKey: "anthropic-key",
- ProviderType: provider.TypeAnthropic,
- DefaultLargeModel: "claude-3-opus",
- DefaultSmallModel: "claude-3-haiku",
- Models: []Model{
- {
- ID: "claude-3-opus",
- Name: "Claude 3 Opus",
- CostPer1MIn: 15.0,
- CostPer1MOut: 75.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "claude-3-haiku",
- Name: "Claude 3 Haiku",
- CostPer1MIn: 0.25,
- CostPer1MOut: 1.25,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- Options: Options{
- ContextPaths: []string{"local-context.md"},
- TUI: TUIOptions{CompactMode: true},
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 2)
-
- openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
- assert.Equal(t, "local-key", openaiProvider.APIKey)
-
- assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
-
- assert.Contains(t, cfg.Options.ContextPaths, "global-context.md")
- assert.Contains(t, cfg.Options.ContextPaths, "local-context.md")
- assert.True(t, cfg.Options.TUI.CompactMode)
-}
-
-func TestLoadConfig_MalformedGlobalJSON(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644))
-
- _, err := Init(cwdDir, false)
- assert.Error(t, err)
-}
-
-func TestLoadConfig_MalformedLocalJSON(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644))
-
- _, err := Init(cwdDir, false)
- assert.Error(t, err)
-}
-
-func TestConfigWithoutEnv(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- cfg, _ := Init(cwdDir, false)
- assert.Len(t, cfg.Providers, 0)
-}
-
-func TestConfigWithEnv(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
- os.Setenv("OPENAI_API_KEY", "test-openai-key")
- os.Setenv("GEMINI_API_KEY", "test-gemini-key")
- os.Setenv("XAI_API_KEY", "test-xai-key")
- os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
-
- cfg, _ := Init(cwdDir, false)
- assert.Len(t, cfg.Providers, 5)
-}
-
-// Environment Variable Tests
-
-func TestEnvVars_NoEnvironmentVariables(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
-}
-
-func TestEnvVars_AllSupportedAPIKeys(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
- os.Setenv("OPENAI_API_KEY", "test-openai-key")
- os.Setenv("GEMINI_API_KEY", "test-gemini-key")
- os.Setenv("XAI_API_KEY", "test-xai-key")
- os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 5)
-
- anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
- assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey)
- assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType)
-
- openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
- assert.Equal(t, "test-openai-key", openaiProvider.APIKey)
- assert.Equal(t, provider.TypeOpenAI, openaiProvider.ProviderType)
-
- geminiProvider := cfg.Providers[provider.InferenceProviderGemini]
- assert.Equal(t, "test-gemini-key", geminiProvider.APIKey)
- assert.Equal(t, provider.TypeGemini, geminiProvider.ProviderType)
-
- xaiProvider := cfg.Providers[provider.InferenceProviderXAI]
- assert.Equal(t, "test-xai-key", xaiProvider.APIKey)
- assert.Equal(t, provider.TypeXAI, xaiProvider.ProviderType)
-
- openrouterProvider := cfg.Providers[provider.InferenceProviderOpenRouter]
- assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey)
- assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType)
- assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL)
-}
-
-func TestEnvVars_PartialEnvironmentVariables(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
- os.Setenv("OPENAI_API_KEY", "test-openai-key")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 2)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
- assert.NotContains(t, cfg.Providers, provider.InferenceProviderGemini)
-}
-
-func TestEnvVars_VertexAIConfiguration(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
- os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
- os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderVertexAI)
-
- vertexProvider := cfg.Providers[provider.InferenceProviderVertexAI]
- assert.Equal(t, provider.TypeVertexAI, vertexProvider.ProviderType)
- assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
- assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
-}
-
-func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
- os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.NotContains(t, cfg.Providers, provider.InferenceProviderVertexAI)
-}
-
-func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key")
- os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
- os.Setenv("AWS_DEFAULT_REGION", "us-east-1")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
-
- bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
- assert.Equal(t, provider.TypeBedrock, bedrockProvider.ProviderType)
- assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"])
-}
-
-func TestEnvVars_AWSBedrockWithProfile(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("AWS_PROFILE", "test-profile")
- os.Setenv("AWS_REGION", "eu-west-1")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
-
- bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
- assert.Equal(t, "eu-west-1", bedrockProvider.ExtraParams["region"])
-}
-
-func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test")
- os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
-}
-
-func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
- os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
- os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
- os.Setenv("AWS_REGION", "us-east-1")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
- assert.Equal(t, "us-west-2", bedrockProvider.ExtraParams["region"])
-}
-
-func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
- os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
- os.Setenv("AWS_REGION", "us-east-1")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
- assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"])
-}
-
-func TestEnvVars_NoAWSCredentials(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.NotContains(t, cfg.Providers, provider.InferenceProviderBedrock)
-}
-
-func TestEnvVars_CustomEnvironmentVariables(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- if len(cfg.Providers) > 0 {
- if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists {
- assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey)
- }
- }
-}
-
-func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-anthropic")
- os.Setenv("OPENAI_API_KEY", "test-openai")
- os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
- os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
- os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
- os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key")
- os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret")
- os.Setenv("AWS_DEFAULT_REGION", "us-west-1")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- expectedProviders := []provider.InferenceProvider{
- provider.InferenceProviderAnthropic,
- provider.InferenceProviderOpenAI,
- provider.InferenceProviderVertexAI,
- provider.InferenceProviderBedrock,
- }
-
- for _, expectedProvider := range expectedProviders {
- assert.Contains(t, cfg.Providers, expectedProvider)
- }
-}
-
-func TestHasAWSCredentials_AccessKeys(t *testing.T) {
- reset()
-
- os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
- os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
-
- assert.True(t, hasAWSCredentials())
-}
-
-func TestHasAWSCredentials_Profile(t *testing.T) {
- reset()
-
- os.Setenv("AWS_PROFILE", "test-profile")
-
- assert.True(t, hasAWSCredentials())
-}
-
-func TestHasAWSCredentials_DefaultProfile(t *testing.T) {
- reset()
-
- os.Setenv("AWS_DEFAULT_PROFILE", "default")
-
- assert.True(t, hasAWSCredentials())
-}
-
-func TestHasAWSCredentials_Region(t *testing.T) {
- reset()
-
- os.Setenv("AWS_REGION", "us-east-1")
-
- assert.True(t, hasAWSCredentials())
-}
-
-func TestHasAWSCredentials_ContainerCredentials(t *testing.T) {
- reset()
-
- os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test")
-
- assert.True(t, hasAWSCredentials())
-}
-
-func TestHasAWSCredentials_NoCredentials(t *testing.T) {
- reset()
-
- assert.False(t, hasAWSCredentials())
-}
-
-func TestProviderMerging_GlobalToBase(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "global-openai-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- ContextWindow: 4096,
- DefaultMaxTokens: 2048,
- },
- },
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
-
- openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
- assert.Equal(t, "global-openai-key", openaiProvider.APIKey)
- assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel)
- assert.Equal(t, "gpt-4o", openaiProvider.DefaultSmallModel)
- assert.GreaterOrEqual(t, len(openaiProvider.Models), 2)
-}
-
-func TestProviderMerging_LocalToBase(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- localConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderAnthropic: {
- ID: provider.InferenceProviderAnthropic,
- APIKey: "local-anthropic-key",
- ProviderType: provider.TypeAnthropic,
- DefaultLargeModel: "claude-3-opus",
- DefaultSmallModel: "claude-3-haiku",
- Models: []Model{
- {
- ID: "claude-3-opus",
- Name: "Claude 3 Opus",
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- CostPer1MIn: 15.0,
- CostPer1MOut: 75.0,
- },
- {
- ID: "claude-3-haiku",
- Name: "Claude 3 Haiku",
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- CostPer1MIn: 0.25,
- CostPer1MOut: 1.25,
- },
- },
- },
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err := json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
-
- anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
- assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey)
- assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel)
- assert.Equal(t, "claude-3-5-haiku-20241022", anthropicProvider.DefaultSmallModel)
- assert.GreaterOrEqual(t, len(anthropicProvider.Models), 2)
-}
-
-func TestProviderMerging_ConflictingSettings(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "global-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- ContextWindow: 4096,
- DefaultMaxTokens: 2048,
- },
- {
- ID: "gpt-4-turbo",
- Name: "GPT-4 Turbo",
- ContextWindow: 128000,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- // Create local config that overrides
- localConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- APIKey: "local-key",
- DefaultLargeModel: "gpt-4-turbo",
- },
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
- assert.Equal(t, "local-key", openaiProvider.APIKey)
- assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel)
- assert.False(t, openaiProvider.Disabled)
- assert.Equal(t, "gpt-4o", openaiProvider.DefaultSmallModel)
-}
-
-func TestProviderMerging_CustomVsKnownProviders(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- customProviderID := provider.InferenceProvider("custom-provider")
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "openai-key",
- BaseURL: "should-not-override",
- ProviderType: provider.TypeAnthropic,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- ContextWindow: 4096,
- DefaultMaxTokens: 2048,
- },
- },
- },
- customProviderID: {
- ID: customProviderID,
- APIKey: "custom-key",
- BaseURL: "https://custom.api.com",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "custom-large",
- DefaultSmallModel: "custom-small",
- Models: []Model{
- {
- ID: "custom-large",
- Name: "Custom Large",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "custom-small",
- Name: "Custom Small",
- ContextWindow: 4096,
- DefaultMaxTokens: 2048,
- },
- },
- },
- },
- }
-
- localConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- BaseURL: "https://should-not-change.com",
- ProviderType: provider.TypeGemini, // Should not change
- },
- customProviderID: {
- BaseURL: "https://updated-custom.api.com",
- ProviderType: provider.TypeOpenAI,
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
- assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL)
- assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType)
-
- customProvider := cfg.Providers[customProviderID]
- assert.Equal(t, "custom-key", customProvider.APIKey)
- assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL)
- assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType)
-}
-
-func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- customProviderID := provider.InferenceProvider("custom-provider")
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- customProviderID: {
- ID: customProviderID,
- APIKey: "custom-key",
- ProviderType: provider.TypeOpenAI,
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.NotContains(t, cfg.Providers, customProviderID)
-}
-
-func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- customProviderID := provider.InferenceProvider("custom-provider")
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- customProviderID: {
- ID: customProviderID,
- BaseURL: "https://custom.api.com",
- ProviderType: provider.TypeOpenAI,
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.NotContains(t, cfg.Providers, customProviderID)
-}
-
-func TestProviderValidation_CustomProviderInvalidType(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- customProviderID := provider.InferenceProvider("custom-provider")
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- customProviderID: {
- ID: customProviderID,
- APIKey: "custom-key",
- BaseURL: "https://custom.api.com",
- ProviderType: provider.Type("invalid-type"),
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.NotContains(t, cfg.Providers, customProviderID)
-}
-
-func TestProviderValidation_KnownProviderValid(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "openai-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- ContextWindow: 4096,
- DefaultMaxTokens: 2048,
- },
- },
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
-}
-
-func TestProviderValidation_DisabledProvider(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "openai-key",
- ProviderType: provider.TypeOpenAI,
- Disabled: true,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- ContextWindow: 4096,
- DefaultMaxTokens: 2048,
- },
- },
- },
- provider.InferenceProviderAnthropic: {
- ID: provider.InferenceProviderAnthropic,
- APIKey: "anthropic-key",
- ProviderType: provider.TypeAnthropic,
- Disabled: false, // This one is enabled
- DefaultLargeModel: "claude-3-opus",
- DefaultSmallModel: "claude-3-haiku",
- Models: []Model{
- {
- ID: "claude-3-opus",
- Name: "Claude 3 Opus",
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "claude-3-haiku",
- Name: "Claude 3 Haiku",
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
- assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled)
- assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
- assert.False(t, cfg.Providers[provider.InferenceProviderAnthropic].Disabled)
-}
-
-func TestProviderModels_AddingNewModels(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "openai-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-4-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- }
-
- localConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- Models: []Model{
- {
- ID: "gpt-4-turbo",
- Name: "GPT-4 Turbo",
- ContextWindow: 128000,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
- assert.GreaterOrEqual(t, len(openaiProvider.Models), 2)
-
- modelIDs := make([]string, len(openaiProvider.Models))
- for i, model := range openaiProvider.Models {
- modelIDs[i] = model.ID
- }
- assert.Contains(t, modelIDs, "gpt-4")
- assert.Contains(t, modelIDs, "gpt-4-turbo")
-}
-
-func TestProviderModels_DuplicateModelHandling(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "openai-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-4",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- }
-
- localConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4 Updated",
- ContextWindow: 16384,
- DefaultMaxTokens: 8192,
- },
- },
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
- assert.GreaterOrEqual(t, len(openaiProvider.Models), 1)
-
- // Find the first model that matches our test data
- var testModel *Model
- for _, model := range openaiProvider.Models {
- if model.ID == "gpt-4" {
- testModel = &model
- break
- }
- }
-
- // If gpt-4 not found, use the first available model
- if testModel == nil {
- testModel = &openaiProvider.Models[0]
- }
-
- assert.NotEmpty(t, testModel.ID)
- assert.NotEmpty(t, testModel.Name)
- assert.Greater(t, testModel.ContextWindow, int64(0))
-}
-
-func TestProviderModels_ModelCostAndCapabilities(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "openai-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-4",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- CostPer1MIn: 30.0,
- CostPer1MOut: 60.0,
- CostPer1MInCached: 15.0,
- CostPer1MOutCached: 30.0,
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- CanReason: true,
- ReasoningEffort: "medium",
- SupportsImages: true,
- },
- },
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
- require.GreaterOrEqual(t, len(openaiProvider.Models), 1)
-
- // Find the test model or use the first one
- var testModel *Model
- for _, model := range openaiProvider.Models {
- if model.ID == "gpt-4" {
- testModel = &model
- break
- }
- }
-
- if testModel == nil {
- testModel = &openaiProvider.Models[0]
- }
-
- // Only test the custom properties if this is actually our test model
- if testModel.ID == "gpt-4" {
- assert.Equal(t, 30.0, testModel.CostPer1MIn)
- assert.Equal(t, 60.0, testModel.CostPer1MOut)
- assert.Equal(t, 15.0, testModel.CostPer1MInCached)
- assert.Equal(t, 30.0, testModel.CostPer1MOutCached)
- assert.True(t, testModel.CanReason)
- assert.Equal(t, "medium", testModel.ReasoningEffort)
- assert.True(t, testModel.SupportsImages)
- }
-}
-
-func TestDefaultAgents_CoderAgent(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Contains(t, cfg.Agents, AgentCoder)
-
- coderAgent := cfg.Agents[AgentCoder]
- assert.Equal(t, AgentCoder, coderAgent.ID)
- assert.Equal(t, "Coder", coderAgent.Name)
- assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description)
- assert.Equal(t, LargeModel, coderAgent.Model)
- assert.False(t, coderAgent.Disabled)
- assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths)
- assert.Nil(t, coderAgent.AllowedTools)
-}
-
-func TestDefaultAgents_TaskAgent(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- assert.Contains(t, cfg.Agents, AgentTask)
-
- taskAgent := cfg.Agents[AgentTask]
- assert.Equal(t, AgentTask, taskAgent.ID)
- assert.Equal(t, "Task", taskAgent.Name)
- assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description)
- assert.Equal(t, LargeModel, taskAgent.Model)
- assert.False(t, taskAgent.Disabled)
- assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths)
-
- expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
- assert.Equal(t, expectedTools, taskAgent.AllowedTools)
-
- assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP)
- assert.Equal(t, []string{}, taskAgent.AllowedLSP)
-}
-
-func TestAgentMerging_CustomAgent(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Agents: map[AgentID]Agent{
- AgentID("custom-agent"): {
- ID: AgentID("custom-agent"),
- Name: "Custom Agent",
- Description: "A custom agent for testing",
- Model: SmallModel,
- AllowedTools: []string{"glob", "grep"},
- AllowedMCP: map[string][]string{"mcp1": {"tool1", "tool2"}},
- AllowedLSP: []string{"typescript", "go"},
- ContextPaths: []string{"custom-context.md"},
- },
- },
- MCP: map[string]MCP{
- "mcp1": {
- Type: MCPStdio,
- Command: "test-mcp-command",
- Args: []string{"--test"},
- },
- },
- LSP: map[string]LSPConfig{
- "typescript": {
- Command: "typescript-language-server",
- Args: []string{"--stdio"},
- },
- "go": {
- Command: "gopls",
- Args: []string{},
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- assert.Contains(t, cfg.Agents, AgentCoder)
- assert.Contains(t, cfg.Agents, AgentTask)
- assert.Contains(t, cfg.Agents, AgentID("custom-agent"))
-
- customAgent := cfg.Agents[AgentID("custom-agent")]
- assert.Equal(t, "Custom Agent", customAgent.Name)
- assert.Equal(t, "A custom agent for testing", customAgent.Description)
- assert.Equal(t, SmallModel, customAgent.Model)
- assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools)
- assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP)
- assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP)
- expectedContextPaths := append(defaultContextPaths, "custom-context.md")
- assert.Equal(t, expectedContextPaths, customAgent.ContextPaths)
-}
-
-func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Agents: map[AgentID]Agent{
- AgentCoder: {
- Model: SmallModel,
- AllowedMCP: map[string][]string{"mcp1": {"tool1"}},
- AllowedLSP: []string{"typescript"},
- ContextPaths: []string{"coder-specific.md"},
- },
- },
- MCP: map[string]MCP{
- "mcp1": {
- Type: MCPStdio,
- Command: "test-mcp-command",
- Args: []string{"--test"},
- },
- },
- LSP: map[string]LSPConfig{
- "typescript": {
- Command: "typescript-language-server",
- Args: []string{"--stdio"},
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- coderAgent := cfg.Agents[AgentCoder]
- assert.Equal(t, AgentCoder, coderAgent.ID)
- assert.Equal(t, "Coder", coderAgent.Name)
- assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description)
-
- expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md")
- assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths)
-
- assert.Equal(t, SmallModel, coderAgent.Model)
- assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP)
- assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP)
-}
-
-func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Agents: map[AgentID]Agent{
- AgentTask: {
- Model: SmallModel,
- AllowedMCP: map[string][]string{"search-mcp": nil},
- AllowedLSP: []string{"python"},
- Name: "Search Agent",
- Description: "Custom search agent",
- Disabled: true,
- AllowedTools: []string{"glob", "grep", "view"},
- },
- },
- MCP: map[string]MCP{
- "search-mcp": {
- Type: MCPStdio,
- Command: "search-mcp-command",
- Args: []string{"--search"},
- },
- },
- LSP: map[string]LSPConfig{
- "python": {
- Command: "pylsp",
- Args: []string{},
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- taskAgent := cfg.Agents[AgentTask]
- assert.Equal(t, "Task", taskAgent.Name)
- assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description)
- assert.False(t, taskAgent.Disabled)
- assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
-
- assert.Equal(t, SmallModel, taskAgent.Model)
- assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP)
- assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP)
-}
-
-func TestAgentMerging_LocalOverridesGlobal(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Agents: map[AgentID]Agent{
- AgentID("test-agent"): {
- ID: AgentID("test-agent"),
- Name: "Global Agent",
- Description: "Global description",
- Model: LargeModel,
- AllowedTools: []string{"glob"},
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- // Create local config that overrides
- localConfig := Config{
- Agents: map[AgentID]Agent{
- AgentID("test-agent"): {
- Name: "Local Agent",
- Description: "Local description",
- Model: SmallModel,
- Disabled: true,
- AllowedTools: []string{"grep", "view"},
- AllowedMCP: map[string][]string{"local-mcp": {"tool1"}},
- },
- },
- MCP: map[string]MCP{
- "local-mcp": {
- Type: MCPStdio,
- Command: "local-mcp-command",
- Args: []string{"--local"},
- },
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- testAgent := cfg.Agents[AgentID("test-agent")]
- assert.Equal(t, "Local Agent", testAgent.Name)
- assert.Equal(t, "Local description", testAgent.Description)
- assert.Equal(t, SmallModel, testAgent.Model)
- assert.True(t, testAgent.Disabled)
- assert.Equal(t, []string{"grep", "view"}, testAgent.AllowedTools)
- assert.Equal(t, map[string][]string{"local-mcp": {"tool1"}}, testAgent.AllowedMCP)
-}
-
-func TestAgentModelTypeAssignment(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Agents: map[AgentID]Agent{
- AgentID("large-agent"): {
- ID: AgentID("large-agent"),
- Name: "Large Model Agent",
- Model: LargeModel,
- },
- AgentID("small-agent"): {
- ID: AgentID("small-agent"),
- Name: "Small Model Agent",
- Model: SmallModel,
- },
- AgentID("default-agent"): {
- ID: AgentID("default-agent"),
- Name: "Default Model Agent",
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model)
- assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model)
- assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model)
-}
-
-func TestAgentContextPathOverrides(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Options: Options{
- ContextPaths: []string{"global-context.md", "shared-context.md"},
- },
- Agents: map[AgentID]Agent{
- AgentID("custom-context-agent"): {
- ID: AgentID("custom-context-agent"),
- Name: "Custom Context Agent",
- ContextPaths: []string{"agent-specific.md", "custom.md"},
- },
- AgentID("default-context-agent"): {
- ID: AgentID("default-context-agent"),
- Name: "Default Context Agent",
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- customAgent := cfg.Agents[AgentID("custom-context-agent")]
- expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md")
- assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths)
-
- defaultAgent := cfg.Agents[AgentID("default-context-agent")]
- expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md")
- assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths)
-
- coderAgent := cfg.Agents[AgentCoder]
- assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths)
-}
-
-func TestOptionsMerging_ContextPaths(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Options: Options{
- ContextPaths: []string{"global1.md", "global2.md"},
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- localConfig := Config{
- Options: Options{
- ContextPaths: []string{"local1.md", "local2.md"},
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md")
- assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths)
-}
-
-func TestOptionsMerging_TUIOptions(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Options: Options{
- TUI: TUIOptions{
- CompactMode: false,
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- localConfig := Config{
- Options: Options{
- TUI: TUIOptions{
- CompactMode: true,
- },
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- assert.True(t, cfg.Options.TUI.CompactMode)
-}
-
-func TestOptionsMerging_DebugFlags(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Options: Options{
- Debug: false,
- DebugLSP: false,
- DisableAutoSummarize: false,
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- localConfig := Config{
- Options: Options{
- DebugLSP: true,
- DisableAutoSummarize: true,
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- assert.False(t, cfg.Options.Debug)
- assert.True(t, cfg.Options.DebugLSP)
- assert.True(t, cfg.Options.DisableAutoSummarize)
-}
-
-func TestOptionsMerging_DataDirectory(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Options: Options{
- DataDirectory: "global-data",
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- localConfig := Config{
- Options: Options{
- DataDirectory: "local-data",
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- assert.Equal(t, "local-data", cfg.Options.DataDirectory)
-}
-
-func TestOptionsMerging_DefaultValues(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory)
- assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
- assert.False(t, cfg.Options.TUI.CompactMode)
- assert.False(t, cfg.Options.Debug)
- assert.False(t, cfg.Options.DebugLSP)
- assert.False(t, cfg.Options.DisableAutoSummarize)
-}
-
-func TestOptionsMerging_DebugFlagFromInit(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Options: Options{
- Debug: false,
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- cfg, err := Init(cwdDir, true)
-
- require.NoError(t, err)
-
- // Debug flag from Init should take precedence
- assert.True(t, cfg.Options.Debug)
-}
-
-func TestOptionsMerging_ComplexScenario(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- // Set up a provider
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- // Create global config with various options
- globalConfig := Config{
- Options: Options{
- ContextPaths: []string{"global-context.md"},
- DataDirectory: "global-data",
- Debug: false,
- DebugLSP: false,
- DisableAutoSummarize: false,
- TUI: TUIOptions{
- CompactMode: false,
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- // Create local config that partially overrides
- localConfig := Config{
- Options: Options{
- ContextPaths: []string{"local-context.md"},
- DebugLSP: true, // Override
- DisableAutoSummarize: true, // Override
- TUI: TUIOptions{
- CompactMode: true, // Override
- },
- // DataDirectory and Debug not specified - should keep global values
- },
- }
-
- localConfigPath := filepath.Join(cwdDir, "crush.json")
- data, err = json.Marshal(localConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
-
- // Check merged results
- expectedContextPaths := append(defaultContextPaths, "global-context.md", "local-context.md")
- assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths)
- assert.Equal(t, "global-data", cfg.Options.DataDirectory) // From global
- assert.False(t, cfg.Options.Debug) // From global
- assert.True(t, cfg.Options.DebugLSP) // From local
- assert.True(t, cfg.Options.DisableAutoSummarize) // From local
- assert.True(t, cfg.Options.TUI.CompactMode) // From local
-}
-
-// Model Selection Tests
-
-func TestModelSelection_PreferredModelSelection(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- // Set up multiple providers to test selection logic
- os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
- os.Setenv("OPENAI_API_KEY", "test-openai-key")
-
- cfg, err := Init(cwdDir, false)
-
- require.NoError(t, err)
- require.Len(t, cfg.Providers, 2)
-
- // Should have preferred models set
- assert.NotEmpty(t, cfg.Models.Large.ModelID)
- assert.NotEmpty(t, cfg.Models.Large.Provider)
- assert.NotEmpty(t, cfg.Models.Small.ModelID)
- assert.NotEmpty(t, cfg.Models.Small.Provider)
-
- // Both should use the same provider (first available)
- assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider)
-}
-
-func TestValidation_InvalidModelReference(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- globalConfig := Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "non-existent-model",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- ContextWindow: 4096,
- DefaultMaxTokens: 2048,
- },
- },
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- _, err = Init(cwdDir, false)
- assert.Error(t, err)
-}
-
-func TestValidation_InvalidAgentModelType(t *testing.T) {
- reset()
- testConfigDir = t.TempDir()
- cwdDir := t.TempDir()
-
- os.Setenv("ANTHROPIC_API_KEY", "test-key")
-
- globalConfig := Config{
- Agents: map[AgentID]Agent{
- AgentID("invalid-agent"): {
- ID: AgentID("invalid-agent"),
- Name: "Invalid Agent",
- Model: ModelType("invalid"),
- },
- },
- }
-
- configPath := filepath.Join(testConfigDir, "crush.json")
- require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
- data, err := json.Marshal(globalConfig)
- require.NoError(t, err)
- require.NoError(t, os.WriteFile(configPath, data, 0o644))
-
- _, err = Init(cwdDir, false)
- assert.Error(t, err)
-}
@@ -1,71 +0,0 @@
-package config
-
-import (
- "fmt"
- "os"
- "path/filepath"
- "runtime"
-)
-
-var testConfigDir string
-
-func baseConfigPath() string {
- if testConfigDir != "" {
- return testConfigDir
- }
-
- xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
- if xdgConfigHome != "" {
- return filepath.Join(xdgConfigHome, "crush")
- }
-
- // return the path to the main config directory
- // for windows, it should be in `%LOCALAPPDATA%/crush/`
- // for linux and macOS, it should be in `$HOME/.config/crush/`
- if runtime.GOOS == "windows" {
- localAppData := os.Getenv("LOCALAPPDATA")
- if localAppData == "" {
- localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
- }
- return filepath.Join(localAppData, appName)
- }
-
- return filepath.Join(os.Getenv("HOME"), ".config", appName)
-}
-
-func baseDataPath() string {
- if testConfigDir != "" {
- return testConfigDir
- }
-
- xdgDataHome := os.Getenv("XDG_DATA_HOME")
- if xdgDataHome != "" {
- return filepath.Join(xdgDataHome, appName)
- }
-
- // return the path to the main data directory
- // for windows, it should be in `%LOCALAPPDATA%/crush/`
- // for linux and macOS, it should be in `$HOME/.local/share/crush/`
- if runtime.GOOS == "windows" {
- localAppData := os.Getenv("LOCALAPPDATA")
- if localAppData == "" {
- localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
- }
- return filepath.Join(localAppData, appName)
- }
-
- return filepath.Join(os.Getenv("HOME"), ".local", "share", appName)
-}
-
-func ConfigPath() string {
- return filepath.Join(baseConfigPath(), fmt.Sprintf("%s.json", appName))
-}
-
-func CrushInitialized() bool {
- cfgPath := ConfigPath()
- if _, err := os.Stat(cfgPath); os.IsNotExist(err) {
- // config file does not exist, so Crush is not initialized
- return false
- }
- return true
-}
@@ -5,27 +5,53 @@ import (
"os"
"path/filepath"
"strings"
+ "sync"
+ "sync/atomic"
+
+ "github.com/charmbracelet/crush/internal/logging"
)
const (
- // InitFlagFilename is the name of the file that indicates whether the project has been initialized
InitFlagFilename = "init"
)
-// ProjectInitFlag represents the initialization status for a project directory
type ProjectInitFlag struct {
Initialized bool `json:"initialized"`
}
-// ProjectNeedsInitialization checks if the current project needs initialization
+// TODO: we need to remove the global config instance keeping it now just until everything is migrated
+var (
+ instance atomic.Pointer[Config]
+ cwd string
+ once sync.Once // Ensures the initialization happens only once
+)
+
+func Init(workingDir string, debug bool) (*Config, error) {
+ var err error
+ once.Do(func() {
+ cwd = workingDir
+ cfg, err := Load(cwd, debug)
+ if err != nil {
+ logging.Error("Failed to load config", "error", err)
+ }
+ instance.Store(cfg)
+ })
+
+ return instance.Load(), err
+}
+
+func Get() *Config {
+ return instance.Load()
+}
+
func ProjectNeedsInitialization() (bool, error) {
- if instance == nil {
+ cfg := Get()
+ if cfg == nil {
return false, fmt.Errorf("config not loaded")
}
- flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
+ flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename)
- // Check if the flag file exists
_, err := os.Stat(flagFilePath)
if err == nil {
return false, nil
@@ -35,8 +61,7 @@ func ProjectNeedsInitialization() (bool, error) {
return false, fmt.Errorf("failed to check init flag file: %w", err)
}
- // Check if any variation of CRUSH.md already exists in working directory
- crushExists, err := crushMdExists(WorkingDirectory())
+ crushExists, err := crushMdExists(cfg.WorkingDir())
if err != nil {
return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err)
}
@@ -47,7 +72,6 @@ func ProjectNeedsInitialization() (bool, error) {
return true, nil
}
-// crushMdExists checks if any case variation of crush.md exists in the directory
func crushMdExists(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
@@ -68,12 +92,12 @@ func crushMdExists(dir string) (bool, error) {
return false, nil
}
-// MarkProjectInitialized marks the current project as initialized
func MarkProjectInitialized() error {
- if instance == nil {
+ cfg := Get()
+ if cfg == nil {
return fmt.Errorf("config not loaded")
}
- flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
+ flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename)
file, err := os.Create(flagFilePath)
if err != nil {
@@ -10,10 +10,10 @@ import (
"slices"
"strings"
+ "github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
- "github.com/charmbracelet/crush/pkg/env"
- "github.com/charmbracelet/crush/pkg/log"
+ "github.com/charmbracelet/crush/internal/log"
"golang.org/x/exp/slog"
)
@@ -68,6 +68,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
env := env.New()
// Configure providers
valueResolver := NewShellVariableResolver(env)
+ cfg.resolver = valueResolver
if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
return nil, fmt.Errorf("failed to configure providers: %w", err)
}
@@ -81,6 +82,36 @@ func Load(workingDir string, debug bool) (*Config, error) {
return nil, fmt.Errorf("failed to configure selected models: %w", err)
}
+ // TODO: remove the agents concept from the config
+ agents := map[string]Agent{
+ "coder": {
+ ID: "coder",
+ Name: "Coder",
+ Description: "An agent that helps with executing coding tasks.",
+ Model: SelectedModelTypeLarge,
+ ContextPaths: cfg.Options.ContextPaths,
+ // All tools allowed
+ },
+ "task": {
+ ID: "task",
+ Name: "Task",
+ Description: "An agent that helps with searching for context and finding implementation details.",
+ Model: SelectedModelTypeLarge,
+ ContextPaths: cfg.Options.ContextPaths,
+ AllowedTools: []string{
+ "glob",
+ "grep",
+ "ls",
+ "sourcegraph",
+ "view",
+ },
+ // NO MCPs or LSPs by default
+ AllowedMCP: map[string][]string{},
+ AllowedLSP: []string{},
+ },
+ }
+ cfg.Agents = agents
+
return cfg, nil
}
@@ -8,7 +8,7 @@ import (
"testing"
"github.com/charmbracelet/crush/internal/fur/provider"
- "github.com/charmbracelet/crush/pkg/env"
+ "github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/assert"
)
@@ -4,27 +4,44 @@ import (
"encoding/json"
"os"
"path/filepath"
+ "runtime"
"sync"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
)
-var fur = client.New()
+type ProviderClient interface {
+ GetProviders() ([]provider.Provider, error)
+}
var (
- providerOnc sync.Once // Ensures the initialization happens only once
+ providerOnce sync.Once
providerList []provider.Provider
- // UseMockProviders can be set to true in tests to avoid API calls
- UseMockProviders bool
)
-func providersPath() string {
- return filepath.Join(baseDataPath(), "providers.json")
+// file to cache provider data
+func providerCacheFileData() string {
+ xdgDataHome := os.Getenv("XDG_DATA_HOME")
+ if xdgDataHome != "" {
+ return filepath.Join(xdgDataHome, appName)
+ }
+
+ // return the path to the main data directory
+ // for windows, it should be in `%LOCALAPPDATA%/crush/`
+ // for linux and macOS, it should be in `$HOME/.local/share/crush/`
+ if runtime.GOOS == "windows" {
+ localAppData := os.Getenv("LOCALAPPDATA")
+ if localAppData == "" {
+ localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
+ }
+ return filepath.Join(localAppData, appName)
+ }
+
+ return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
}
-func saveProviders(providers []provider.Provider) error {
- path := providersPath()
+func saveProvidersInCache(path string, providers []provider.Provider) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
@@ -38,8 +55,7 @@ func saveProviders(providers []provider.Provider) error {
return os.WriteFile(path, data, 0o644)
}
-func loadProviders() ([]provider.Provider, error) {
- path := providersPath()
+func loadProvidersFromCache(path string) ([]provider.Provider, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
@@ -50,34 +66,33 @@ func loadProviders() ([]provider.Provider, error) {
return providers, err
}
-func Providers() []provider.Provider {
- providerOnc.Do(func() {
- // Use mock providers when testing
- if UseMockProviders {
- providerList = MockProviders()
- return
+func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
+ providers, err := client.GetProviders()
+ if err != nil {
+ fallbackToCache, err := loadProvidersFromCache(path)
+ if err != nil {
+ return nil, err
}
-
- // Try to get providers from upstream API
- if providers, err := fur.GetProviders(); err == nil {
- providerList = providers
- // Save providers locally for future fallback
- _ = saveProviders(providers)
- } else {
- // If upstream fails, try to load from local cache
- if localProviders, localErr := loadProviders(); localErr == nil {
- providerList = localProviders
- } else {
- // If both fail, return empty list
- providerList = []provider.Provider{}
- }
+ providers = fallbackToCache
+ } else {
+ if err := saveProvidersInCache(path, providerList); err != nil {
+ return nil, err
}
- })
- return providerList
+ }
+ return providers, nil
+}
+
+func Providers() ([]provider.Provider, error) {
+ return LoadProviders(client.New())
}
-// ResetProviders resets the provider cache. Useful for testing.
-func ResetProviders() {
- providerOnc = sync.Once{}
- providerList = nil
+func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
+ var err error
+ providerOnce.Do(func() {
+ providerList, err = loadProviders(providerCacheFileData(), client)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return providerList, nil
}
@@ -1,293 +0,0 @@
-package config
-
-import (
- "github.com/charmbracelet/crush/internal/fur/provider"
-)
-
-// MockProviders returns a mock list of providers for testing.
-// This avoids making API calls during tests and provides consistent test data.
-// Simplified version with only default models from each provider.
-func MockProviders() []provider.Provider {
- return []provider.Provider{
- {
- Name: "Anthropic",
- ID: provider.InferenceProviderAnthropic,
- APIKey: "$ANTHROPIC_API_KEY",
- APIEndpoint: "$ANTHROPIC_API_ENDPOINT",
- Type: provider.TypeAnthropic,
- DefaultLargeModelID: "claude-sonnet-4-20250514",
- DefaultSmallModelID: "claude-3-5-haiku-20241022",
- Models: []provider.Model{
- {
- ID: "claude-sonnet-4-20250514",
- Name: "Claude Sonnet 4",
- CostPer1MIn: 3.0,
- CostPer1MOut: 15.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.3,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsImages: true,
- },
- {
- ID: "claude-3-5-haiku-20241022",
- Name: "Claude 3.5 Haiku",
- CostPer1MIn: 0.8,
- CostPer1MOut: 4.0,
- CostPer1MInCached: 1.0,
- CostPer1MOutCached: 0.08,
- ContextWindow: 200000,
- DefaultMaxTokens: 5000,
- CanReason: false,
- SupportsImages: true,
- },
- },
- },
- {
- Name: "OpenAI",
- ID: provider.InferenceProviderOpenAI,
- APIKey: "$OPENAI_API_KEY",
- APIEndpoint: "$OPENAI_API_ENDPOINT",
- Type: provider.TypeOpenAI,
- DefaultLargeModelID: "codex-mini-latest",
- DefaultSmallModelID: "gpt-4o",
- Models: []provider.Model{
- {
- ID: "codex-mini-latest",
- Name: "Codex Mini",
- CostPer1MIn: 1.5,
- CostPer1MOut: 6.0,
- CostPer1MInCached: 0.0,
- CostPer1MOutCached: 0.375,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- HasReasoningEffort: true,
- DefaultReasoningEffort: "medium",
- SupportsImages: true,
- },
- {
- ID: "gpt-4o",
- Name: "GPT-4o",
- CostPer1MIn: 2.5,
- CostPer1MOut: 10.0,
- CostPer1MInCached: 0.0,
- CostPer1MOutCached: 1.25,
- ContextWindow: 128000,
- DefaultMaxTokens: 20000,
- CanReason: false,
- SupportsImages: true,
- },
- },
- },
- {
- Name: "Google Gemini",
- ID: provider.InferenceProviderGemini,
- APIKey: "$GEMINI_API_KEY",
- APIEndpoint: "$GEMINI_API_ENDPOINT",
- Type: provider.TypeGemini,
- DefaultLargeModelID: "gemini-2.5-pro",
- DefaultSmallModelID: "gemini-2.5-flash",
- Models: []provider.Model{
- {
- ID: "gemini-2.5-pro",
- Name: "Gemini 2.5 Pro",
- CostPer1MIn: 1.25,
- CostPer1MOut: 10.0,
- CostPer1MInCached: 1.625,
- CostPer1MOutCached: 0.31,
- ContextWindow: 1048576,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsImages: true,
- },
- {
- ID: "gemini-2.5-flash",
- Name: "Gemini 2.5 Flash",
- CostPer1MIn: 0.3,
- CostPer1MOut: 2.5,
- CostPer1MInCached: 0.3833,
- CostPer1MOutCached: 0.075,
- ContextWindow: 1048576,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsImages: true,
- },
- },
- },
- {
- Name: "xAI",
- ID: provider.InferenceProviderXAI,
- APIKey: "$XAI_API_KEY",
- APIEndpoint: "https://api.x.ai/v1",
- Type: provider.TypeXAI,
- DefaultLargeModelID: "grok-3",
- DefaultSmallModelID: "grok-3-mini",
- Models: []provider.Model{
- {
- ID: "grok-3",
- Name: "Grok 3",
- CostPer1MIn: 3.0,
- CostPer1MOut: 15.0,
- CostPer1MInCached: 0.0,
- CostPer1MOutCached: 0.75,
- ContextWindow: 131072,
- DefaultMaxTokens: 20000,
- CanReason: false,
- SupportsImages: false,
- },
- {
- ID: "grok-3-mini",
- Name: "Grok 3 Mini",
- CostPer1MIn: 0.3,
- CostPer1MOut: 0.5,
- CostPer1MInCached: 0.0,
- CostPer1MOutCached: 0.075,
- ContextWindow: 131072,
- DefaultMaxTokens: 20000,
- CanReason: true,
- SupportsImages: false,
- },
- },
- },
- {
- Name: "Azure OpenAI",
- ID: provider.InferenceProviderAzure,
- APIKey: "$AZURE_OPENAI_API_KEY",
- APIEndpoint: "$AZURE_OPENAI_API_ENDPOINT",
- Type: provider.TypeAzure,
- DefaultLargeModelID: "o4-mini",
- DefaultSmallModelID: "gpt-4o",
- Models: []provider.Model{
- {
- ID: "o4-mini",
- Name: "o4 Mini",
- CostPer1MIn: 1.1,
- CostPer1MOut: 4.4,
- CostPer1MInCached: 0.0,
- CostPer1MOutCached: 0.275,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- HasReasoningEffort: false,
- DefaultReasoningEffort: "medium",
- SupportsImages: true,
- },
- {
- ID: "gpt-4o",
- Name: "GPT-4o",
- CostPer1MIn: 2.5,
- CostPer1MOut: 10.0,
- CostPer1MInCached: 0.0,
- CostPer1MOutCached: 1.25,
- ContextWindow: 128000,
- DefaultMaxTokens: 20000,
- CanReason: false,
- SupportsImages: true,
- },
- },
- },
- {
- Name: "AWS Bedrock",
- ID: provider.InferenceProviderBedrock,
- Type: provider.TypeBedrock,
- DefaultLargeModelID: "anthropic.claude-sonnet-4-20250514-v1:0",
- DefaultSmallModelID: "anthropic.claude-3-5-haiku-20241022-v1:0",
- Models: []provider.Model{
- {
- ID: "anthropic.claude-sonnet-4-20250514-v1:0",
- Name: "AWS Claude Sonnet 4",
- CostPer1MIn: 3.0,
- CostPer1MOut: 15.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.3,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsImages: true,
- },
- {
- ID: "anthropic.claude-3-5-haiku-20241022-v1:0",
- Name: "AWS Claude 3.5 Haiku",
- CostPer1MIn: 0.8,
- CostPer1MOut: 4.0,
- CostPer1MInCached: 1.0,
- CostPer1MOutCached: 0.08,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: false,
- SupportsImages: true,
- },
- },
- },
- {
- Name: "Google Vertex AI",
- ID: provider.InferenceProviderVertexAI,
- Type: provider.TypeVertexAI,
- DefaultLargeModelID: "gemini-2.5-pro",
- DefaultSmallModelID: "gemini-2.5-flash",
- Models: []provider.Model{
- {
- ID: "gemini-2.5-pro",
- Name: "Gemini 2.5 Pro",
- CostPer1MIn: 1.25,
- CostPer1MOut: 10.0,
- CostPer1MInCached: 1.625,
- CostPer1MOutCached: 0.31,
- ContextWindow: 1048576,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsImages: true,
- },
- {
- ID: "gemini-2.5-flash",
- Name: "Gemini 2.5 Flash",
- CostPer1MIn: 0.3,
- CostPer1MOut: 2.5,
- CostPer1MInCached: 0.3833,
- CostPer1MOutCached: 0.075,
- ContextWindow: 1048576,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsImages: true,
- },
- },
- },
- {
- Name: "OpenRouter",
- ID: provider.InferenceProviderOpenRouter,
- APIKey: "$OPENROUTER_API_KEY",
- APIEndpoint: "https://openrouter.ai/api/v1",
- Type: provider.TypeOpenAI,
- DefaultLargeModelID: "anthropic/claude-sonnet-4",
- DefaultSmallModelID: "anthropic/claude-haiku-3.5",
- Models: []provider.Model{
- {
- ID: "anthropic/claude-sonnet-4",
- Name: "Anthropic: Claude Sonnet 4",
- CostPer1MIn: 3.0,
- CostPer1MOut: 15.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.3,
- ContextWindow: 200000,
- DefaultMaxTokens: 32000,
- CanReason: true,
- SupportsImages: true,
- },
- {
- ID: "anthropic/claude-haiku-3.5",
- Name: "Anthropic: Claude 3.5 Haiku",
- CostPer1MIn: 0.8,
- CostPer1MOut: 4.0,
- CostPer1MInCached: 1.0,
- CostPer1MOutCached: 0.08,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- CanReason: false,
- SupportsImages: true,
- },
- },
- },
- }
-}
@@ -1,81 +1,73 @@
package config
import (
+ "encoding/json"
+ "errors"
+ "os"
"testing"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
)
-func TestProviders_MockEnabled(t *testing.T) {
- originalUseMock := UseMockProviders
- UseMockProviders = true
- defer func() {
- UseMockProviders = originalUseMock
- ResetProviders()
- }()
-
- ResetProviders()
- providers := Providers()
- require.NotEmpty(t, providers)
+type mockProviderClient struct {
+ shouldFail bool
+}
- providerIDs := make(map[provider.InferenceProvider]bool)
- for _, p := range providers {
- providerIDs[p.ID] = true
+func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
+ if m.shouldFail {
+ return nil, errors.New("failed to load providers")
}
-
- assert.True(t, providerIDs[provider.InferenceProviderAnthropic])
- assert.True(t, providerIDs[provider.InferenceProviderOpenAI])
- assert.True(t, providerIDs[provider.InferenceProviderGemini])
+ return []provider.Provider{
+ {
+ Name: "Mock",
+ },
+ }, nil
}
-func TestProviders_ResetFunctionality(t *testing.T) {
- UseMockProviders = true
- defer func() {
- UseMockProviders = false
- ResetProviders()
- }()
-
- providers1 := Providers()
- require.NotEmpty(t, providers1)
-
- ResetProviders()
- providers2 := Providers()
- require.NotEmpty(t, providers2)
+func TestProvider_loadProvidersNoIssues(t *testing.T) {
+ client := &mockProviderClient{shouldFail: false}
+ tmpPath := t.TempDir() + "/providers.json"
+ providers, err := loadProviders(tmpPath, client)
+ assert.NoError(t, err)
+ assert.NotNil(t, providers)
+ assert.Len(t, providers, 1)
- assert.Equal(t, len(providers1), len(providers2))
+ // check if file got saved
+ fileInfo, err := os.Stat(tmpPath)
+ assert.NoError(t, err)
+ assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
}
-func TestProviders_ModelCapabilities(t *testing.T) {
- originalUseMock := UseMockProviders
- UseMockProviders = true
- defer func() {
- UseMockProviders = originalUseMock
- ResetProviders()
- }()
-
- ResetProviders()
- providers := Providers()
-
- var openaiProvider provider.Provider
- for _, p := range providers {
- if p.ID == provider.InferenceProviderOpenAI {
- openaiProvider = p
- break
- }
+func TestProvider_loadProvidersWithIssues(t *testing.T) {
+ client := &mockProviderClient{shouldFail: true}
+ tmpPath := t.TempDir() + "/providers.json"
+ // store providers to a temporary file
+ oldProviders := []provider.Provider{
+ {
+ Name: "OldProvider",
+ },
+ }
+ data, err := json.Marshal(oldProviders)
+ if err != nil {
+ t.Fatalf("Failed to marshal old providers: %v", err)
}
- require.NotEmpty(t, openaiProvider.ID)
- var foundReasoning, foundNonReasoning bool
- for _, model := range openaiProvider.Models {
- if model.CanReason && model.HasReasoningEffort {
- foundReasoning = true
- } else if !model.CanReason {
- foundNonReasoning = true
- }
+ err = os.WriteFile(tmpPath, data, 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write old providers to file: %v", err)
}
+ providers, err := loadProviders(tmpPath, client)
+ assert.NoError(t, err)
+ assert.NotNil(t, providers)
+ assert.Len(t, providers, 1)
+ assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
+}
- assert.True(t, foundReasoning)
- assert.True(t, foundNonReasoning)
+func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
+ client := &mockProviderClient{shouldFail: true}
+ tmpPath := t.TempDir() + "/providers.json"
+ providers, err := loadProviders(tmpPath, client)
+ assert.Error(t, err)
+ assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
}
@@ -7,7 +7,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/shell"
- "github.com/charmbracelet/crush/pkg/env"
+ "github.com/charmbracelet/crush/internal/env"
)
type VariableResolver interface {
@@ -5,7 +5,7 @@ import (
"errors"
"testing"
- "github.com/charmbracelet/crush/pkg/env"
+ "github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/assert"
)
@@ -1,73 +0,0 @@
-package config
-
-import (
- "context"
- "fmt"
- "os"
- "strings"
- "time"
-
- "github.com/charmbracelet/crush/internal/logging"
- "github.com/charmbracelet/crush/internal/shell"
-)
-
-// ExecuteCommand executes a shell command and returns the output
-// This is a shared utility that can be used by both provider config and tools
-func ExecuteCommand(ctx context.Context, command string, workingDir string) (string, error) {
- if workingDir == "" {
- workingDir = WorkingDirectory()
- }
-
- persistentShell := shell.NewShell(&shell.Options{WorkingDir: workingDir})
-
- stdout, stderr, err := persistentShell.Exec(ctx, command)
- if err != nil {
- logging.Debug("Command execution failed", "command", command, "error", err, "stderr", stderr)
- return "", fmt.Errorf("command execution failed: %w", err)
- }
-
- return strings.TrimSpace(stdout), nil
-}
-
-// ResolveAPIKey resolves an API key that can be either:
-// - A direct string value
-// - An environment variable (prefixed with $)
-// - A shell command (wrapped in $(...))
-func ResolveAPIKey(apiKey string) (string, error) {
- if !strings.HasPrefix(apiKey, "$") {
- return apiKey, nil
- }
-
- if strings.HasPrefix(apiKey, "$(") && strings.HasSuffix(apiKey, ")") {
- command := strings.TrimSuffix(strings.TrimPrefix(apiKey, "$("), ")")
- logging.Debug("Resolving API key from command", "command", command)
- return resolveCommandAPIKey(command)
- }
-
- envVar := strings.TrimPrefix(apiKey, "$")
- if value := os.Getenv(envVar); value != "" {
- logging.Debug("Resolved environment variable", "envVar", envVar, "value", value)
- return value, nil
- }
-
- logging.Debug("Environment variable not found", "envVar", envVar)
-
- return "", fmt.Errorf("environment variable %s not found", envVar)
-}
-
-// resolveCommandAPIKey executes a command to get an API key, with caching support
-func resolveCommandAPIKey(command string) (string, error) {
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
-
- logging.Debug("Executing command for API key", "command", command)
-
- workingDir := WorkingDirectory()
-
- result, err := ExecuteCommand(ctx, command, workingDir)
- if err != nil {
- return "", fmt.Errorf("failed to execute API key command: %w", err)
- }
- logging.Debug("Command executed successfully", "command", command, "result", result)
- return result, nil
-}
@@ -1,462 +0,0 @@
-package config
-
-import (
- "testing"
-
- "github.com/charmbracelet/crush/internal/fur/provider"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestConfig_Validate_ValidConfig(t *testing.T) {
- cfg := &Config{
- Models: PreferredModels{
- Large: PreferredModel{
- ModelID: "gpt-4",
- Provider: provider.InferenceProviderOpenAI,
- },
- Small: PreferredModel{
- ModelID: "gpt-3.5-turbo",
- Provider: provider.InferenceProviderOpenAI,
- },
- },
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "gpt-4",
- DefaultSmallModel: "gpt-3.5-turbo",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- CostPer1MIn: 30.0,
- CostPer1MOut: 60.0,
- },
- {
- ID: "gpt-3.5-turbo",
- Name: "GPT-3.5 Turbo",
- ContextWindow: 4096,
- DefaultMaxTokens: 2048,
- CostPer1MIn: 1.5,
- CostPer1MOut: 2.0,
- },
- },
- },
- },
- Agents: map[AgentID]Agent{
- AgentCoder: {
- ID: AgentCoder,
- Name: "Coder",
- Description: "An agent that helps with executing coding tasks.",
- Model: LargeModel,
- ContextPaths: []string{"CRUSH.md"},
- },
- AgentTask: {
- ID: AgentTask,
- Name: "Task",
- Description: "An agent that helps with searching for context and finding implementation details.",
- Model: LargeModel,
- ContextPaths: []string{"CRUSH.md"},
- AllowedTools: []string{"glob", "grep", "ls", "sourcegraph", "view"},
- AllowedMCP: map[string][]string{},
- AllowedLSP: []string{},
- },
- },
- MCP: map[string]MCP{},
- LSP: map[string]LSPConfig{},
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- assert.NoError(t, err)
-}
-
-func TestConfig_Validate_MissingAPIKey(t *testing.T) {
- cfg := &Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- ProviderType: provider.TypeOpenAI,
- // Missing APIKey
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "API key is required")
-}
-
-func TestConfig_Validate_InvalidProviderType(t *testing.T) {
- cfg := &Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.Type("invalid"),
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "invalid provider type")
-}
-
-func TestConfig_Validate_CustomProviderMissingBaseURL(t *testing.T) {
- customProvider := provider.InferenceProvider("custom-provider")
- cfg := &Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- customProvider: {
- ID: customProvider,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- // Missing BaseURL for custom provider
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "BaseURL is required for custom providers")
-}
-
-func TestConfig_Validate_DuplicateModelIDs(t *testing.T) {
- cfg := &Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- {
- ID: "gpt-4", // Duplicate ID
- Name: "GPT-4 Duplicate",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "duplicate model ID")
-}
-
-func TestConfig_Validate_InvalidModelFields(t *testing.T) {
- cfg := &Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- Models: []Model{
- {
- ID: "", // Empty ID
- Name: "GPT-4",
- ContextWindow: 0, // Invalid context window
- DefaultMaxTokens: -1, // Invalid max tokens
- CostPer1MIn: -5.0, // Negative cost
- },
- },
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- validationErr := err.(ValidationErrors)
- assert.True(t, len(validationErr) >= 4) // Should have multiple validation errors
-}
-
-func TestConfig_Validate_DefaultModelNotFound(t *testing.T) {
- cfg := &Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- DefaultLargeModel: "nonexistent-model",
- Models: []Model{
- {
- ID: "gpt-4",
- Name: "GPT-4",
- ContextWindow: 8192,
- DefaultMaxTokens: 4096,
- },
- },
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "default large model 'nonexistent-model' not found")
-}
-
-func TestConfig_Validate_AgentIDMismatch(t *testing.T) {
- cfg := &Config{
- Agents: map[AgentID]Agent{
- AgentCoder: {
- ID: AgentTask, // Wrong ID
- Name: "Coder",
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "agent ID mismatch")
-}
-
-func TestConfig_Validate_InvalidAgentModelType(t *testing.T) {
- cfg := &Config{
- Agents: map[AgentID]Agent{
- AgentCoder: {
- ID: AgentCoder,
- Name: "Coder",
- Model: ModelType("invalid"),
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "invalid model type")
-}
-
-func TestConfig_Validate_UnknownTool(t *testing.T) {
- cfg := &Config{
- Agents: map[AgentID]Agent{
- AgentID("custom-agent"): {
- ID: AgentID("custom-agent"),
- Name: "Custom Agent",
- Model: LargeModel,
- AllowedTools: []string{"unknown-tool"},
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "unknown tool")
-}
-
-func TestConfig_Validate_MCPReference(t *testing.T) {
- cfg := &Config{
- Agents: map[AgentID]Agent{
- AgentID("custom-agent"): {
- ID: AgentID("custom-agent"),
- Name: "Custom Agent",
- Model: LargeModel,
- AllowedMCP: map[string][]string{"nonexistent-mcp": nil},
- },
- },
- MCP: map[string]MCP{}, // Empty MCP map
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "referenced MCP 'nonexistent-mcp' not found")
-}
-
-func TestConfig_Validate_InvalidMCPType(t *testing.T) {
- cfg := &Config{
- MCP: map[string]MCP{
- "test-mcp": {
- Type: MCPType("invalid"),
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "invalid MCP type")
-}
-
-func TestConfig_Validate_MCPMissingCommand(t *testing.T) {
- cfg := &Config{
- MCP: map[string]MCP{
- "test-mcp": {
- Type: MCPStdio,
- // Missing Command
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "command is required for stdio MCP")
-}
-
-func TestConfig_Validate_LSPMissingCommand(t *testing.T) {
- cfg := &Config{
- LSP: map[string]LSPConfig{
- "test-lsp": {
- // Missing Command
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "command is required for LSP")
-}
-
-func TestConfig_Validate_NoValidProviders(t *testing.T) {
- cfg := &Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- Disabled: true, // Disabled
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "at least one non-disabled provider is required")
-}
-
-func TestConfig_Validate_MissingDefaultAgents(t *testing.T) {
- cfg := &Config{
- Providers: map[provider.InferenceProvider]ProviderConfig{
- provider.InferenceProviderOpenAI: {
- ID: provider.InferenceProviderOpenAI,
- APIKey: "test-key",
- ProviderType: provider.TypeOpenAI,
- },
- },
- Agents: map[AgentID]Agent{}, // Missing default agents
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "coder agent is required")
- assert.Contains(t, err.Error(), "task agent is required")
-}
-
-func TestConfig_Validate_KnownAgentProtection(t *testing.T) {
- cfg := &Config{
- Agents: map[AgentID]Agent{
- AgentCoder: {
- ID: AgentCoder,
- Name: "Modified Coder", // Should not be allowed
- Description: "Modified description", // Should not be allowed
- Model: LargeModel,
- },
- },
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "coder agent name cannot be changed")
- assert.Contains(t, err.Error(), "coder agent description cannot be changed")
-}
-
-func TestConfig_Validate_EmptyDataDirectory(t *testing.T) {
- cfg := &Config{
- Options: Options{
- DataDirectory: "", // Empty
- ContextPaths: []string{"CRUSH.md"},
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "data directory is required")
-}
-
-func TestConfig_Validate_EmptyContextPath(t *testing.T) {
- cfg := &Config{
- Options: Options{
- DataDirectory: ".crush",
- ContextPaths: []string{""}, // Empty context path
- },
- }
-
- err := cfg.Validate()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "context path cannot be empty")
-}
@@ -11,7 +11,7 @@ import (
func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) {
// remove the cwd prefix and ensure consistent path format
// this prevents issues with absolute paths in different environments
- cwd := config.WorkingDirectory()
+ cwd := config.Get().WorkingDir()
fileName = strings.TrimPrefix(fileName, cwd)
fileName = strings.TrimPrefix(fileName, "/")
@@ -10,6 +10,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
+ fur "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/llm/provider"
@@ -49,7 +50,7 @@ type AgentEvent struct {
type Service interface {
pubsub.Suscriber[AgentEvent]
- Model() config.Model
+ Model() fur.Model
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
Cancel(sessionID string)
CancelAll()
@@ -76,9 +77,9 @@ type agent struct {
activeRequests sync.Map
}
-var agentPromptMap = map[config.AgentID]prompt.PromptID{
- config.AgentCoder: prompt.PromptCoder,
- config.AgentTask: prompt.PromptTask,
+var agentPromptMap = map[string]prompt.PromptID{
+ "coder": prompt.PromptCoder,
+ "task": prompt.PromptTask,
}
func NewAgent(
@@ -109,8 +110,8 @@ func NewAgent(
tools.NewWriteTool(lspClients, permissions, history),
}
- if agentCfg.ID == config.AgentCoder {
- taskAgentCfg := config.Get().Agents[config.AgentTask]
+ if agentCfg.ID == "coder" {
+ taskAgentCfg := config.Get().Agents["task"]
if taskAgentCfg.ID == "" {
return nil, fmt.Errorf("task agent not found in config")
}
@@ -130,13 +131,13 @@ func NewAgent(
}
allTools = append(allTools, otherTools...)
- providerCfg := config.GetAgentProvider(agentCfg.ID)
- if providerCfg.ID == "" {
+ providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
+ if providerCfg == nil {
return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
}
- model := config.GetAgentModel(agentCfg.ID)
+ model := config.Get().GetModelByType(agentCfg.Model)
- if model.ID == "" {
+ if model == nil {
return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
}
@@ -148,51 +149,40 @@ func NewAgent(
provider.WithModel(agentCfg.Model),
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
}
- agentProvider, err := provider.NewProvider(providerCfg, opts...)
+ agentProvider, err := provider.NewProvider(*providerCfg, opts...)
if err != nil {
return nil, err
}
- smallModelCfg := cfg.Models.Small
- var smallModel config.Model
-
- var smallModelProviderCfg config.ProviderConfig
+ smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
+ var smallModelProviderCfg *config.ProviderConfig
if smallModelCfg.Provider == providerCfg.ID {
smallModelProviderCfg = providerCfg
} else {
- for _, p := range cfg.Providers {
- if p.ID == smallModelCfg.Provider {
- smallModelProviderCfg = p
- break
- }
- }
+ smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
+
if smallModelProviderCfg.ID == "" {
return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
}
}
- for _, m := range smallModelProviderCfg.Models {
- if m.ID == smallModelCfg.ModelID {
- smallModel = m
- break
- }
- }
+ smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
if smallModel.ID == "" {
- return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
+ return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
}
titleOpts := []provider.ProviderClientOption{
- provider.WithModel(config.SmallModel),
+ provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
}
- titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
+ titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
if err != nil {
return nil, err
}
summarizeOpts := []provider.ProviderClientOption{
- provider.WithModel(config.SmallModel),
+ provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
}
- summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
+ summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
if err != nil {
return nil, err
}
@@ -225,8 +215,8 @@ func NewAgent(
return agent, nil
}
-func (a *agent) Model() config.Model {
- return config.GetAgentModel(a.agentCfg.ID)
+func (a *agent) Model() fur.Model {
+ return *config.Get().GetModelByType(a.agentCfg.Model)
}
func (a *agent) Cancel(sessionID string) {
@@ -610,7 +600,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
return nil
}
-func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -819,7 +809,7 @@ func (a *agent) UpdateModel() error {
cfg := config.Get()
// Get current provider configuration
- currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
+ currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
if currentProviderCfg.ID == "" {
return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
}
@@ -827,7 +817,7 @@ func (a *agent) UpdateModel() error {
// Check if provider has changed
if string(currentProviderCfg.ID) != a.providerID {
// Provider changed, need to recreate the main provider
- model := config.GetAgentModel(a.agentCfg.ID)
+ model := cfg.GetModelByType(a.agentCfg.Model)
if model.ID == "" {
return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
}
@@ -842,7 +832,7 @@ func (a *agent) UpdateModel() error {
provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
}
- newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
+ newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
if err != nil {
return fmt.Errorf("failed to create new provider: %w", err)
}
@@ -853,7 +843,7 @@ func (a *agent) UpdateModel() error {
}
// Check if small model provider has changed (affects title and summarize providers)
- smallModelCfg := cfg.Models.Small
+ smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
var smallModelProviderCfg config.ProviderConfig
for _, p := range cfg.Providers {
@@ -869,20 +859,14 @@ func (a *agent) UpdateModel() error {
// Check if summarize provider has changed
if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
- var smallModel config.Model
- for _, m := range smallModelProviderCfg.Models {
- if m.ID == smallModelCfg.ModelID {
- smallModel = m
- break
- }
- }
- if smallModel.ID == "" {
- return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
+ smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
+ if smallModel == nil {
+ return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
}
// Recreate title provider
titleOpts := []provider.ProviderClientOption{
- provider.WithModel(config.SmallModel),
+ provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
// We want the title to be short, so we limit the max tokens
provider.WithMaxTokens(40),
@@ -894,7 +878,7 @@ func (a *agent) UpdateModel() error {
// Recreate summarize provider
summarizeOpts := []provider.ProviderClientOption{
- provider.WithModel(config.SmallModel),
+ provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
}
newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
@@ -19,7 +19,7 @@ import (
type mcpTool struct {
mcpName string
tool mcp.Tool
- mcpConfig config.MCP
+ mcpConfig config.MCPConfig
permissions permission.Service
}
@@ -97,7 +97,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
p := b.permissions.Request(
permission.CreatePermissionRequest{
SessionID: sessionID,
- Path: config.WorkingDirectory(),
+ Path: config.Get().WorkingDir(),
ToolName: b.Info().Name,
Action: "execute",
Description: permissionDescription,
@@ -142,7 +142,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("invalid mcp type"), nil
}
-func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
+func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool {
return &mcpTool{
mcpName: name,
tool: tool,
@@ -153,7 +153,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC
var mcpTools []tools.BaseTool
-func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
+func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
@@ -14,12 +14,12 @@ import (
"github.com/charmbracelet/crush/internal/logging"
)
-func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string {
+func CoderPrompt(p string, contextFiles ...string) string {
var basePrompt string
switch p {
- case provider.InferenceProviderOpenAI:
+ case string(provider.InferenceProviderOpenAI):
basePrompt = baseOpenAICoderPrompt
- case provider.InferenceProviderGemini, provider.InferenceProviderVertexAI:
+ case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI):
basePrompt = baseGeminiCoderPrompt
default:
basePrompt = baseAnthropicCoderPrompt
@@ -380,7 +380,7 @@ Your core function is efficient and safe assistance. Balance extreme conciseness
`
func getEnvironmentInfo() string {
- cwd := config.WorkingDirectory()
+ cwd := config.Get().WorkingDir()
isGit := isGitRepo(cwd)
platform := runtime.GOOS
date := time.Now().Format("1/2/2006")
@@ -7,7 +7,6 @@ import (
"sync"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/fur/provider"
)
type PromptID string
@@ -20,17 +19,17 @@ const (
PromptDefault PromptID = "default"
)
-func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string {
+func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string {
basePrompt := ""
switch promptID {
case PromptCoder:
basePrompt = CoderPrompt(provider)
case PromptTitle:
- basePrompt = TitlePrompt(provider)
+ basePrompt = TitlePrompt()
case PromptTask:
- basePrompt = TaskPrompt(provider)
+ basePrompt = TaskPrompt()
case PromptSummarizer:
- basePrompt = SummarizerPrompt(provider)
+ basePrompt = SummarizerPrompt()
default:
basePrompt = "You are a helpful assistant"
}
@@ -38,7 +37,7 @@ func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPa
}
func getContextFromPaths(contextPaths []string) string {
- return processContextPaths(config.WorkingDirectory(), contextPaths)
+ return processContextPaths(config.Get().WorkingDir(), contextPaths)
}
func processContextPaths(workDir string, paths []string) string {
@@ -1,10 +1,6 @@
package prompt
-import (
- "github.com/charmbracelet/crush/internal/fur/provider"
-)
-
-func SummarizerPrompt(_ provider.InferenceProvider) string {
+func SummarizerPrompt() string {
return `You are a helpful AI assistant tasked with summarizing conversations.
When asked to summarize, provide a detailed but concise summary of the conversation.
@@ -2,11 +2,9 @@ package prompt
import (
"fmt"
-
- "github.com/charmbracelet/crush/internal/fur/provider"
)
-func TaskPrompt(_ provider.InferenceProvider) string {
+func TaskPrompt() string {
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
@@ -1,10 +1,6 @@
package prompt
-import (
- "github.com/charmbracelet/crush/internal/fur/provider"
-)
-
-func TitlePrompt(_ provider.InferenceProvider) string {
+func TitlePrompt() string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message
@@ -153,9 +153,9 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
model := a.providerOptions.model(a.providerOptions.modelType)
var thinkingParam anthropic.ThinkingConfigParamUnion
cfg := config.Get()
- modelConfig := cfg.Models.Large
- if a.providerOptions.modelType == config.SmallModel {
- modelConfig = cfg.Models.Small
+ modelConfig := cfg.Models[config.SelectedModelTypeLarge]
+ if a.providerOptions.modelType == config.SelectedModelTypeSmall {
+ modelConfig = cfg.Models[config.SelectedModelTypeSmall]
}
temperature := anthropic.Float(0)
@@ -399,7 +399,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
}
if apiErr.StatusCode == 401 {
- a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
+ a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -490,6 +490,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
}
}
-func (a *anthropicClient) Model() config.Model {
+func (a *anthropicClient) Model() provider.Model {
return a.providerOptions.model(a.providerOptions.modelType)
}
@@ -7,6 +7,7 @@ import (
"strings"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
@@ -31,14 +32,14 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
}
}
- opts.model = func(modelType config.ModelType) config.Model {
- model := config.GetModel(modelType)
+ opts.model = func(modelType config.SelectedModelType) provider.Model {
+ model := config.Get().GetModelByType(modelType)
// Prefix the model name with region
regionPrefix := region[:2]
modelName := model.ID
model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
- return model
+ return *model
}
model := opts.model(opts.modelType)
@@ -87,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
return b.childProvider.stream(ctx, messages, tools)
}
-func (b *bedrockClient) Model() config.Model {
+func (b *bedrockClient) Model() provider.Model {
return b.providerOptions.model(b.providerOptions.modelType)
}
@@ -10,6 +10,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
@@ -170,9 +171,9 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
logging.Debug("Prepared messages", "messages", string(jsonData))
}
- modelConfig := cfg.Models.Large
- if g.providerOptions.modelType == config.SmallModel {
- modelConfig = cfg.Models.Small
+ modelConfig := cfg.Models[config.SelectedModelTypeLarge]
+ if g.providerOptions.modelType == config.SelectedModelTypeSmall {
+ modelConfig = cfg.Models[config.SelectedModelTypeSmall]
}
maxTokens := model.DefaultMaxTokens
@@ -268,9 +269,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
logging.Debug("Prepared messages", "messages", string(jsonData))
}
- modelConfig := cfg.Models.Large
- if g.providerOptions.modelType == config.SmallModel {
- modelConfig = cfg.Models.Small
+ modelConfig := cfg.Models[config.SelectedModelTypeLarge]
+ if g.providerOptions.modelType == config.SelectedModelTypeSmall {
+ modelConfig = cfg.Models[config.SelectedModelTypeSmall]
}
maxTokens := model.DefaultMaxTokens
if modelConfig.MaxTokens > 0 {
@@ -424,7 +425,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
// Check for token expiration (401 Unauthorized)
if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
- g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey)
+ g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -462,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
}
}
-func (g *geminiClient) Model() config.Model {
+func (g *geminiClient) Model() provider.Model {
return g.providerOptions.model(g.providerOptions.modelType)
}
@@ -148,15 +148,12 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
model := o.providerOptions.model(o.providerOptions.modelType)
cfg := config.Get()
- modelConfig := cfg.Models.Large
- if o.providerOptions.modelType == config.SmallModel {
- modelConfig = cfg.Models.Small
+ modelConfig := cfg.Models[config.SelectedModelTypeLarge]
+ if o.providerOptions.modelType == config.SelectedModelTypeSmall {
+ modelConfig = cfg.Models[config.SelectedModelTypeSmall]
}
- reasoningEffort := model.ReasoningEffort
- if modelConfig.ReasoningEffort != "" {
- reasoningEffort = modelConfig.ReasoningEffort
- }
+ reasoningEffort := modelConfig.ReasoningEffort
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(model.ID),
@@ -363,7 +360,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
// Check for token expiration (401 Unauthorized)
if apiErr.StatusCode == 401 {
- o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey)
+ o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -420,6 +417,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
}
}
-func (a *openaiClient) Model() config.Model {
+func (a *openaiClient) Model() provider.Model {
return a.providerOptions.model(a.providerOptions.modelType)
}
@@ -55,15 +55,15 @@ type Provider interface {
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
- Model() config.Model
+ Model() provider.Model
}
type providerClientOptions struct {
baseURL string
config config.ProviderConfig
apiKey string
- modelType config.ModelType
- model func(config.ModelType) config.Model
+ modelType config.SelectedModelType
+ model func(config.SelectedModelType) provider.Model
disableCache bool
systemMessage string
maxTokens int64
@@ -77,7 +77,7 @@ type ProviderClient interface {
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
- Model() config.Model
+ Model() provider.Model
}
type baseProvider[C ProviderClient] struct {
@@ -106,11 +106,11 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
return p.client.stream(ctx, messages, tools)
}
-func (p *baseProvider[C]) Model() config.Model {
+func (p *baseProvider[C]) Model() provider.Model {
return p.client.Model()
}
-func WithModel(model config.ModelType) ProviderClientOption {
+func WithModel(model config.SelectedModelType) ProviderClientOption {
return func(options *providerClientOptions) {
options.modelType = model
}
@@ -135,7 +135,7 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
}
func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
- resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey)
+ resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
if err != nil {
return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
}
@@ -145,14 +145,14 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
config: cfg,
apiKey: resolvedAPIKey,
extraHeaders: cfg.ExtraHeaders,
- model: func(tp config.ModelType) config.Model {
- return config.GetModel(tp)
+ model: func(tp config.SelectedModelType) provider.Model {
+ return *config.Get().GetModelByType(tp)
},
}
for _, o := range opts {
o(&clientOptions)
}
- switch cfg.ProviderType {
+ switch cfg.Type {
case provider.TypeAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
@@ -190,5 +190,5 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
client: newOpenAIClient(clientOptions),
}, nil
}
- return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
+ return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
}
@@ -317,7 +317,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
p := b.permissions.Request(
permission.CreatePermissionRequest{
SessionID: sessionID,
- Path: config.WorkingDirectory(),
+ Path: config.Get().WorkingDir(),
ToolName: BashToolName,
Action: "execute",
Description: fmt.Sprintf("Execute command: %s", params.Command),
@@ -337,7 +337,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
defer cancel()
}
stdout, stderr, err := shell.
- GetPersistentShell(config.WorkingDirectory()).
+ GetPersistentShell(config.Get().WorkingDir()).
Exec(ctx, params.Command)
interrupted := shell.IsInterrupt(err)
exitCode := shell.ExitCode(err)
@@ -143,7 +143,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
}
if !filepath.IsAbs(params.FilePath) {
- wd := config.WorkingDirectory()
+ wd := config.Get().WorkingDir()
params.FilePath = filepath.Join(wd, params.FilePath)
}
@@ -207,7 +207,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
content,
filePath,
)
- rootDir := config.WorkingDirectory()
+ rootDir := config.Get().WorkingDir()
permissionPath := filepath.Dir(filePath)
if strings.HasPrefix(filePath, rootDir) {
permissionPath = rootDir
@@ -320,7 +320,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
filePath,
)
- rootDir := config.WorkingDirectory()
+ rootDir := config.Get().WorkingDir()
permissionPath := filepath.Dir(filePath)
if strings.HasPrefix(filePath, rootDir) {
permissionPath = rootDir
@@ -442,7 +442,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
newContent,
filePath,
)
- rootDir := config.WorkingDirectory()
+ rootDir := config.Get().WorkingDir()
permissionPath := filepath.Dir(filePath)
if strings.HasPrefix(filePath, rootDir) {
permissionPath = rootDir
@@ -133,7 +133,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
p := t.permissions.Request(
permission.CreatePermissionRequest{
SessionID: sessionID,
- Path: config.WorkingDirectory(),
+ Path: config.Get().WorkingDir(),
ToolName: FetchToolName,
Action: "fetch",
Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
@@ -108,7 +108,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
searchPath := params.Path
if searchPath == "" {
- searchPath = config.WorkingDirectory()
+ searchPath = config.Get().WorkingDir()
}
files, truncated, err := globFiles(params.Pattern, searchPath, 100)
@@ -200,7 +200,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
searchPath := params.Path
if searchPath == "" {
- searchPath = config.WorkingDirectory()
+ searchPath = config.Get().WorkingDir()
}
matches, truncated, err := searchFiles(searchPattern, searchPath, params.Include, 100)
@@ -107,11 +107,11 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
searchPath := params.Path
if searchPath == "" {
- searchPath = config.WorkingDirectory()
+ searchPath = config.Get().WorkingDir()
}
if !filepath.IsAbs(searchPath) {
- searchPath = filepath.Join(config.WorkingDirectory(), searchPath)
+ searchPath = filepath.Join(config.Get().WorkingDir(), searchPath)
}
if _, err := os.Stat(searchPath); os.IsNotExist(err) {
@@ -117,7 +117,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
// Handle relative paths
filePath := params.FilePath
if !filepath.IsAbs(filePath) {
- filePath = filepath.Join(config.WorkingDirectory(), filePath)
+ filePath = filepath.Join(config.Get().WorkingDir(), filePath)
}
// Check if file exists
@@ -122,7 +122,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
filePath := params.FilePath
if !filepath.IsAbs(filePath) {
- filePath = filepath.Join(config.WorkingDirectory(), filePath)
+ filePath = filepath.Join(config.Get().WorkingDir(), filePath)
}
fileInfo, err := os.Stat(filePath)
@@ -170,7 +170,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
filePath,
)
- rootDir := config.WorkingDirectory()
+ rootDir := config.Get().WorkingDir()
permissionPath := filepath.Dir(filePath)
if strings.HasPrefix(filePath, rootDir) {
permissionPath = rootDir
@@ -376,7 +376,7 @@ func (c *Client) detectServerType() ServerType {
// openKeyConfigFiles opens important configuration files that help initialize the server
func (c *Client) openKeyConfigFiles(ctx context.Context) {
- workDir := config.WorkingDirectory()
+ workDir := config.Get().WorkingDir()
serverType := c.detectServerType()
var filesToOpen []string
@@ -464,7 +464,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
}
// If we have no open TypeScript files, try to find and open one
- workDir := config.WorkingDirectory()
+ workDir := config.Get().WorkingDir()
err := filepath.WalkDir(workDir, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
@@ -87,7 +87,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
dir := filepath.Dir(opts.Path)
if dir == "." {
- dir = config.WorkingDirectory()
+ dir = config.Get().WorkingDir()
}
permission := PermissionRequest{
ID: uuid.New().String(),
@@ -91,7 +91,7 @@ func (p *header) View() tea.View {
func (h *header) details() string {
t := styles.CurrentTheme()
- cwd := fsext.DirTrim(fsext.PrettyPath(config.WorkingDirectory()), 4)
+ cwd := fsext.DirTrim(fsext.PrettyPath(config.Get().WorkingDir()), 4)
parts := []string{
t.S().Muted.Render(cwd),
}
@@ -111,7 +111,8 @@ func (h *header) details() string {
parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount)))
}
- model := config.GetAgentModel(config.AgentCoder)
+ agentCfg := config.Get().Agents["coder"]
+ model := config.Get().GetModelByType(agentCfg.Model)
percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100
formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage)))
parts = append(parts, formattedPercentage)
@@ -10,7 +10,6 @@ import (
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/tui/components/anim"
"github.com/charmbracelet/crush/internal/tui/components/core"
@@ -296,7 +295,7 @@ func (m *assistantSectionModel) View() tea.View {
duration := finishTime.Sub(m.lastUserMessageTime)
infoMsg := t.S().Subtle.Render(duration.String())
icon := t.S().Subtle.Render(styles.ModelIcon)
- model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model)
+ model := config.Get().GetModel(m.message.Provider, m.message.Model)
modelFormatted := t.S().Muted.Render(model.Name)
assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
return tea.NewView(
@@ -297,7 +297,7 @@ func (m *sidebarCmp) filesBlock() string {
}
extraContent := strings.Join(statusParts, " ")
- cwd := config.WorkingDirectory() + string(os.PathSeparator)
+ cwd := config.Get().WorkingDir() + string(os.PathSeparator)
filePath := file.FilePath
filePath = strings.TrimPrefix(filePath, cwd)
filePath = fsext.DirTrim(fsext.PrettyPath(filePath), 2)
@@ -474,7 +474,8 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string {
}
func (s *sidebarCmp) currentModelBlock() string {
- model := config.GetAgentModel(config.AgentCoder)
+ agentCfg := config.Get().Agents["coder"]
+ model := config.Get().GetModelByType(agentCfg.Model)
t := styles.CurrentTheme()
@@ -507,7 +508,7 @@ func (m *sidebarCmp) SetSession(session session.Session) tea.Cmd {
}
func cwd() string {
- cwd := config.WorkingDirectory()
+ cwd := config.Get().WorkingDir()
t := styles.CurrentTheme()
// Replace home directory with ~, unless we're at the top level of the
// home directory).
@@ -31,8 +31,8 @@ const (
// ModelSelectedMsg is sent when a model is selected
type ModelSelectedMsg struct {
- Model config.PreferredModel
- ModelType config.ModelType
+ Model config.SelectedModel
+ ModelType config.SelectedModelType
}
// CloseModelDialogMsg is sent when a model is selected
@@ -115,19 +115,19 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
items := m.modelList.Items()
selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
- var modelType config.ModelType
+ var modelType config.SelectedModelType
if m.modelType == LargeModelType {
- modelType = config.LargeModel
+ modelType = config.SelectedModelTypeLarge
} else {
- modelType = config.SmallModel
+ modelType = config.SelectedModelTypeSmall
}
return m, tea.Sequence(
util.CmdHandler(dialogs.CloseDialogMsg{}),
util.CmdHandler(ModelSelectedMsg{
- Model: config.PreferredModel{
- ModelID: selectedItem.Model.ID,
- Provider: selectedItem.Provider.ID,
+ Model: config.SelectedModel{
+ Model: selectedItem.Model.ID,
+ Provider: string(selectedItem.Provider.ID),
},
ModelType: modelType,
}),
@@ -218,35 +218,39 @@ func (m *modelDialogCmp) modelTypeRadio() string {
func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
m.modelType = modelType
- providers := config.Providers()
+ providers, err := config.Providers()
+ if err != nil {
+ return util.ReportError(err)
+ }
+
modelItems := []util.Model{}
selectIndex := 0
cfg := config.Get()
- var currentModel config.PreferredModel
+ var currentModel config.SelectedModel
if m.modelType == LargeModelType {
- currentModel = cfg.Models.Large
+ currentModel = cfg.Models[config.SelectedModelTypeLarge]
} else {
- currentModel = cfg.Models.Small
+ currentModel = cfg.Models[config.SelectedModelTypeSmall]
}
// Create a map to track which providers we've already added
- addedProviders := make(map[provider.InferenceProvider]bool)
+ addedProviders := make(map[string]bool)
// First, add any configured providers that are not in the known providers list
// These should appear at the top of the list
knownProviders := provider.KnownProviders()
for providerID, providerConfig := range cfg.Providers {
- if providerConfig.Disabled {
+ if providerConfig.Disable {
continue
}
// Check if this provider is not in the known providers list
- if !slices.Contains(knownProviders, providerID) {
+ if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
// Convert config provider to provider.Provider format
configProvider := provider.Provider{
Name: string(providerID), // Use provider ID as name for unknown providers
- ID: providerID,
+ ID: provider.InferenceProvider(providerID),
Models: make([]provider.Model, len(providerConfig.Models)),
}
@@ -263,7 +267,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
DefaultMaxTokens: model.DefaultMaxTokens,
CanReason: model.CanReason,
HasReasoningEffort: model.HasReasoningEffort,
- DefaultReasoningEffort: model.ReasoningEffort,
+ DefaultReasoningEffort: model.DefaultReasoningEffort,
SupportsImages: model.SupportsImages,
}
}
@@ -279,7 +283,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
Provider: configProvider,
Model: model,
}))
- if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
+ if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
selectIndex = len(modelItems) - 1 // Set the selected index to the current model
}
}
@@ -290,12 +294,12 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
// Then add the known providers from the predefined list
for _, provider := range providers {
// Skip if we already added this provider as an unknown provider
- if addedProviders[provider.ID] {
+ if addedProviders[string(provider.ID)] {
continue
}
// Check if this provider is configured and not disabled
- if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
+ if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
continue
}
@@ -309,7 +313,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
Provider: provider,
Model: model,
}))
- if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
+ if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
selectIndex = len(modelItems) - 1 // Set the selected index to the current model
}
}
@@ -170,7 +170,8 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
util.CmdHandler(ChatFocusedMsg{Focused: false}),
)
case key.Matches(msg, p.keyMap.AddAttachment):
- model := config.GetAgentModel(config.AgentCoder)
+ agentCfg := config.Get().Agents["coder"]
+ model := config.Get().GetModelByType(agentCfg.Model)
if model.SupportsImages {
return p, util.CmdHandler(OpenFilePickerMsg{})
} else {
@@ -177,14 +177,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Update the agent with the new model/provider configuration
if err := a.app.UpdateAgentModel(); err != nil {
logging.ErrorPersist(fmt.Sprintf("Failed to update agent model: %v", err))
- return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err))
+ return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.Model, err))
}
modelTypeName := "large"
- if msg.ModelType == config.SmallModel {
+ if msg.ModelType == config.SelectedModelTypeSmall {
modelTypeName = "small"
}
- return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.ModelID))
+ return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.Model))
// File Picker
case chat.OpenFilePickerMsg:
@@ -1,224 +0,0 @@
-package config
-
-import (
- "slices"
- "strings"
-
- "github.com/charmbracelet/crush/internal/fur/provider"
-)
-
-const (
- appName = "crush"
- defaultDataDirectory = ".crush"
- defaultLogLevel = "info"
-)
-
-var defaultContextPaths = []string{
- ".github/copilot-instructions.md",
- ".cursorrules",
- ".cursor/rules/",
- "CLAUDE.md",
- "CLAUDE.local.md",
- "GEMINI.md",
- "gemini.md",
- "crush.md",
- "crush.local.md",
- "Crush.md",
- "Crush.local.md",
- "CRUSH.md",
- "CRUSH.local.md",
-}
-
-type SelectedModelType string
-
-const (
- SelectedModelTypeLarge SelectedModelType = "large"
- SelectedModelTypeSmall SelectedModelType = "small"
-)
-
-type SelectedModel struct {
- // The model id as used by the provider API.
- // Required.
- Model string `json:"model"`
- // The model provider, same as the key/id used in the providers config.
- // Required.
- Provider string `json:"provider"`
-
- // Only used by models that use the openai provider and need this set.
- ReasoningEffort string `json:"reasoning_effort,omitempty"`
-
- // Overrides the default model configuration.
- MaxTokens int64 `json:"max_tokens,omitempty"`
-
- // Used by anthropic models that can reason to indicate if the model should think.
- Think bool `json:"think,omitempty"`
-}
-
-type ProviderConfig struct {
- // The provider's id.
- ID string `json:"id,omitempty"`
- // The provider's API endpoint.
- BaseURL string `json:"base_url,omitempty"`
- // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
- Type provider.Type `json:"type,omitempty"`
- // The provider's API key.
- APIKey string `json:"api_key,omitempty"`
- // Marks the provider as disabled.
- Disable bool `json:"disable,omitempty"`
-
- // Extra headers to send with each request to the provider.
- ExtraHeaders map[string]string
-
- // Used to pass extra parameters to the provider.
- ExtraParams map[string]string `json:"-"`
-
- // The provider models
- Models []provider.Model `json:"models,omitempty"`
-}
-
-type MCPType string
-
-const (
- MCPStdio MCPType = "stdio"
- MCPSse MCPType = "sse"
- MCPHttp MCPType = "http"
-)
-
-type MCPConfig struct {
- Command string `json:"command,omitempty" `
- Env []string `json:"env,omitempty"`
- Args []string `json:"args,omitempty"`
- Type MCPType `json:"type"`
- URL string `json:"url,omitempty"`
-
- // TODO: maybe make it possible to get the value from the env
- Headers map[string]string `json:"headers,omitempty"`
-}
-
-type LSPConfig struct {
- Disabled bool `json:"enabled,omitempty"`
- Command string `json:"command"`
- Args []string `json:"args,omitempty"`
- Options any `json:"options,omitempty"`
-}
-
-type TUIOptions struct {
- CompactMode bool `json:"compact_mode,omitempty"`
- // Here we can add themes later or any TUI related options
-}
-
-type Options struct {
- ContextPaths []string `json:"context_paths,omitempty"`
- TUI *TUIOptions `json:"tui,omitempty"`
- Debug bool `json:"debug,omitempty"`
- DebugLSP bool `json:"debug_lsp,omitempty"`
- DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
- // Relative to the cwd
- DataDirectory string `json:"data_directory,omitempty"`
-}
-
-type MCPs map[string]MCPConfig
-
-type MCP struct {
- Name string `json:"name"`
- MCP MCPConfig `json:"mcp"`
-}
-
-func (m MCPs) Sorted() []MCP {
- sorted := make([]MCP, 0, len(m))
- for k, v := range m {
- sorted = append(sorted, MCP{
- Name: k,
- MCP: v,
- })
- }
- slices.SortFunc(sorted, func(a, b MCP) int {
- return strings.Compare(a.Name, b.Name)
- })
- return sorted
-}
-
-type LSPs map[string]LSPConfig
-
-type LSP struct {
- Name string `json:"name"`
- LSP LSPConfig `json:"lsp"`
-}
-
-func (l LSPs) Sorted() []LSP {
- sorted := make([]LSP, 0, len(l))
- for k, v := range l {
- sorted = append(sorted, LSP{
- Name: k,
- LSP: v,
- })
- }
- slices.SortFunc(sorted, func(a, b LSP) int {
- return strings.Compare(a.Name, b.Name)
- })
- return sorted
-}
-
-// Config holds the configuration for crush.
-type Config struct {
- // We currently only support large/small as values here.
- Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
-
- // The providers that are configured
- Providers map[string]ProviderConfig `json:"providers,omitempty"`
-
- MCP MCPs `json:"mcp,omitempty"`
-
- LSP LSPs `json:"lsp,omitempty"`
-
- Options *Options `json:"options,omitempty"`
-
- // Internal
- workingDir string `json:"-"`
-}
-
-func (c *Config) WorkingDir() string {
- return c.workingDir
-}
-
-func (c *Config) EnabledProviders() []ProviderConfig {
- enabled := make([]ProviderConfig, 0, len(c.Providers))
- for _, p := range c.Providers {
- if !p.Disable {
- enabled = append(enabled, p)
- }
- }
- return enabled
-}
-
-// IsConfigured return true if at least one provider is configured
-func (c *Config) IsConfigured() bool {
- return len(c.EnabledProviders()) > 0
-}
-
-func (c *Config) GetModel(provider, model string) *provider.Model {
- if providerConfig, ok := c.Providers[provider]; ok {
- for _, m := range providerConfig.Models {
- if m.ID == model {
- return &m
- }
- }
- }
- return nil
-}
-
-func (c *Config) LargeModel() *provider.Model {
- model, ok := c.Models[SelectedModelTypeLarge]
- if !ok {
- return nil
- }
- return c.GetModel(model.Provider, model.Model)
-}
-
-func (c *Config) SmallModel() *provider.Model {
- model, ok := c.Models[SelectedModelTypeSmall]
- if !ok {
- return nil
- }
- return c.GetModel(model.Provider, model.Model)
-}
@@ -1,93 +0,0 @@
-package config
-
-import (
- "encoding/json"
- "os"
- "path/filepath"
- "runtime"
- "sync"
-
- "github.com/charmbracelet/crush/internal/fur/provider"
-)
-
-type ProviderClient interface {
- GetProviders() ([]provider.Provider, error)
-}
-
-var (
- providerOnce sync.Once
- providerList []provider.Provider
-)
-
-// file to cache provider data
-func providerCacheFileData() string {
- xdgDataHome := os.Getenv("XDG_DATA_HOME")
- if xdgDataHome != "" {
- return filepath.Join(xdgDataHome, appName)
- }
-
- // return the path to the main data directory
- // for windows, it should be in `%LOCALAPPDATA%/crush/`
- // for linux and macOS, it should be in `$HOME/.local/share/crush/`
- if runtime.GOOS == "windows" {
- localAppData := os.Getenv("LOCALAPPDATA")
- if localAppData == "" {
- localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
- }
- return filepath.Join(localAppData, appName)
- }
-
- return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
-}
-
-func saveProvidersInCache(path string, providers []provider.Provider) error {
- dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0o755); err != nil {
- return err
- }
-
- data, err := json.MarshalIndent(providers, "", " ")
- if err != nil {
- return err
- }
-
- return os.WriteFile(path, data, 0o644)
-}
-
-func loadProvidersFromCache(path string) ([]provider.Provider, error) {
- data, err := os.ReadFile(path)
- if err != nil {
- return nil, err
- }
-
- var providers []provider.Provider
- err = json.Unmarshal(data, &providers)
- return providers, err
-}
-
-func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
- providers, err := client.GetProviders()
- if err != nil {
- fallbackToCache, err := loadProvidersFromCache(path)
- if err != nil {
- return nil, err
- }
- providers = fallbackToCache
- } else {
- if err := saveProvidersInCache(path, providerList); err != nil {
- return nil, err
- }
- }
- return providers, nil
-}
-
-func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
- var err error
- providerOnce.Do(func() {
- providerList, err = loadProviders(providerCacheFileData(), client)
- })
- if err != nil {
- return nil, err
- }
- return providerList, nil
-}
@@ -1,73 +0,0 @@
-package config
-
-import (
- "encoding/json"
- "errors"
- "os"
- "testing"
-
- "github.com/charmbracelet/crush/internal/fur/provider"
- "github.com/stretchr/testify/assert"
-)
-
-type mockProviderClient struct {
- shouldFail bool
-}
-
-func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
- if m.shouldFail {
- return nil, errors.New("failed to load providers")
- }
- return []provider.Provider{
- {
- Name: "Mock",
- },
- }, nil
-}
-
-func TestProvider_loadProvidersNoIssues(t *testing.T) {
- client := &mockProviderClient{shouldFail: false}
- tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(tmpPath, client)
- assert.NoError(t, err)
- assert.NotNil(t, providers)
- assert.Len(t, providers, 1)
-
- // check if file got saved
- fileInfo, err := os.Stat(tmpPath)
- assert.NoError(t, err)
- assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
-}
-
-func TestProvider_loadProvidersWithIssues(t *testing.T) {
- client := &mockProviderClient{shouldFail: true}
- tmpPath := t.TempDir() + "/providers.json"
- // store providers to a temporary file
- oldProviders := []provider.Provider{
- {
- Name: "OldProvider",
- },
- }
- data, err := json.Marshal(oldProviders)
- if err != nil {
- t.Fatalf("Failed to marshal old providers: %v", err)
- }
-
- err = os.WriteFile(tmpPath, data, 0o644)
- if err != nil {
- t.Fatalf("Failed to write old providers to file: %v", err)
- }
- providers, err := loadProviders(tmpPath, client)
- assert.NoError(t, err)
- assert.NotNil(t, providers)
- assert.Len(t, providers, 1)
- assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
-}
-
-func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
- client := &mockProviderClient{shouldFail: true}
- tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(tmpPath, client)
- assert.Error(t, err)
- assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
-}