diff --git a/CRUSH.md b/CRUSH.md
index e2f6053c110f317ee85113a2f08359673342e645..c308db631e006dd1c3834b6b470a02f4c41ff53b 100644
--- a/CRUSH.md
+++ b/CRUSH.md
@@ -24,6 +24,29 @@
- **File permissions**: Use octal notation (0o755, 0o644) for file permissions
- **Comments**: End comments in periods unless comments are at the end of the line.
+## Testing with Mock Providers
+
+When writing tests that involve provider configurations, use the mock providers to avoid API calls:
+
+```go
+func TestYourFunction(t *testing.T) {
+ // Enable mock providers for testing
+ originalUseMock := config.UseMockProviders
+ config.UseMockProviders = true
+ defer func() {
+ config.UseMockProviders = originalUseMock
+ config.ResetProviders()
+ }()
+
+ // Reset providers to ensure fresh mock data
+ config.ResetProviders()
+
+ // Your test code here - providers will now return mock data
+ providers := config.Providers()
+ // ... test logic
+}
+```
+
## Formatting
- ALWAYS format any Go code you write.
diff --git a/Taskfile.yaml b/Taskfile.yaml
index 7e181d84513bc1abc5d5a32807f88a75898633ea..077c994ea7c6ce79fabbc18f76cfb3ef27af8d5e 100644
--- a/Taskfile.yaml
+++ b/Taskfile.yaml
@@ -6,12 +6,12 @@ tasks:
lint:
desc: Run base linters
cmds:
- - golangci-lint run
+ - golangci-lint run --path-mode=abs --config=".golangci.yml" --timeout=5m
lint-fix:
desc: Run base linters and fix issues
cmds:
- - golangci-lint run --fix
+ - golangci-lint run --path-mode=abs --config=".golangci.yml" --timeout=5m --fix
test:
desc: Run tests
@@ -44,3 +44,9 @@ tasks:
desc: Allocations profile
cmds:
- go tool pprof -http :6061 'http://localhost:6060/debug/pprof/allocs'
+
+ schema:
+ desc: Generate JSON schema for configuration
+ cmds:
+ - go run cmd/schema/main.go > crush-schema.json
+ - echo "Generated crush-schema.json"
diff --git a/cmd/root.go b/cmd/root.go
index 2b5f79cf0337c386196d783ad9d18e2e1380aa5b..d741b859178e6c524b4b4e3a61863f144840812c 100644
--- a/cmd/root.go
+++ b/cmd/root.go
@@ -72,7 +72,8 @@ to assist developers in writing, debugging, and understanding code directly from
}
cwd = c
}
- _, err := config.Load(cwd, debug)
+
+ _, err := config.Init(cwd, debug)
if err != nil {
return err
}
diff --git a/cmd/schema/README.md b/cmd/schema/README.md
index 517fdb4d20fb9f2b819051bd72e6c33f5dea2195..c7f0102e590dcd98ce150e506c28f9016ef50bbc 100644
--- a/cmd/schema/README.md
+++ b/cmd/schema/README.md
@@ -1,64 +1,182 @@
# Crush Configuration Schema Generator
-This tool generates a JSON Schema for the Crush configuration file. The schema can be used to validate configuration files and provide autocompletion in editors that support JSON Schema.
+This tool automatically generates a JSON Schema for the Crush configuration file by using Go reflection to analyze the configuration structs. The schema provides validation, autocompletion, and documentation for configuration files.
+
+## Features
+
+- **Automated Generation**: Uses reflection to automatically generate schemas from Go structs
+- **Always Up-to-Date**: Schema stays in sync with code changes automatically
+- **Comprehensive**: Includes all configuration options, types, and validation rules
+- **Enhanced**: Adds provider enums, model lists, and custom descriptions
+- **Extensible**: Easy to add new fields and modify existing ones
## Usage
```bash
+# Generate the schema
go run cmd/schema/main.go > crush-schema.json
+
+# Or use the task runner
+task schema
```
-This will generate a JSON Schema file that can be used to validate configuration files.
+## How It Works
+
+The generator:
+
+1. **Reflects on Config Structs**: Analyzes the `config.Config` struct and all related types
+2. **Generates Base Schema**: Creates JSON Schema definitions for all struct fields
+3. **Enhances with Runtime Data**: Adds provider lists, model enums, and tool lists from the actual codebase
+4. **Adds Custom Descriptions**: Provides meaningful descriptions for configuration options
+5. **Sets Default Values**: Includes appropriate defaults for optional fields
## Schema Features
The generated schema includes:
-- All configuration options with descriptions
-- Default values where applicable
-- Validation for enum values (e.g., model IDs, provider types)
-- Required fields
-- Type checking
+- **Type Safety**: Proper type definitions for all configuration fields
+- **Validation**: Required fields, enum constraints, and format validation
+- **Documentation**: Descriptions for all configuration options
+- **Defaults**: Default values for optional settings
+- **Provider Enums**: Current list of supported providers
+- **Model Enums**: Available models from all configured providers
+- **Tool Lists**: Valid tool names for agent configurations
+- **Cross-References**: Proper relationships between different config sections
+
+## Adding New Configuration Fields
+
+To add new configuration options:
+
+1. **Add to Config Structs**: Add the field to the appropriate struct in `internal/config/`
+2. **Add JSON Tags**: Include proper JSON tags with field names
+3. **Regenerate Schema**: Run the schema generator to update the JSON schema
+4. **Update Validation**: Add any custom validation logic if needed
+
+Example:
+```go
+type Options struct {
+ // ... existing fields ...
+
+ // New field with JSON tag and description
+ NewFeature bool `json:"new_feature,omitempty"`
+}
+```
+
+The schema generator will automatically:
+- Detect the new field
+- Generate appropriate JSON schema
+- Add type information
+- Include in validation
## Using the Schema
-You can use the generated schema in several ways:
+### Editor Integration
+
+Most modern editors support JSON Schema:
+
+**VS Code**: Add to your workspace settings:
+```json
+{
+ "json.schemas": [
+ {
+ "fileMatch": ["crush.json", ".crush.json"],
+ "url": "./crush-schema.json"
+ }
+ ]
+}
+```
-1. **Editor Integration**: Many editors (VS Code, JetBrains IDEs, etc.) support JSON Schema for validation and autocompletion. You can configure your editor to use the generated schema for `.crush.json` files.
+**JetBrains IDEs**: Configure in Settings → Languages & Frameworks → Schemas and DTDs → JSON Schema Mappings
-2. **Validation Tools**: You can use tools like [jsonschema](https://github.com/Julian/jsonschema) to validate your configuration files against the schema.
+### Validation Tools
-3. **Documentation**: The schema serves as documentation for the configuration options.
+```bash
+# Using jsonschema (Python)
+pip install jsonschema
+jsonschema -i crush.json crush-schema.json
-## Example Configuration
+# Using ajv-cli (Node.js)
+npm install -g ajv-cli
+ajv validate -s crush-schema.json -d crush.json
+```
-Here's an example configuration that conforms to the schema:
+### Configuration Example
```json
{
- "data": {
- "directory": ".crush"
+ "models": {
+ "large": {
+ "model_id": "claude-3-5-sonnet-20241022",
+ "provider": "anthropic",
+ "reasoning_effort": "medium",
+ "max_tokens": 8192
+ },
+ "small": {
+ "model_id": "claude-3-5-haiku-20241022",
+ "provider": "anthropic"
+ }
},
- "debug": false,
"providers": {
"anthropic": {
- "apiKey": "your-api-key"
+ "id": "anthropic",
+ "provider_type": "anthropic",
+ "api_key": "your-api-key",
+ "disabled": false
}
},
"agents": {
"coder": {
- "model": "claude-3.7-sonnet",
- "maxTokens": 5000,
- "reasoningEffort": "medium"
- },
- "task": {
- "model": "claude-3.7-sonnet",
- "maxTokens": 5000
+ "id": "coder",
+ "name": "Coder",
+ "model": "large",
+ "disabled": false
},
- "title": {
- "model": "claude-3.7-sonnet",
- "maxTokens": 80
+ "custom-agent": {
+ "id": "custom-agent",
+ "name": "Custom Agent",
+ "description": "A custom agent for specific tasks",
+ "model": "small",
+ "allowed_tools": ["glob", "grep", "view"],
+ "allowed_mcp": {
+ "filesystem": ["read", "write"]
+ }
+ }
+ },
+ "mcp": {
+ "filesystem": {
+ "command": "mcp-filesystem",
+ "args": ["--root", "/workspace"],
+ "type": "stdio"
+ }
+ },
+ "lsp": {
+ "typescript": {
+ "command": "typescript-language-server",
+ "args": ["--stdio"],
+ "enabled": true
+ }
+ },
+ "options": {
+ "context_paths": [
+ "README.md",
+ "docs/",
+ ".cursorrules"
+ ],
+ "data_directory": ".crush",
+ "debug": false,
+ "tui": {
+ "compact_mode": false
}
}
}
```
+
+## Maintenance
+
+The schema generator is designed to be maintenance-free. As long as:
+
+- Configuration structs have proper JSON tags
+- New enums are added to the enhancement functions
+- The generator is run after significant config changes
+
+The schema will stay current with the codebase automatically.
\ No newline at end of file
diff --git a/cmd/schema/main.go b/cmd/schema/main.go
index da5353c0ec7353bfa3ec9b35760b735ecc2c9ccd..43f361662cd5d357e4d3b736ba7b4f3af2222724 100644
--- a/cmd/schema/main.go
+++ b/cmd/schema/main.go
@@ -6,25 +6,29 @@ import (
"os"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/invopop/jsonschema"
)
-// JSONSchemaType represents a JSON Schema type
-type JSONSchemaType struct {
- Type string `json:"type,omitempty"`
- Description string `json:"description,omitempty"`
- Properties map[string]any `json:"properties,omitempty"`
- Required []string `json:"required,omitempty"`
- AdditionalProperties any `json:"additionalProperties,omitempty"`
- Enum []any `json:"enum,omitempty"`
- Items map[string]any `json:"items,omitempty"`
- OneOf []map[string]any `json:"oneOf,omitempty"`
- AnyOf []map[string]any `json:"anyOf,omitempty"`
- Default any `json:"default,omitempty"`
-}
-
func main() {
- schema := generateSchema()
+ // Create a new reflector
+ r := &jsonschema.Reflector{
+ // Use anonymous schemas to avoid ID conflicts
+ Anonymous: true,
+ // Expand the root struct instead of referencing it
+ ExpandedStruct: true,
+ AllowAdditionalProperties: true,
+ }
+
+ // Generate schema for the main Config struct
+ schema := r.Reflect(&config.Config{})
+
+ // Enhance the schema with additional information
+ enhanceSchema(schema)
+
+ // Set the schema metadata
+ schema.Version = "https://json-schema.org/draft/2020-12/schema"
+ schema.Title = "Crush Configuration"
+ schema.Description = "Configuration schema for the Crush application"
// Pretty print the schema
encoder := json.NewEncoder(os.Stdout)
@@ -35,276 +39,117 @@ func main() {
}
}
-func generateSchema() map[string]any {
- schema := map[string]any{
- "$schema": "http://json-schema.org/draft-07/schema#",
- "title": "Crush Configuration",
- "description": "Configuration schema for the Crush application",
- "type": "object",
- "properties": map[string]any{},
- }
+// enhanceSchema adds additional enhancements to the generated schema
+func enhanceSchema(schema *jsonschema.Schema) {
+ // Add provider enums
+ addProviderEnums(schema)
- // Add Data configuration
- schema["properties"].(map[string]any)["data"] = map[string]any{
- "type": "object",
- "description": "Storage configuration",
- "properties": map[string]any{
- "directory": map[string]any{
- "type": "string",
- "description": "Directory where application data is stored",
- "default": ".crush",
- },
- },
- "required": []string{"directory"},
- }
+ // Add model enums
+ addModelEnums(schema)
- // Add working directory
- schema["properties"].(map[string]any)["wd"] = map[string]any{
- "type": "string",
- "description": "Working directory for the application",
- }
-
- // Add debug flags
- schema["properties"].(map[string]any)["debug"] = map[string]any{
- "type": "boolean",
- "description": "Enable debug mode",
- "default": false,
- }
-
- schema["properties"].(map[string]any)["debugLSP"] = map[string]any{
- "type": "boolean",
- "description": "Enable LSP debug mode",
- "default": false,
- }
-
- schema["properties"].(map[string]any)["contextPaths"] = map[string]any{
- "type": "array",
- "description": "Context paths for the application",
- "items": map[string]any{
- "type": "string",
- },
- "default": []string{
- ".github/copilot-instructions.md",
- ".cursorrules",
- ".cursor/rules/",
- "CLAUDE.md",
- "CLAUDE.local.md",
- "GEMINI.md",
- "gemini.md",
- "crush.md",
- "crush.local.md",
- "Crush.md",
- "Crush.local.md",
- "CRUSH.md",
- "CRUSH.local.md",
- },
- }
-
- schema["properties"].(map[string]any)["tui"] = map[string]any{
- "type": "object",
- "description": "Terminal User Interface configuration",
- "properties": map[string]any{
- "theme": map[string]any{
- "type": "string",
- "description": "TUI theme name",
- "default": "crush",
- "enum": []string{
- "crush",
- "catppuccin",
- "dracula",
- "flexoki",
- "gruvbox",
- "monokai",
- "onedark",
- "tokyonight",
- "tron",
- },
- },
- },
- }
+ // Add tool enums
+ addToolEnums(schema)
- // Add MCP servers
- schema["properties"].(map[string]any)["mcpServers"] = map[string]any{
- "type": "object",
- "description": "Model Control Protocol server configurations",
- "additionalProperties": map[string]any{
- "type": "object",
- "description": "MCP server configuration",
- "properties": map[string]any{
- "command": map[string]any{
- "type": "string",
- "description": "Command to execute for the MCP server",
- },
- "env": map[string]any{
- "type": "array",
- "description": "Environment variables for the MCP server",
- "items": map[string]any{
- "type": "string",
- },
- },
- "args": map[string]any{
- "type": "array",
- "description": "Command arguments for the MCP server",
- "items": map[string]any{
- "type": "string",
- },
- },
- "type": map[string]any{
- "type": "string",
- "description": "Type of MCP server",
- "enum": []string{"stdio", "sse"},
- "default": "stdio",
- },
- "url": map[string]any{
- "type": "string",
- "description": "URL for SSE type MCP servers",
- },
- "headers": map[string]any{
- "type": "object",
- "description": "HTTP headers for SSE type MCP servers",
- "additionalProperties": map[string]any{
- "type": "string",
- },
- },
- },
- "required": []string{"command"},
- },
- }
+ // Add default context paths
+ addDefaultContextPaths(schema)
+}
- // Add providers
- providerSchema := map[string]any{
- "type": "object",
- "description": "LLM provider configurations",
- "additionalProperties": map[string]any{
- "type": "object",
- "description": "Provider configuration",
- "properties": map[string]any{
- "apiKey": map[string]any{
- "type": "string",
- "description": "API key for the provider",
- },
- "disabled": map[string]any{
- "type": "boolean",
- "description": "Whether the provider is disabled",
- "default": false,
- },
- },
- },
+// addProviderEnums adds provider enums to the schema
+func addProviderEnums(schema *jsonschema.Schema) {
+ providers := config.Providers()
+ var providerIDs []any
+ for _, p := range providers {
+ providerIDs = append(providerIDs, string(p.ID))
}
- // Add known providers
- knownProviders := []string{
- string(models.ProviderAnthropic),
- string(models.ProviderOpenAI),
- string(models.ProviderGemini),
- string(models.ProviderGROQ),
- string(models.ProviderOpenRouter),
- string(models.ProviderBedrock),
- string(models.ProviderAzure),
- string(models.ProviderVertexAI),
- }
+ // Add to PreferredModel provider field
+ if schema.Definitions != nil {
+ if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists {
+ if providerProp, exists := preferredModelDef.Properties.Get("provider"); exists {
+ providerProp.Enum = providerIDs
+ }
+ }
- providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{
- "type": "string",
- "description": "Provider type",
- "enum": knownProviders,
+ // Add to ProviderConfig ID field
+ if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists {
+ if idProp, exists := providerConfigDef.Properties.Get("id"); exists {
+ idProp.Enum = providerIDs
+ }
+ }
}
+}
- schema["properties"].(map[string]any)["providers"] = providerSchema
-
- // Add agents
- agentSchema := map[string]any{
- "type": "object",
- "description": "Agent configurations",
- "additionalProperties": map[string]any{
- "type": "object",
- "description": "Agent configuration",
- "properties": map[string]any{
- "model": map[string]any{
- "type": "string",
- "description": "Model ID for the agent",
- },
- "maxTokens": map[string]any{
- "type": "integer",
- "description": "Maximum tokens for the agent",
- "minimum": 1,
- },
- "reasoningEffort": map[string]any{
- "type": "string",
- "description": "Reasoning effort for models that support it (OpenAI, Anthropic)",
- "enum": []string{"low", "medium", "high"},
- },
- },
- "required": []string{"model"},
- },
+// addModelEnums adds model enums to the schema
+func addModelEnums(schema *jsonschema.Schema) {
+ providers := config.Providers()
+ var modelIDs []any
+ for _, p := range providers {
+ for _, m := range p.Models {
+ modelIDs = append(modelIDs, m.ID)
+ }
}
- // Add model enum
- modelEnum := []string{}
- for modelID := range models.SupportedModels {
- modelEnum = append(modelEnum, string(modelID))
+ // Add to PreferredModel model_id field
+ if schema.Definitions != nil {
+ if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists {
+ if modelIDProp, exists := preferredModelDef.Properties.Get("model_id"); exists {
+ modelIDProp.Enum = modelIDs
+ }
+ }
}
- agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum
+}
- // Add specific agent properties
- agentProperties := map[string]any{}
- knownAgents := []string{
- string(config.AgentCoder),
- string(config.AgentTask),
- string(config.AgentTitle),
+// addToolEnums adds tool enums to the schema
+func addToolEnums(schema *jsonschema.Schema) {
+ tools := []any{
+ "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
}
- for _, agentName := range knownAgents {
- agentProperties[agentName] = map[string]any{
- "$ref": "#/definitions/agent",
+ if schema.Definitions != nil {
+ if agentDef, exists := schema.Definitions["Agent"]; exists {
+ if allowedToolsProp, exists := agentDef.Properties.Get("allowed_tools"); exists {
+ if allowedToolsProp.Items != nil {
+ allowedToolsProp.Items.Enum = tools
+ }
+ }
}
}
+}
- // Create a combined schema that allows both specific agents and additional ones
- combinedAgentSchema := map[string]any{
- "type": "object",
- "description": "Agent configurations",
- "properties": agentProperties,
- "additionalProperties": agentSchema["additionalProperties"],
- }
-
- schema["properties"].(map[string]any)["agents"] = combinedAgentSchema
- schema["definitions"] = map[string]any{
- "agent": agentSchema["additionalProperties"],
+// addDefaultContextPaths adds default context paths to the schema
+func addDefaultContextPaths(schema *jsonschema.Schema) {
+ defaultContextPaths := []any{
+ ".github/copilot-instructions.md",
+ ".cursorrules",
+ ".cursor/rules/",
+ "CLAUDE.md",
+ "CLAUDE.local.md",
+ "GEMINI.md",
+ "gemini.md",
+ "crush.md",
+ "crush.local.md",
+ "Crush.md",
+ "Crush.local.md",
+ "CRUSH.md",
+ "CRUSH.local.md",
+ }
+
+ if schema.Definitions != nil {
+ if optionsDef, exists := schema.Definitions["Options"]; exists {
+ if contextPathsProp, exists := optionsDef.Properties.Get("context_paths"); exists {
+ contextPathsProp.Default = defaultContextPaths
+ }
+ }
}
- // Add LSP configuration
- schema["properties"].(map[string]any)["lsp"] = map[string]any{
- "type": "object",
- "description": "Language Server Protocol configurations",
- "additionalProperties": map[string]any{
- "type": "object",
- "description": "LSP configuration for a language",
- "properties": map[string]any{
- "disabled": map[string]any{
- "type": "boolean",
- "description": "Whether the LSP is disabled",
- "default": false,
- },
- "command": map[string]any{
- "type": "string",
- "description": "Command to execute for the LSP server",
- },
- "args": map[string]any{
- "type": "array",
- "description": "Command arguments for the LSP server",
- "items": map[string]any{
- "type": "string",
- },
- },
- "options": map[string]any{
- "type": "object",
- "description": "Additional options for the LSP server",
- },
- },
- "required": []string{"command"},
- },
+ // Also add to root properties if they exist
+ if schema.Properties != nil {
+ if optionsProp, exists := schema.Properties.Get("options"); exists {
+ if optionsProp.Properties != nil {
+ if contextPathsProp, exists := optionsProp.Properties.Get("context_paths"); exists {
+ contextPathsProp.Default = defaultContextPaths
+ }
+ }
+ }
}
-
- return schema
}
diff --git a/crush-schema.json b/crush-schema.json
index 5412a4badecb3e9d49022a69b3c7eb20fce0812b..680ba31196e276c290fd5040b36b23c26cb12414 100644
--- a/crush-schema.json
+++ b/crush-schema.json
@@ -1,383 +1,705 @@
{
- "$schema": "http://json-schema.org/draft-07/schema#",
- "definitions": {
- "agent": {
- "description": "Agent configuration",
+ "$schema": "https://json-schema.org/draft/2020-12/schema",
+ "$defs": {
+ "Agent": {
"properties": {
- "maxTokens": {
- "description": "Maximum tokens for the agent",
- "minimum": 1,
- "type": "integer"
+ "id": {
+ "type": "string",
+ "enum": [
+ "coder",
+ "task",
+ "coder",
+ "task"
+ ],
+ "title": "Agent ID",
+ "description": "Unique identifier for the agent"
+ },
+ "name": {
+ "type": "string",
+ "title": "Name",
+ "description": "Display name of the agent"
+ },
+ "description": {
+ "type": "string",
+ "title": "Description",
+ "description": "Description of what the agent does"
+ },
+ "disabled": {
+ "type": "boolean",
+ "title": "Disabled",
+ "description": "Whether this agent is disabled",
+ "default": false
},
"model": {
- "description": "Model ID for the agent",
+ "type": "string",
"enum": [
- "gpt-4.1",
- "llama-3.3-70b-versatile",
- "azure.gpt-4.1",
- "openrouter.gpt-4o",
- "openrouter.o1-mini",
- "openrouter.claude-3-haiku",
- "claude-3-opus",
- "gpt-4o",
- "gpt-4o-mini",
- "o1",
- "meta-llama/llama-4-maverick-17b-128e-instruct",
- "azure.o3-mini",
- "openrouter.gpt-4o-mini",
- "openrouter.o1",
- "claude-3.5-haiku",
- "o4-mini",
- "azure.gpt-4.1-mini",
- "openrouter.o3",
- "grok-3-beta",
- "o3-mini",
- "qwen-qwq",
- "azure.o1",
- "openrouter.gemini-2.5-flash",
- "openrouter.gemini-2.5",
- "o1-mini",
- "azure.gpt-4o",
- "openrouter.gpt-4.1-mini",
- "openrouter.claude-3.5-sonnet",
- "openrouter.o3-mini",
- "gpt-4.1-mini",
- "gpt-4.5-preview",
- "gpt-4.1-nano",
- "deepseek-r1-distill-llama-70b",
- "azure.gpt-4o-mini",
- "openrouter.gpt-4.1",
- "bedrock.claude-3.7-sonnet",
- "claude-3-haiku",
- "o3",
- "gemini-2.0-flash-lite",
- "azure.o3",
- "azure.gpt-4.5-preview",
- "openrouter.claude-3-opus",
- "grok-3-mini-fast-beta",
- "claude-4-sonnet",
- "azure.o4-mini",
- "grok-3-fast-beta",
- "claude-3.5-sonnet",
- "azure.o1-mini",
- "openrouter.claude-3.7-sonnet",
- "openrouter.gpt-4.5-preview",
- "grok-3-mini-beta",
- "claude-3.7-sonnet",
- "gemini-2.0-flash",
- "openrouter.deepseek-r1-free",
- "vertexai.gemini-2.5-flash",
- "vertexai.gemini-2.5",
- "o1-pro",
- "gemini-2.5",
- "meta-llama/llama-4-scout-17b-16e-instruct",
- "azure.gpt-4.1-nano",
- "openrouter.gpt-4.1-nano",
- "gemini-2.5-flash",
- "openrouter.o4-mini",
- "openrouter.claude-3.5-haiku",
- "claude-4-opus",
- "openrouter.o1-pro"
+ "large",
+ "small",
+ "large",
+ "small"
],
- "type": "string"
+ "title": "Model Type",
+ "description": "Type of model to use (large or small)"
},
- "reasoningEffort": {
- "description": "Reasoning effort for models that support it (OpenAI, Anthropic)",
- "enum": ["low", "medium", "high"],
- "type": "string"
- }
- },
- "required": ["model"],
- "type": "object"
- }
- },
- "description": "Configuration schema for the Crush application",
- "properties": {
- "agents": {
- "additionalProperties": {
- "description": "Agent configuration",
- "properties": {
- "maxTokens": {
- "description": "Maximum tokens for the agent",
- "minimum": 1,
- "type": "integer"
- },
- "model": {
- "description": "Model ID for the agent",
+ "allowed_tools": {
+ "items": {
+ "type": "string",
"enum": [
- "gpt-4.1",
- "llama-3.3-70b-versatile",
- "azure.gpt-4.1",
- "openrouter.gpt-4o",
- "openrouter.o1-mini",
- "openrouter.claude-3-haiku",
- "claude-3-opus",
- "gpt-4o",
- "gpt-4o-mini",
- "o1",
- "meta-llama/llama-4-maverick-17b-128e-instruct",
- "azure.o3-mini",
- "openrouter.gpt-4o-mini",
- "openrouter.o1",
- "claude-3.5-haiku",
- "o4-mini",
- "azure.gpt-4.1-mini",
- "openrouter.o3",
- "grok-3-beta",
- "o3-mini",
- "qwen-qwq",
- "azure.o1",
- "openrouter.gemini-2.5-flash",
- "openrouter.gemini-2.5",
- "o1-mini",
- "azure.gpt-4o",
- "openrouter.gpt-4.1-mini",
- "openrouter.claude-3.5-sonnet",
- "openrouter.o3-mini",
- "gpt-4.1-mini",
- "gpt-4.5-preview",
- "gpt-4.1-nano",
- "deepseek-r1-distill-llama-70b",
- "azure.gpt-4o-mini",
- "openrouter.gpt-4.1",
- "bedrock.claude-3.7-sonnet",
- "claude-3-haiku",
- "o3",
- "gemini-2.0-flash-lite",
- "azure.o3",
- "azure.gpt-4.5-preview",
- "openrouter.claude-3-opus",
- "grok-3-mini-fast-beta",
- "claude-4-sonnet",
- "azure.o4-mini",
- "grok-3-fast-beta",
- "claude-3.5-sonnet",
- "azure.o1-mini",
- "openrouter.claude-3.7-sonnet",
- "openrouter.gpt-4.5-preview",
- "grok-3-mini-beta",
- "claude-3.7-sonnet",
- "gemini-2.0-flash",
- "openrouter.deepseek-r1-free",
- "vertexai.gemini-2.5-flash",
- "vertexai.gemini-2.5",
- "o1-pro",
- "gemini-2.5",
- "meta-llama/llama-4-scout-17b-16e-instruct",
- "azure.gpt-4.1-nano",
- "openrouter.gpt-4.1-nano",
- "gemini-2.5-flash",
- "openrouter.o4-mini",
- "openrouter.claude-3.5-haiku",
- "claude-4-opus",
- "openrouter.o1-pro"
- ],
- "type": "string"
+ "bash",
+ "edit",
+ "fetch",
+ "glob",
+ "grep",
+ "ls",
+ "sourcegraph",
+ "view",
+ "write",
+ "agent"
+ ]
+ },
+ "type": "array",
+ "title": "Allowed Tools",
+ "description": "List of tools this agent is allowed to use (if nil all tools are allowed)"
+ },
+ "allowed_mcp": {
+ "additionalProperties": {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
},
- "reasoningEffort": {
- "description": "Reasoning effort for models that support it (OpenAI, Anthropic)",
- "enum": ["low", "medium", "high"],
+ "type": "object",
+ "title": "Allowed MCP",
+ "description": "Map of MCP servers this agent can use and their allowed tools"
+ },
+ "allowed_lsp": {
+ "items": {
"type": "string"
- }
+ },
+ "type": "array",
+ "title": "Allowed LSP",
+ "description": "List of LSP servers this agent can use (if nil all LSPs are allowed)"
},
- "required": ["model"],
- "type": "object"
+ "context_paths": {
+ "items": {
+ "type": "string"
+ },
+ "type": "array",
+ "title": "Context Paths",
+ "description": "Custom context paths for this agent (additive to global context paths)"
+ }
},
- "description": "Agent configurations",
+ "type": "object",
+ "required": [
+ "model"
+ ]
+ },
+ "LSPConfig": {
"properties": {
- "coder": {
- "$ref": "#/definitions/agent"
+ "enabled": {
+ "type": "boolean",
+ "title": "Enabled",
+ "description": "Whether this LSP server is enabled",
+ "default": true
},
- "task": {
- "$ref": "#/definitions/agent"
+ "command": {
+ "type": "string",
+ "title": "Command",
+ "description": "Command to execute for the LSP server"
+ },
+ "args": {
+ "items": {
+ "type": "string"
+ },
+ "type": "array",
+ "title": "Arguments",
+ "description": "Command line arguments for the LSP server"
},
- "title": {
- "$ref": "#/definitions/agent"
+ "options": {
+ "title": "Options",
+ "description": "LSP server specific options"
}
},
- "type": "object"
+ "type": "object",
+ "required": [
+ "command"
+ ]
},
- "contextPaths": {
- "default": [
- ".github/copilot-instructions.md",
- ".cursorrules",
- ".cursor/rules/",
- "CLAUDE.md",
- "CLAUDE.local.md",
- "GEMINI.md",
- "gemini.md",
- "crush.md",
- "crush.local.md",
- "Crush.md",
- "Crush.local.md",
- "CRUSH.md",
- "CRUSH.local.md"
- ],
- "description": "Context paths for the application",
- "items": {
- "type": "string"
+ "MCP": {
+ "properties": {
+ "command": {
+ "type": "string",
+ "title": "Command",
+ "description": "Command to execute for stdio MCP servers"
+ },
+ "env": {
+ "items": {
+ "type": "string"
+ },
+ "type": "array",
+ "title": "Environment",
+ "description": "Environment variables for the MCP server"
+ },
+ "args": {
+ "items": {
+ "type": "string"
+ },
+ "type": "array",
+ "title": "Arguments",
+ "description": "Command line arguments for the MCP server"
+ },
+ "type": {
+ "type": "string",
+ "enum": [
+ "stdio",
+ "sse",
+ "stdio",
+ "sse"
+ ],
+ "title": "Type",
+ "description": "Type of MCP connection",
+ "default": "stdio"
+ },
+ "url": {
+ "type": "string",
+ "title": "URL",
+ "description": "URL for SSE MCP servers"
+ },
+ "headers": {
+ "additionalProperties": {
+ "type": "string"
+ },
+ "type": "object",
+ "title": "Headers",
+ "description": "HTTP headers for SSE MCP servers"
+ }
},
- "type": "array"
+ "type": "object",
+ "required": [
+ "command",
+ "type"
+ ]
},
- "data": {
- "description": "Storage configuration",
+ "Model": {
"properties": {
- "directory": {
- "default": ".crush",
- "description": "Directory where application data is stored",
- "type": "string"
+ "id": {
+ "type": "string",
+ "title": "Model ID",
+ "description": "Unique identifier for the model"
+ },
+ "name": {
+ "type": "string",
+ "title": "Model Name",
+ "description": "Display name of the model"
+ },
+ "cost_per_1m_in": {
+ "type": "number",
+ "minimum": 0,
+ "title": "Input Cost",
+ "description": "Cost per 1 million input tokens"
+ },
+ "cost_per_1m_out": {
+ "type": "number",
+ "minimum": 0,
+ "title": "Output Cost",
+ "description": "Cost per 1 million output tokens"
+ },
+ "cost_per_1m_in_cached": {
+ "type": "number",
+ "minimum": 0,
+ "title": "Cached Input Cost",
+ "description": "Cost per 1 million cached input tokens"
+ },
+ "cost_per_1m_out_cached": {
+ "type": "number",
+ "minimum": 0,
+ "title": "Cached Output Cost",
+ "description": "Cost per 1 million cached output tokens"
+ },
+ "context_window": {
+ "type": "integer",
+ "minimum": 1,
+ "title": "Context Window",
+ "description": "Maximum context window size in tokens"
+ },
+ "default_max_tokens": {
+ "type": "integer",
+ "minimum": 1,
+ "title": "Default Max Tokens",
+ "description": "Default maximum tokens for responses"
+ },
+ "can_reason": {
+ "type": "boolean",
+ "title": "Can Reason",
+ "description": "Whether the model supports reasoning capabilities"
+ },
+ "reasoning_effort": {
+ "type": "string",
+ "title": "Reasoning Effort",
+ "description": "Default reasoning effort level for reasoning models"
+ },
+ "has_reasoning_effort": {
+ "type": "boolean",
+ "title": "Has Reasoning Effort",
+ "description": "Whether the model supports reasoning effort configuration"
+ },
+ "supports_attachments": {
+ "type": "boolean",
+ "title": "Supports Images",
+ "description": "Whether the model supports image attachments"
}
},
- "required": ["directory"],
- "type": "object"
- },
- "debug": {
- "default": false,
- "description": "Enable debug mode",
- "type": "boolean"
+ "type": "object",
+ "required": [
+ "id",
+ "name",
+ "cost_per_1m_out_cached",
+ "context_window",
+ "default_max_tokens",
+ "can_reason",
+ "reasoning_effort",
+ "has_reasoning_effort",
+ "supports_attachments"
+ ]
},
- "debugLSP": {
- "default": false,
- "description": "Enable LSP debug mode",
- "type": "boolean"
- },
- "lsp": {
- "additionalProperties": {
- "description": "LSP configuration for a language",
- "properties": {
- "args": {
- "description": "Command arguments for the LSP server",
- "items": {
- "type": "string"
- },
- "type": "array"
- },
- "command": {
- "description": "Command to execute for the LSP server",
+ "Options": {
+ "properties": {
+ "context_paths": {
+ "items": {
"type": "string"
},
- "disabled": {
- "default": false,
- "description": "Whether the LSP is disabled",
- "type": "boolean"
- },
- "options": {
- "description": "Additional options for the LSP server",
- "type": "object"
- }
+ "type": "array",
+ "title": "Context Paths",
+ "description": "List of paths to search for context files",
+ "default": [
+ ".github/copilot-instructions.md",
+ ".cursorrules",
+ ".cursor/rules/",
+ "CLAUDE.md",
+ "CLAUDE.local.md",
+ "GEMINI.md",
+ "gemini.md",
+ "crush.md",
+ "crush.local.md",
+ "Crush.md",
+ "Crush.local.md",
+ "CRUSH.md",
+ "CRUSH.local.md"
+ ]
+ },
+ "tui": {
+ "$ref": "#/$defs/TUIOptions",
+ "title": "TUI Options",
+ "description": "Terminal UI configuration options"
+ },
+ "debug": {
+ "type": "boolean",
+ "title": "Debug",
+ "description": "Enable debug logging",
+ "default": false
+ },
+ "debug_lsp": {
+ "type": "boolean",
+ "title": "Debug LSP",
+ "description": "Enable LSP debug logging",
+ "default": false
+ },
+ "disable_auto_summarize": {
+ "type": "boolean",
+ "title": "Disable Auto Summarize",
+ "description": "Disable automatic conversation summarization",
+ "default": false
},
- "required": ["command"],
- "type": "object"
+ "data_directory": {
+ "type": "string",
+ "title": "Data Directory",
+ "description": "Directory for storing application data",
+ "default": ".crush"
+ }
},
- "description": "Language Server Protocol configurations",
"type": "object"
},
- "mcpServers": {
- "additionalProperties": {
- "description": "MCP server configuration",
- "properties": {
- "args": {
- "description": "Command arguments for the MCP server",
- "items": {
- "type": "string"
- },
- "type": "array"
- },
- "command": {
- "description": "Command to execute for the MCP server",
- "type": "string"
- },
- "env": {
- "description": "Environment variables for the MCP server",
- "items": {
- "type": "string"
- },
- "type": "array"
- },
- "headers": {
- "additionalProperties": {
- "type": "string"
- },
- "description": "HTTP headers for SSE type MCP servers",
- "type": "object"
- },
- "type": {
- "default": "stdio",
- "description": "Type of MCP server",
- "enum": ["stdio", "sse"],
- "type": "string"
- },
- "url": {
- "description": "URL for SSE type MCP servers",
- "type": "string"
- }
+ "PreferredModel": {
+ "properties": {
+ "model_id": {
+ "type": "string",
+ "enum": [
+ "claude-opus-4-20250514",
+ "claude-sonnet-4-20250514",
+ "claude-3-7-sonnet-20250219",
+ "claude-3-5-haiku-20241022",
+ "claude-3-5-sonnet-20240620",
+ "claude-3-5-sonnet-20241022",
+ "codex-mini-latest",
+ "o4-mini",
+ "o3",
+ "o3-pro",
+ "gpt-4.1",
+ "gpt-4.1-mini",
+ "gpt-4.1-nano",
+ "gpt-4.5-preview",
+ "o3-mini",
+ "gpt-4o",
+ "gpt-4o-mini",
+ "gemini-2.5-pro",
+ "gemini-2.5-flash",
+ "codex-mini-latest",
+ "o4-mini",
+ "o3",
+ "o3-pro",
+ "gpt-4.1",
+ "gpt-4.1-mini",
+ "gpt-4.1-nano",
+ "gpt-4.5-preview",
+ "o3-mini",
+ "gpt-4o",
+ "gpt-4o-mini",
+ "anthropic.claude-opus-4-20250514-v1:0",
+ "anthropic.claude-sonnet-4-20250514-v1:0",
+ "anthropic.claude-3-7-sonnet-20250219-v1:0",
+ "anthropic.claude-3-5-haiku-20241022-v1:0",
+ "gemini-2.5-pro",
+ "gemini-2.5-flash",
+ "grok-3-mini",
+ "grok-3",
+ "mistralai/mistral-small-3.2-24b-instruct:free",
+ "mistralai/mistral-small-3.2-24b-instruct",
+ "minimax/minimax-m1:extended",
+ "minimax/minimax-m1",
+ "google/gemini-2.5-flash-lite-preview-06-17",
+ "google/gemini-2.5-flash",
+ "google/gemini-2.5-pro",
+ "openai/o3-pro",
+ "x-ai/grok-3-mini",
+ "x-ai/grok-3",
+ "mistralai/magistral-small-2506",
+ "mistralai/magistral-medium-2506",
+ "mistralai/magistral-medium-2506:thinking",
+ "google/gemini-2.5-pro-preview",
+ "deepseek/deepseek-r1-0528",
+ "anthropic/claude-opus-4",
+ "anthropic/claude-sonnet-4",
+ "mistralai/devstral-small:free",
+ "mistralai/devstral-small",
+ "google/gemini-2.5-flash-preview-05-20",
+ "google/gemini-2.5-flash-preview-05-20:thinking",
+ "openai/codex-mini",
+ "mistralai/mistral-medium-3",
+ "google/gemini-2.5-pro-preview-05-06",
+ "arcee-ai/caller-large",
+ "arcee-ai/virtuoso-large",
+ "arcee-ai/virtuoso-medium-v2",
+ "qwen/qwen3-30b-a3b",
+ "qwen/qwen3-14b",
+ "qwen/qwen3-32b",
+ "qwen/qwen3-235b-a22b",
+ "google/gemini-2.5-flash-preview",
+ "google/gemini-2.5-flash-preview:thinking",
+ "openai/o4-mini-high",
+ "openai/o3",
+ "openai/o4-mini",
+ "openai/gpt-4.1",
+ "openai/gpt-4.1-mini",
+ "openai/gpt-4.1-nano",
+ "x-ai/grok-3-mini-beta",
+ "x-ai/grok-3-beta",
+ "meta-llama/llama-4-maverick",
+ "meta-llama/llama-4-scout",
+ "all-hands/openhands-lm-32b-v0.1",
+ "google/gemini-2.5-pro-exp-03-25",
+ "deepseek/deepseek-chat-v3-0324:free",
+ "deepseek/deepseek-chat-v3-0324",
+ "mistralai/mistral-small-3.1-24b-instruct:free",
+ "mistralai/mistral-small-3.1-24b-instruct",
+ "ai21/jamba-1.6-large",
+ "ai21/jamba-1.6-mini",
+ "openai/gpt-4.5-preview",
+ "google/gemini-2.0-flash-lite-001",
+ "anthropic/claude-3.7-sonnet",
+ "anthropic/claude-3.7-sonnet:beta",
+ "anthropic/claude-3.7-sonnet:thinking",
+ "mistralai/mistral-saba",
+ "openai/o3-mini-high",
+ "google/gemini-2.0-flash-001",
+ "qwen/qwen-turbo",
+ "qwen/qwen-plus",
+ "qwen/qwen-max",
+ "openai/o3-mini",
+ "mistralai/mistral-small-24b-instruct-2501",
+ "deepseek/deepseek-r1-distill-llama-70b",
+ "deepseek/deepseek-r1",
+ "mistralai/codestral-2501",
+ "deepseek/deepseek-chat",
+ "openai/o1",
+ "x-ai/grok-2-1212",
+ "meta-llama/llama-3.3-70b-instruct",
+ "amazon/nova-lite-v1",
+ "amazon/nova-micro-v1",
+ "amazon/nova-pro-v1",
+ "openai/gpt-4o-2024-11-20",
+ "mistralai/mistral-large-2411",
+ "mistralai/mistral-large-2407",
+ "mistralai/pixtral-large-2411",
+ "thedrummer/unslopnemo-12b",
+ "anthropic/claude-3.5-haiku:beta",
+ "anthropic/claude-3.5-haiku",
+ "anthropic/claude-3.5-haiku-20241022:beta",
+ "anthropic/claude-3.5-haiku-20241022",
+ "anthropic/claude-3.5-sonnet:beta",
+ "anthropic/claude-3.5-sonnet",
+ "x-ai/grok-beta",
+ "mistralai/ministral-8b",
+ "mistralai/ministral-3b",
+ "nvidia/llama-3.1-nemotron-70b-instruct",
+ "google/gemini-flash-1.5-8b",
+ "meta-llama/llama-3.2-11b-vision-instruct",
+ "meta-llama/llama-3.2-3b-instruct",
+ "qwen/qwen-2.5-72b-instruct",
+ "mistralai/pixtral-12b",
+ "cohere/command-r-plus-08-2024",
+ "cohere/command-r-08-2024",
+ "microsoft/phi-3.5-mini-128k-instruct",
+ "nousresearch/hermes-3-llama-3.1-70b",
+ "openai/gpt-4o-2024-08-06",
+ "meta-llama/llama-3.1-405b-instruct",
+ "meta-llama/llama-3.1-70b-instruct",
+ "meta-llama/llama-3.1-8b-instruct",
+ "mistralai/mistral-nemo",
+ "openai/gpt-4o-mini",
+ "openai/gpt-4o-mini-2024-07-18",
+ "anthropic/claude-3.5-sonnet-20240620:beta",
+ "anthropic/claude-3.5-sonnet-20240620",
+ "mistralai/mistral-7b-instruct-v0.3",
+ "mistralai/mistral-7b-instruct:free",
+ "mistralai/mistral-7b-instruct",
+ "microsoft/phi-3-mini-128k-instruct",
+ "microsoft/phi-3-medium-128k-instruct",
+ "google/gemini-flash-1.5",
+ "openai/gpt-4o-2024-05-13",
+ "openai/gpt-4o",
+ "openai/gpt-4o:extended",
+ "meta-llama/llama-3-8b-instruct",
+ "meta-llama/llama-3-70b-instruct",
+ "mistralai/mixtral-8x22b-instruct",
+ "openai/gpt-4-turbo",
+ "google/gemini-pro-1.5",
+ "cohere/command-r-plus",
+ "cohere/command-r-plus-04-2024",
+ "cohere/command-r",
+ "anthropic/claude-3-haiku:beta",
+ "anthropic/claude-3-haiku",
+ "anthropic/claude-3-opus:beta",
+ "anthropic/claude-3-opus",
+ "anthropic/claude-3-sonnet:beta",
+ "anthropic/claude-3-sonnet",
+ "cohere/command-r-03-2024",
+ "mistralai/mistral-large",
+ "openai/gpt-3.5-turbo-0613",
+ "openai/gpt-4-turbo-preview",
+ "mistralai/mistral-small",
+ "mistralai/mistral-tiny",
+ "mistralai/mixtral-8x7b-instruct",
+ "openai/gpt-4-1106-preview",
+ "mistralai/mistral-7b-instruct-v0.1",
+ "openai/gpt-3.5-turbo-16k",
+ "openai/gpt-4",
+ "openai/gpt-4-0314"
+ ],
+ "title": "Model ID",
+ "description": "ID of the preferred model"
+ },
+ "provider": {
+ "type": "string",
+ "enum": [
+ "anthropic",
+ "openai",
+ "gemini",
+ "azure",
+ "bedrock",
+ "vertex",
+ "xai",
+ "openrouter"
+ ],
+ "title": "Provider",
+ "description": "Provider for the preferred model"
+ },
+ "reasoning_effort": {
+ "type": "string",
+ "title": "Reasoning Effort",
+ "description": "Override reasoning effort for this model"
+ },
+ "max_tokens": {
+ "type": "integer",
+ "minimum": 1,
+ "title": "Max Tokens",
+ "description": "Override max tokens for this model"
+ },
+ "think": {
+ "type": "boolean",
+ "title": "Think",
+ "description": "Enable thinking for reasoning models",
+ "default": false
+ }
+ },
+ "type": "object",
+ "required": [
+ "model_id",
+ "provider"
+ ]
+ },
+ "PreferredModels": {
+ "properties": {
+ "large": {
+ "$ref": "#/$defs/PreferredModel",
+ "title": "Large Model",
+ "description": "Preferred model configuration for large model type"
},
- "required": ["command"],
- "type": "object"
+ "small": {
+ "$ref": "#/$defs/PreferredModel",
+ "title": "Small Model",
+ "description": "Preferred model configuration for small model type"
+ }
},
- "description": "Model Control Protocol server configurations",
"type": "object"
},
- "providers": {
- "additionalProperties": {
- "description": "Provider configuration",
- "properties": {
- "apiKey": {
- "description": "API key for the provider",
+ "ProviderConfig": {
+ "properties": {
+ "id": {
+ "type": "string",
+ "enum": [
+ "anthropic",
+ "openai",
+ "gemini",
+ "azure",
+ "bedrock",
+ "vertex",
+ "xai",
+ "openrouter"
+ ],
+ "title": "Provider ID",
+ "description": "Unique identifier for the provider"
+ },
+ "base_url": {
+ "type": "string",
+ "title": "Base URL",
+ "description": "Base URL for the provider API (required for custom providers)"
+ },
+ "provider_type": {
+ "type": "string",
+ "title": "Provider Type",
+ "description": "Type of the provider (openai"
+ },
+ "api_key": {
+ "type": "string",
+ "title": "API Key",
+ "description": "API key for authenticating with the provider"
+ },
+ "disabled": {
+ "type": "boolean",
+ "title": "Disabled",
+ "description": "Whether this provider is disabled",
+ "default": false
+ },
+ "extra_headers": {
+ "additionalProperties": {
"type": "string"
},
- "disabled": {
- "default": false,
- "description": "Whether the provider is disabled",
- "type": "boolean"
- },
- "provider": {
- "description": "Provider type",
- "enum": [
- "anthropic",
- "openai",
- "gemini",
- "groq",
- "openrouter",
- "bedrock",
- "azure",
- "vertexai"
- ],
+ "type": "object",
+ "title": "Extra Headers",
+ "description": "Additional HTTP headers to send with requests"
+ },
+ "extra_params": {
+ "additionalProperties": {
"type": "string"
- }
+ },
+ "type": "object",
+ "title": "Extra Parameters",
+ "description": "Additional provider-specific parameters"
+ },
+ "default_large_model": {
+ "type": "string",
+ "title": "Default Large Model",
+ "description": "Default model ID for large model type"
+ },
+ "default_small_model": {
+ "type": "string",
+ "title": "Default Small Model",
+ "description": "Default model ID for small model type"
},
- "type": "object"
+ "models": {
+ "items": {
+ "$ref": "#/$defs/Model"
+ },
+ "type": "array",
+ "title": "Models",
+ "description": "List of available models for this provider"
+ }
},
- "description": "LLM provider configurations",
- "type": "object"
+ "type": "object",
+ "required": [
+ "provider_type"
+ ]
},
- "tui": {
- "description": "Terminal User Interface configuration",
+ "TUIOptions": {
"properties": {
- "theme": {
- "default": "crush",
- "description": "TUI theme name",
- "enum": [
- "crush",
- "catppuccin",
- "dracula",
- "flexoki",
- "gruvbox",
- "monokai",
- "onedark",
- "tokyonight",
- "tron"
- ],
- "type": "string"
+ "compact_mode": {
+ "type": "boolean",
+ "title": "Compact Mode",
+ "description": "Enable compact mode for the TUI",
+ "default": false
}
},
- "type": "object"
+ "type": "object",
+ "required": [
+ "compact_mode"
+ ]
+ }
+ },
+ "properties": {
+ "models": {
+ "$ref": "#/$defs/PreferredModels",
+ "title": "Models",
+ "description": "Preferred model configurations for large and small model types"
+ },
+ "providers": {
+ "additionalProperties": {
+ "$ref": "#/$defs/ProviderConfig"
+ },
+ "type": "object",
+ "title": "Providers",
+ "description": "LLM provider configurations"
+ },
+ "agents": {
+ "additionalProperties": {
+ "$ref": "#/$defs/Agent"
+ },
+ "type": "object",
+ "title": "Agents",
+ "description": "Agent configurations for different tasks"
+ },
+ "mcp": {
+ "additionalProperties": {
+ "$ref": "#/$defs/MCP"
+ },
+ "type": "object",
+ "title": "MCP",
+ "description": "Model Control Protocol server configurations"
+ },
+ "lsp": {
+ "additionalProperties": {
+ "$ref": "#/$defs/LSPConfig"
+ },
+ "type": "object",
+ "title": "LSP",
+ "description": "Language Server Protocol configurations"
},
- "wd": {
- "description": "Working directory for the application",
- "type": "string"
+ "options": {
+ "$ref": "#/$defs/Options",
+ "title": "Options",
+ "description": "General application options and settings"
}
},
+ "type": "object",
"title": "Crush Configuration",
- "type": "object"
+ "description": "Configuration schema for the Crush application"
}
diff --git a/.crush.json b/crush.json
similarity index 63%
rename from .crush.json
rename to crush.json
index 0edb084c3ee5f74c9b279b634b5838b3035cee95..4937665c513258840f1efb4f88fb2bdd73f6ff68 100644
--- a/.crush.json
+++ b/crush.json
@@ -1,11 +1,8 @@
{
"$schema": "./crush-schema.json",
"lsp": {
- "Go": {
+ "go": {
"command": "gopls"
}
- },
- "tui": {
- "theme": "crush"
}
}
diff --git a/cspell.json b/cspell.json
index 2595963182b8e2aa6fe575bbe9ca6a5da0f70a9b..d98b1326e54c8b62c7ad700fe19b4cbbe3e4f672 100644
--- a/cspell.json
+++ b/cspell.json
@@ -1,39 +1 @@
-{
- "flagWords": [],
- "words": [
- "afero",
- "alecthomas",
- "bubbletea",
- "charmbracelet",
- "charmtone",
- "Charple",
- "crush",
- "diffview",
- "Emph",
- "filepicker",
- "Focusable",
- "fsext",
- "GROQ",
- "Guac",
- "imageorient",
- "Lanczos",
- "lipgloss",
- "lsps",
- "lucasb",
- "nfnt",
- "oksvg",
- "Preproc",
- "rasterx",
- "rivo",
- "Sourcegraph",
- "srwiley",
- "Strikethrough",
- "termenv",
- "textinput",
- "trashhalo",
- "uniseg",
- "Unticked"
- ],
- "version": "0.2",
- "language": "en"
-}
+{"flagWords":[],"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema"],"version":"0.2","language":"en"}
\ No newline at end of file
diff --git a/go.mod b/go.mod
index 86b3a8541ddc01679f4a6e25a02ea6599b6b3754..a24a21f859d4a0ea3ba8ecf203252da84823ada7 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/charmbracelet/crush
-go 1.24.0
+go 1.24.3
require (
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
@@ -24,6 +24,7 @@ require (
github.com/fsnotify/fsnotify v1.8.0
github.com/go-logfmt/logfmt v0.6.0
github.com/google/uuid v1.6.0
+ github.com/invopop/jsonschema v0.13.0
github.com/mark3labs/mcp-go v0.17.0
github.com/muesli/termenv v0.16.0
github.com/ncruces/go-sqlite3 v0.25.0
@@ -33,7 +34,6 @@ require (
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06
github.com/sahilm/fuzzy v0.1.1
github.com/spf13/cobra v1.9.1
- github.com/spf13/viper v1.20.0
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef
github.com/stretchr/testify v1.10.0
@@ -65,6 +65,8 @@ require (
github.com/aws/smithy-go v1.20.3 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
+ github.com/bahlo/generic-list-go v0.2.0 // indirect
+ github.com/buger/jsonparser v1.1.1 // indirect
github.com/charmbracelet/colorprofile v0.3.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250611152503-f53cdd7e01ef
@@ -78,7 +80,6 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
- github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
@@ -89,6 +90,7 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0
+ github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
@@ -99,23 +101,18 @@ require (
github.com/muesli/mango-pflag v0.1.0 // indirect
github.com/muesli/roff v0.1.0 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
- github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7
- github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
- github.com/sourcegraph/conc v0.3.0 // indirect
- github.com/spf13/afero v1.12.0
- github.com/spf13/cast v1.7.1 // indirect
github.com/spf13/pflag v1.0.6 // indirect
- github.com/subosito/gotenv v1.6.0 // indirect
github.com/tetratelabs/wazero v1.9.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
+ github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
diff --git a/go.sum b/go.sum
index 3a0b0a6177b193cbbea1bb6d331e4b46fc19cb28..799dce9b3d9be5fc202a3bc5300e61bdb430ac2a 100644
--- a/go.sum
+++ b/go.sum
@@ -64,8 +64,12 @@ github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3v
github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
+github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
+github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/bmatcuk/doublestar/v4 v4.8.1 h1:54Bopc5c2cAvhLRAzqOGCYHYyhcDHsFF4wWIR5wKP38=
github.com/bmatcuk/doublestar/v4 v4.8.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
+github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
+github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/charlievieth/fastwalk v1.0.11 h1:5sLT/q9+d9xMdpKExawLppqvXFZCVKf6JHnr2u/ufj8=
github.com/charlievieth/fastwalk v1.0.11/go.mod h1:yGy1zbxog41ZVMcKA/i8ojXLFsuayX5VvwhQVoj9PBI=
github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e h1:99Ugtt633rqauFsXjZobZmtkNpeaWialfj8dl6COC6A=
@@ -84,8 +88,6 @@ github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 h1:L07QkDqR
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa h1:lphz0Z3rsiOtMYiz8axkT24i9yFiueDhJbzyNUADmME=
github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa/go.mod h1:xBlh2Yi3DL3zy/2n15kITpg0YZardf/aa/hgUaIM6Rk=
-github.com/charmbracelet/x/exp/charmtone v0.0.0-20250603201427-c31516f43444 h1:IJDiTgVE56gkAGfq0lBEloWgkXMk4hl/bmuPoicI4R0=
-github.com/charmbracelet/x/exp/charmtone v0.0.0-20250603201427-c31516f43444/go.mod h1:T9jr8CzFpjhFVHjNjKwbAD7KwBNyFnj2pntAO7F2zw0=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20250627134340-c144409e381c h1:2GELBLPgfSbHU53bsQhR9XIgNuVZ6w+Rz8RWV5Lq+A4=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20250627134340-c144409e381c/go.mod h1:T9jr8CzFpjhFVHjNjKwbAD7KwBNyFnj2pntAO7F2zw0=
github.com/charmbracelet/x/exp/golden v0.0.0-20250207160936-21c02780d27a h1:FsHEJ52OC4VuTzU8t+n5frMjLvpYWEznSr/u8tnkCYw=
@@ -114,8 +116,6 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
-github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
-github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
@@ -127,8 +127,6 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
-github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
-github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
@@ -151,6 +149,9 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
+github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
+github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
+github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
@@ -162,6 +163,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
+github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
+github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mark3labs/mcp-go v0.17.0 h1:5Ps6T7qXr7De/2QTqs9h6BKeZ/qdeUeGrgM5lPzi930=
github.com/mark3labs/mcp-go v0.17.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -194,8 +197,6 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/openai/openai-go v0.1.0-beta.2 h1:Ra5nCFkbEl9w+UJwAciC4kqnIBUCcJazhmMA0/YN894=
github.com/openai/openai-go v0.1.0-beta.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
-github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
-github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -213,8 +214,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI=
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs=
-github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
-github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y=
@@ -225,18 +224,10 @@ github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
-github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
-github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
-github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
-github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
-github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
-github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
-github.com/spf13/viper v1.20.0 h1:zrxIyR3RQIOsarIrgL8+sAvALXul9jeEPa06Y0Ph6vY=
-github.com/spf13/viper v1.20.0/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
@@ -247,8 +238,6 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
-github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I=
github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -261,6 +250,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
+github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
+github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
diff --git a/internal/app/app.go b/internal/app/app.go
index e7472059a9f3fad360172c353f5d9a188529d177..6dd1b9916d593c6f0e053aaef6714723f8fd5c60 100644
--- a/internal/app/app.go
+++ b/internal/app/app.go
@@ -55,18 +55,21 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
// Initialize LSP clients in the background
go app.initLSPClients(ctx)
+ cfg := config.Get()
+
+ coderAgentCfg := cfg.Agents[config.AgentCoder]
+ if coderAgentCfg.ID == "" {
+ return nil, fmt.Errorf("coder agent configuration is missing")
+ }
+
var err error
app.CoderAgent, err = agent.NewAgent(
- config.AgentCoder,
+ coderAgentCfg,
+ app.Permissions,
app.Sessions,
app.Messages,
- agent.CoderAgentTools(
- app.Permissions,
- app.Sessions,
- app.Messages,
- app.History,
- app.LSPClients,
- ),
+ app.History,
+ app.LSPClients,
)
if err != nil {
logging.Error("Failed to create coder agent", err)
@@ -165,3 +168,7 @@ func (app *App) Shutdown() {
}
app.CoderAgent.CancelAll()
}
+
+func (app *App) UpdateAgentModel() error {
+ return app.CoderAgent.UpdateModel()
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 3b794dee60cb13a0c42413aedce3199ea5998352..4a8c47dd8686ca562b60c97efa2c15c31daf88ad 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -1,95 +1,22 @@
-// Package config manages application configuration from various sources.
package config
import (
"encoding/json"
+ "errors"
"fmt"
"log/slog"
+ "maps"
"os"
"path/filepath"
+ "slices"
"strings"
+ "sync"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/logging"
- "github.com/spf13/afero"
- "github.com/spf13/viper"
+ "github.com/invopop/jsonschema"
)
-// MCPType defines the type of MCP (Model Control Protocol) server.
-type MCPType string
-
-// Supported MCP types
-const (
- MCPStdio MCPType = "stdio"
- MCPSse MCPType = "sse"
-)
-
-// MCPServer defines the configuration for a Model Control Protocol server.
-type MCPServer struct {
- Command string `json:"command"`
- Env []string `json:"env"`
- Args []string `json:"args"`
- Type MCPType `json:"type"`
- URL string `json:"url"`
- Headers map[string]string `json:"headers"`
-}
-
-type AgentName string
-
-const (
- AgentCoder AgentName = "coder"
- AgentSummarizer AgentName = "summarizer"
- AgentTask AgentName = "task"
- AgentTitle AgentName = "title"
-)
-
-// Agent defines configuration for different LLM models and their token limits.
-type Agent struct {
- Model models.ModelID `json:"model"`
- MaxTokens int64 `json:"maxTokens"`
- ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh
-}
-
-// Provider defines configuration for an LLM provider.
-type Provider struct {
- APIKey string `json:"apiKey"`
- Disabled bool `json:"disabled"`
-}
-
-// Data defines storage configuration.
-type Data struct {
- Directory string `json:"directory,omitempty"`
-}
-
-// LSPConfig defines configuration for Language Server Protocol integration.
-type LSPConfig struct {
- Disabled bool `json:"enabled"`
- Command string `json:"command"`
- Args []string `json:"args"`
- Options any `json:"options"`
-}
-
-// TUIConfig defines the configuration for the Terminal User Interface.
-type TUIConfig struct {
- Theme string `json:"theme,omitempty"`
-}
-
-// Config is the main configuration structure for the application.
-type Config struct {
- Data Data `json:"data"`
- WorkingDir string `json:"wd,omitempty"`
- MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
- Providers map[models.ModelProvider]Provider `json:"providers,omitempty"`
- LSP map[string]LSPConfig `json:"lsp,omitempty"`
- Agents map[AgentName]Agent `json:"agents,omitempty"`
- Debug bool `json:"debug,omitempty"`
- DebugLSP bool `json:"debugLSP,omitempty"`
- ContextPaths []string `json:"contextPaths,omitempty"`
- TUI TUIConfig `json:"tui"`
- AutoCompact bool `json:"autoCompact,omitempty"`
-}
-
-// Application constants
const (
defaultDataDirectory = ".crush"
defaultLogLevel = "info"
@@ -114,53 +41,181 @@ var defaultContextPaths = []string{
"CRUSH.local.md",
}
-// Global configuration instance
-var cfg *Config
+type AgentID string
-// Load initializes the configuration from environment variables and config files.
-// If debug is true, debug mode is enabled and log level is set to debug.
-// It returns an error if configuration loading fails.
-func Load(workingDir string, debug bool) (*Config, error) {
- if cfg != nil {
- return cfg, nil
- }
+const (
+ AgentCoder AgentID = "coder"
+ AgentTask AgentID = "task"
+)
- cfg = &Config{
- WorkingDir: workingDir,
- MCPServers: make(map[string]MCPServer),
- Providers: make(map[models.ModelProvider]Provider),
- LSP: make(map[string]LSPConfig),
- }
+type ModelType string
- configureViper()
- setDefaults(debug)
+const (
+ LargeModel ModelType = "large"
+ SmallModel ModelType = "small"
+)
- // Read global config
- if err := readConfig(viper.ReadInConfig()); err != nil {
- return cfg, err
- }
+type Model struct {
+ ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model, the API model"`
+ Name string `json:"name" jsonschema:"title=Model Name,description=Display name of the model"`
+ CostPer1MIn float64 `json:"cost_per_1m_in,omitempty" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"`
+ CostPer1MOut float64 `json:"cost_per_1m_out,omitempty" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"`
+ CostPer1MInCached float64 `json:"cost_per_1m_in_cached,omitempty" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"`
+ CostPer1MOutCached float64 `json:"cost_per_1m_out_cached,omitempty" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"`
+ ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"`
+ DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"`
+ CanReason bool `json:"can_reason,omitempty" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"`
+ ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"`
+ HasReasoningEffort bool `json:"has_reasoning_effort,omitempty" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"`
+ SupportsImages bool `json:"supports_attachments,omitempty" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"`
+}
- // Load and merge local config
- mergeLocalConfig(workingDir)
+type VertexAIOptions struct {
+ APIKey string `json:"api_key,omitempty"`
+ Project string `json:"project,omitempty"`
+ Location string `json:"location,omitempty"`
+}
- setProviderDefaults()
+type ProviderConfig struct {
+ ID provider.InferenceProvider `json:"id,omitempty" jsonschema:"title=Provider ID,description=Unique identifier for the provider"`
+ BaseURL string `json:"base_url,omitempty" jsonschema:"title=Base URL,description=Base URL for the provider API (required for custom providers)"`
+ ProviderType provider.Type `json:"provider_type" jsonschema:"title=Provider Type,description=Type of the provider (openai, anthropic, etc.)"`
+ APIKey string `json:"api_key,omitempty" jsonschema:"title=API Key,description=API key for authenticating with the provider"`
+ Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this provider is disabled,default=false"`
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"title=Extra Headers,description=Additional HTTP headers to send with requests"`
+ // used for e.x for vertex to set the project
+ ExtraParams map[string]string `json:"extra_params,omitempty" jsonschema:"title=Extra Parameters,description=Additional provider-specific parameters"`
- // Apply configuration to the struct
- if err := viper.Unmarshal(cfg); err != nil {
- return cfg, fmt.Errorf("failed to unmarshal config: %w", err)
- }
+ DefaultLargeModel string `json:"default_large_model,omitempty" jsonschema:"title=Default Large Model,description=Default model ID for large model type"`
+ DefaultSmallModel string `json:"default_small_model,omitempty" jsonschema:"title=Default Small Model,description=Default model ID for small model type"`
+
+ Models []Model `json:"models,omitempty" jsonschema:"title=Models,description=List of available models for this provider"`
+}
+
+type Agent struct {
+ ID AgentID `json:"id,omitempty" jsonschema:"title=Agent ID,description=Unique identifier for the agent,enum=coder,enum=task"`
+ Name string `json:"name,omitempty" jsonschema:"title=Name,description=Display name of the agent"`
+ Description string `json:"description,omitempty" jsonschema:"title=Description,description=Description of what the agent does"`
+ // This is the id of the system prompt used by the agent
+ Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this agent is disabled,default=false"`
+
+ Model ModelType `json:"model" jsonschema:"title=Model Type,description=Type of model to use (large or small),enum=large,enum=small"`
+
+ // The available tools for the agent
+ // if this is nil, all tools are available
+ AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"title=Allowed Tools,description=List of tools this agent is allowed to use (if nil all tools are allowed)"`
+
+ // this tells us which MCPs are available for this agent
+ // if this is empty all mcps are available
+ // the string array is the list of tools from the AllowedMCP the agent has available
+ // if the string array is nil, all tools from the AllowedMCP are available
+ AllowedMCP map[string][]string `json:"allowed_mcp,omitempty" jsonschema:"title=Allowed MCP,description=Map of MCP servers this agent can use and their allowed tools"`
+
+ // The list of LSPs that this agent can use
+ // if this is nil, all LSPs are available
+ AllowedLSP []string `json:"allowed_lsp,omitempty" jsonschema:"title=Allowed LSP,description=List of LSP servers this agent can use (if nil all LSPs are allowed)"`
+
+ // Overrides the context paths for this agent
+ ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=Custom context paths for this agent (additive to global context paths)"`
+}
+
+type MCPType string
+
+const (
+ MCPStdio MCPType = "stdio"
+ MCPSse MCPType = "sse"
+)
+
+type MCP struct {
+ Command string `json:"command" jsonschema:"title=Command,description=Command to execute for stdio MCP servers"`
+ Env []string `json:"env,omitempty" jsonschema:"title=Environment,description=Environment variables for the MCP server"`
+ Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the MCP server"`
+ Type MCPType `json:"type" jsonschema:"title=Type,description=Type of MCP connection,enum=stdio,enum=sse,default=stdio"`
+ URL string `json:"url,omitempty" jsonschema:"title=URL,description=URL for SSE MCP servers"`
+ // TODO: maybe make it possible to get the value from the env
+ Headers map[string]string `json:"headers,omitempty" jsonschema:"title=Headers,description=HTTP headers for SSE MCP servers"`
+}
+
+type LSPConfig struct {
+ Disabled bool `json:"enabled,omitempty" jsonschema:"title=Enabled,description=Whether this LSP server is enabled,default=true"`
+ Command string `json:"command" jsonschema:"title=Command,description=Command to execute for the LSP server"`
+ Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the LSP server"`
+ Options any `json:"options,omitempty" jsonschema:"title=Options,description=LSP server specific options"`
+}
+
+type TUIOptions struct {
+ CompactMode bool `json:"compact_mode" jsonschema:"title=Compact Mode,description=Enable compact mode for the TUI,default=false"`
+ // Here we can add themes later or any TUI related options
+}
+
+type Options struct {
+ ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=List of paths to search for context files"`
+ TUI TUIOptions `json:"tui,omitempty" jsonschema:"title=TUI Options,description=Terminal UI configuration options"`
+ Debug bool `json:"debug,omitempty" jsonschema:"title=Debug,description=Enable debug logging,default=false"`
+ DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"title=Debug LSP,description=Enable LSP debug logging,default=false"`
+ DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"title=Disable Auto Summarize,description=Disable automatic conversation summarization,default=false"`
+ // Relative to the cwd
+ DataDirectory string `json:"data_directory,omitempty" jsonschema:"title=Data Directory,description=Directory for storing application data,default=.crush"`
+}
+
+type PreferredModel struct {
+ ModelID string `json:"model_id" jsonschema:"title=Model ID,description=ID of the preferred model"`
+ Provider provider.InferenceProvider `json:"provider" jsonschema:"title=Provider,description=Provider for the preferred model"`
+ // ReasoningEffort overrides the default reasoning effort for this model
+ ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Override reasoning effort for this model"`
+ // MaxTokens overrides the default max tokens for this model
+ MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"title=Max Tokens,description=Override max tokens for this model,minimum=1"`
+
+ // Think indicates if the model should think, only applicable for anthropic reasoning models
+ Think bool `json:"think,omitempty" jsonschema:"title=Think,description=Enable thinking for reasoning models,default=false"`
+}
+
+type PreferredModels struct {
+ Large PreferredModel `json:"large,omitempty" jsonschema:"title=Large Model,description=Preferred model configuration for large model type"`
+ Small PreferredModel `json:"small,omitempty" jsonschema:"title=Small Model,description=Preferred model configuration for small model type"`
+}
+
+type Config struct {
+ Models PreferredModels `json:"models,omitempty" jsonschema:"title=Models,description=Preferred model configurations for large and small model types"`
+ // List of configured providers
+ Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty" jsonschema:"title=Providers,description=LLM provider configurations"`
+
+ // List of configured agents
+ Agents map[AgentID]Agent `json:"agents,omitempty" jsonschema:"title=Agents,description=Agent configurations for different tasks"`
+
+ // List of configured MCPs
+ MCP map[string]MCP `json:"mcp,omitempty" jsonschema:"title=MCP,description=Model Control Protocol server configurations"`
+
+ // List of configured LSPs
+ LSP map[string]LSPConfig `json:"lsp,omitempty" jsonschema:"title=LSP,description=Language Server Protocol configurations"`
+
+ // Miscellaneous options
+ Options Options `json:"options,omitempty" jsonschema:"title=Options,description=General application options and settings"`
+}
+
+var (
+ instance *Config // The single instance of the Singleton
+ cwd string
+ once sync.Once // Ensures the initialization happens only once
- applyDefaultValues()
+)
+
+func loadConfig(cwd string, debug bool) (*Config, error) {
+ // First read the global config file
+ cfgPath := ConfigPath()
+
+ cfg := defaultConfigBasedOnEnv()
+ cfg.Options.Debug = debug
defaultLevel := slog.LevelInfo
- if cfg.Debug {
+ if cfg.Options.Debug {
defaultLevel = slog.LevelDebug
}
if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
- loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log")
+ loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
// if file does not exist create it
if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
- if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil {
+ if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
return cfg, fmt.Errorf("failed to create directory: %w", err)
}
if _, err := os.Create(loggingFile); err != nil {
@@ -168,6 +223,15 @@ func Load(workingDir string, debug bool) (*Config, error) {
}
}
+ messagesPath := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "messages")
+
+ if _, err := os.Stat(messagesPath); os.IsNotExist(err) {
+ if err := os.MkdirAll(messagesPath, 0o756); err != nil {
+ return cfg, fmt.Errorf("failed to create directory: %w", err)
+ }
+ }
+ logging.MessageDir = messagesPath
+
sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
if err != nil {
return cfg, fmt.Errorf("failed to open log file: %w", err)
@@ -184,237 +248,579 @@ func Load(workingDir string, debug bool) (*Config, error) {
}))
slog.SetDefault(logger)
}
-
- // Validate configuration
- if err := Validate(); err != nil {
- return cfg, fmt.Errorf("config validation failed: %w", err)
+ var globalCfg *Config
+ if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
+ // some other error occurred while checking the file
+ return nil, err
+ } else if err == nil {
+ // config file exists, read it
+ file, err := os.ReadFile(cfgPath)
+ if err != nil {
+ return nil, err
+ }
+ globalCfg = &Config{}
+ if err := json.Unmarshal(file, globalCfg); err != nil {
+ return nil, err
+ }
+ } else {
+ // config file does not exist, create a new one
+ globalCfg = &Config{}
+ }
+
+ var localConfig *Config
+ // Global config loaded, now read the local config file
+ localConfigPath := filepath.Join(cwd, "crush.json")
+ if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
+ // some other error occurred while checking the file
+ return nil, err
+ } else if err == nil {
+ // local config file exists, read it
+ file, err := os.ReadFile(localConfigPath)
+ if err != nil {
+ return nil, err
+ }
+ localConfig = &Config{}
+ if err := json.Unmarshal(file, localConfig); err != nil {
+ return nil, err
+ }
}
- if cfg.Agents == nil {
- cfg.Agents = make(map[AgentName]Agent)
- }
+ // merge options
+ mergeOptions(cfg, globalCfg, localConfig)
- // Override the max tokens for title agent
- cfg.Agents[AgentTitle] = Agent{
- Model: cfg.Agents[AgentTitle].Model,
- MaxTokens: 80,
+ mergeProviderConfigs(cfg, globalCfg, localConfig)
+ // no providers found the app is not initialized yet
+ if len(cfg.Providers) == 0 {
+ return cfg, nil
}
+ preferredProvider := getPreferredProvider(cfg.Providers)
+ if preferredProvider != nil {
+ cfg.Models = PreferredModels{
+ Large: PreferredModel{
+ ModelID: preferredProvider.DefaultLargeModel,
+ Provider: preferredProvider.ID,
+ },
+ Small: PreferredModel{
+ ModelID: preferredProvider.DefaultSmallModel,
+ Provider: preferredProvider.ID,
+ },
+ }
+ } else {
+ // No valid providers found, set empty models
+ cfg.Models = PreferredModels{}
+ }
+
+ mergeModels(cfg, globalCfg, localConfig)
+
+ agents := map[AgentID]Agent{
+ AgentCoder: {
+ ID: AgentCoder,
+ Name: "Coder",
+ Description: "An agent that helps with executing coding tasks.",
+ Model: LargeModel,
+ ContextPaths: cfg.Options.ContextPaths,
+ // All tools allowed
+ },
+ AgentTask: {
+ ID: AgentTask,
+ Name: "Task",
+ Description: "An agent that helps with searching for context and finding implementation details.",
+ Model: LargeModel,
+ ContextPaths: cfg.Options.ContextPaths,
+ AllowedTools: []string{
+ "glob",
+ "grep",
+ "ls",
+ "sourcegraph",
+ "view",
+ },
+ // NO MCPs or LSPs by default
+ AllowedMCP: map[string][]string{},
+ AllowedLSP: []string{},
+ },
+ }
+ cfg.Agents = agents
+ mergeAgents(cfg, globalCfg, localConfig)
+ mergeMCPs(cfg, globalCfg, localConfig)
+ mergeLSPs(cfg, globalCfg, localConfig)
+
+ // Validate the final configuration
+ if err := cfg.Validate(); err != nil {
+ return cfg, fmt.Errorf("configuration validation failed: %w", err)
+ }
+
return cfg, nil
}
-type configFinder struct {
- appName string
- dotPrefix bool
- paths []string
+func Init(workingDir string, debug bool) (*Config, error) {
+ var err error
+ once.Do(func() {
+ cwd = workingDir
+ instance, err = loadConfig(cwd, debug)
+ if err != nil {
+ logging.Error("Failed to load config", "error", err)
+ }
+ })
+
+ return instance, err
+}
+
+func Get() *Config {
+ if instance == nil {
+ // TODO: Handle this better
+ panic("Config not initialized. Call InitConfig first.")
+ }
+ return instance
}
-func (f configFinder) Find(fsys afero.Fs) ([]string, error) {
- var configFiles []string
- configName := fmt.Sprintf("%s.json", f.appName)
- if f.dotPrefix {
- configName = fmt.Sprintf(".%s.json", f.appName)
+func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
+ providers := Providers()
+ for _, p := range providers {
+ if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
+ return &providerConfig
+ }
}
- paths := []string{}
- for _, p := range f.paths {
- if p == "" {
- continue
+ // if none found return the first configured provider
+ for _, providerConfig := range configuredProviders {
+ if !providerConfig.Disabled {
+ return &providerConfig
}
- paths = append(paths, os.ExpandEnv(p))
}
+ return nil
+}
- for _, path := range paths {
- if path == "" {
- continue
+func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
+ if other.APIKey != "" {
+ base.APIKey = other.APIKey
+ }
+ // Only change these options if the provider is not a known provider
+ if !slices.Contains(provider.KnownProviders(), p) {
+ if other.BaseURL != "" {
+ base.BaseURL = other.BaseURL
}
-
- configPath := filepath.Join(path, configName)
- if exists, err := afero.Exists(fsys, configPath); err == nil && exists {
- configFiles = append(configFiles, configPath)
+ if other.ProviderType != "" {
+ base.ProviderType = other.ProviderType
+ }
+ if len(other.ExtraHeaders) > 0 {
+ if base.ExtraHeaders == nil {
+ base.ExtraHeaders = make(map[string]string)
+ }
+ maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
+ }
+ if len(other.ExtraParams) > 0 {
+ if base.ExtraParams == nil {
+ base.ExtraParams = make(map[string]string)
+ }
+ maps.Copy(base.ExtraParams, other.ExtraParams)
}
}
- return configFiles, nil
-}
-
-// configureViper sets up viper's configuration paths and environment variables.
-func configureViper() {
- viper.SetConfigType("json")
- // Create the three finders
- windowsFinder := configFinder{appName: appName, dotPrefix: false, paths: []string{
- "$USERPROFILE",
- fmt.Sprintf("$APPDATA/%s", appName),
- fmt.Sprintf("$LOCALAPPDATA/%s", appName),
- }}
+ if other.Disabled {
+ base.Disabled = other.Disabled
+ }
- unixFinder := configFinder{appName: appName, dotPrefix: false, paths: []string{
- "$HOME",
- fmt.Sprintf("$XDG_CONFIG_HOME/%s", appName),
- fmt.Sprintf("$HOME/.config/%s", appName),
- }}
+ if other.DefaultLargeModel != "" {
+ base.DefaultLargeModel = other.DefaultLargeModel
+ }
+ // Add new models if they don't exist
+ if other.Models != nil {
+ for _, model := range other.Models {
+ // check if the model already exists
+ exists := false
+ for _, existingModel := range base.Models {
+ if existingModel.ID == model.ID {
+ exists = true
+ break
+ }
+ }
+ if !exists {
+ base.Models = append(base.Models, model)
+ }
+ }
+ }
- localFinder := configFinder{appName: appName, dotPrefix: true, paths: []string{
- ".",
- }}
+ return base
+}
- // Use all finders with viper
- viper.SetOptions(viper.WithFinder(viper.Finders(windowsFinder, unixFinder, localFinder)))
- viper.SetEnvPrefix(strings.ToUpper(appName))
- viper.AutomaticEnv()
+func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
+ if !slices.Contains(provider.KnownProviders(), p) {
+ if providerConfig.ProviderType != provider.TypeOpenAI {
+ return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
+ }
+ if providerConfig.BaseURL == "" {
+ return errors.New("base URL must be set for custom providers")
+ }
+ if providerConfig.APIKey == "" {
+ return errors.New("API key must be set for custom providers")
+ }
+ }
+ return nil
}
-// setDefaults configures default values for configuration options.
-func setDefaults(debug bool) {
- viper.SetDefault("data.directory", defaultDataDirectory)
- viper.SetDefault("contextPaths", defaultContextPaths)
- viper.SetDefault("tui.theme", "crush")
- viper.SetDefault("autoCompact", true)
+func mergeModels(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
+ base.Models.Large = cfg.Models.Large
+ }
- if debug {
- viper.SetDefault("debug", true)
- viper.Set("log.level", "debug")
- } else {
- viper.SetDefault("debug", false)
- viper.SetDefault("log.level", defaultLogLevel)
+ if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
+ base.Models.Small = cfg.Models.Small
+ }
}
}
-// setProviderDefaults configures LLM provider defaults based on provider provided by
-// environment variables and configuration file.
-func setProviderDefaults() {
- // Set all API keys we can find in the environment
- if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.anthropic.apiKey", apiKey)
- }
- if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.openai.apiKey", apiKey)
- }
- if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.gemini.apiKey", apiKey)
- }
- if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.groq.apiKey", apiKey)
- }
- if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.openrouter.apiKey", apiKey)
- }
- if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.xai.apiKey", apiKey)
- }
- if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" {
- // api-key may be empty when using Entra ID credentials – that's okay
- viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
- }
+func mergeOptions(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ baseOptions := base.Options
+ other := cfg.Options
+ if len(other.ContextPaths) > 0 {
+ baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
+ }
- // Use this order to set the default models
- // 1. Anthropic
- // 2. OpenAI
- // 3. Google Gemini
- // 4. Groq
- // 5. OpenRouter
- // 6. AWS Bedrock
- // 7. Azure
- // 8. Google Cloud VertexAI
+ if other.TUI.CompactMode {
+ baseOptions.TUI.CompactMode = other.TUI.CompactMode
+ }
+
+ if other.Debug {
+ baseOptions.Debug = other.Debug
+ }
+
+ if other.DebugLSP {
+ baseOptions.DebugLSP = other.DebugLSP
+ }
+
+ if other.DisableAutoSummarize {
+ baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
+ }
- // Anthropic configuration
- if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.Claude4Sonnet)
- viper.SetDefault("agents.summarizer.model", models.Claude4Sonnet)
- viper.SetDefault("agents.task.model", models.Claude4Sonnet)
- viper.SetDefault("agents.title.model", models.Claude4Sonnet)
- return
+ if other.DataDirectory != "" {
+ baseOptions.DataDirectory = other.DataDirectory
+ }
+ base.Options = baseOptions
}
+}
- // OpenAI configuration
- if key := viper.GetString("providers.openai.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.GPT41)
- viper.SetDefault("agents.summarizer.model", models.GPT41)
- viper.SetDefault("agents.task.model", models.GPT41Mini)
- viper.SetDefault("agents.title.model", models.GPT41Mini)
- return
+func mergeAgents(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ for agentID, newAgent := range cfg.Agents {
+ if _, ok := base.Agents[agentID]; !ok {
+ newAgent.ID = agentID
+ if newAgent.Model == "" {
+ newAgent.Model = LargeModel
+ }
+ if len(newAgent.ContextPaths) > 0 {
+ newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
+ } else {
+ newAgent.ContextPaths = base.Options.ContextPaths
+ }
+ base.Agents[agentID] = newAgent
+ } else {
+ baseAgent := base.Agents[agentID]
+
+ if agentID == AgentCoder || agentID == AgentTask {
+ if newAgent.Model != "" {
+ baseAgent.Model = newAgent.Model
+ }
+ if newAgent.AllowedMCP != nil {
+ baseAgent.AllowedMCP = newAgent.AllowedMCP
+ }
+ if newAgent.AllowedLSP != nil {
+ baseAgent.AllowedLSP = newAgent.AllowedLSP
+ }
+ // Context paths are additive for known agents too
+ if len(newAgent.ContextPaths) > 0 {
+ baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
+ }
+ } else {
+ if newAgent.Name != "" {
+ baseAgent.Name = newAgent.Name
+ }
+ if newAgent.Description != "" {
+ baseAgent.Description = newAgent.Description
+ }
+ if newAgent.Model != "" {
+ baseAgent.Model = newAgent.Model
+ } else if baseAgent.Model == "" {
+ baseAgent.Model = LargeModel
+ }
+
+ baseAgent.Disabled = newAgent.Disabled
+
+ if newAgent.AllowedTools != nil {
+ baseAgent.AllowedTools = newAgent.AllowedTools
+ }
+ if newAgent.AllowedMCP != nil {
+ baseAgent.AllowedMCP = newAgent.AllowedMCP
+ }
+ if newAgent.AllowedLSP != nil {
+ baseAgent.AllowedLSP = newAgent.AllowedLSP
+ }
+ if len(newAgent.ContextPaths) > 0 {
+ baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
+ }
+ }
+
+ base.Agents[agentID] = baseAgent
+ }
+ }
}
+}
- // Google Gemini configuration
- if key := viper.GetString("providers.gemini.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.Gemini25)
- viper.SetDefault("agents.summarizer.model", models.Gemini25)
- viper.SetDefault("agents.task.model", models.Gemini25Flash)
- viper.SetDefault("agents.title.model", models.Gemini25Flash)
- return
+func mergeMCPs(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ maps.Copy(base.MCP, cfg.MCP)
}
+}
- // Groq configuration
- if key := viper.GetString("providers.groq.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.QWENQwq)
- viper.SetDefault("agents.summarizer.model", models.QWENQwq)
- viper.SetDefault("agents.task.model", models.QWENQwq)
- viper.SetDefault("agents.title.model", models.QWENQwq)
- return
+func mergeLSPs(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ maps.Copy(base.LSP, cfg.LSP)
}
+}
- // OpenRouter configuration
- if key := viper.GetString("providers.openrouter.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.OpenRouterClaude37Sonnet)
- viper.SetDefault("agents.summarizer.model", models.OpenRouterClaude37Sonnet)
- viper.SetDefault("agents.task.model", models.OpenRouterClaude37Sonnet)
- viper.SetDefault("agents.title.model", models.OpenRouterClaude35Haiku)
- return
+func mergeProviderConfigs(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ for providerName, p := range cfg.Providers {
+ p.ID = providerName
+ if _, ok := base.Providers[providerName]; !ok {
+ base.Providers[providerName] = p
+ } else {
+ base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
+ }
+ }
}
- // XAI configuration
- if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.XAIGrok3Beta)
- viper.SetDefault("agents.summarizer.model", models.XAIGrok3Beta)
- viper.SetDefault("agents.task.model", models.XAIGrok3Beta)
- viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta)
- return
+ finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
+ for providerName, providerConfig := range base.Providers {
+ err := validateProvider(providerName, providerConfig)
+ if err != nil {
+ logging.Warn("Skipping provider", "name", providerName, "error", err)
+ continue // Skip invalid providers
+ }
+ finalProviders[providerName] = providerConfig
}
+ base.Providers = finalProviders
+}
- // AWS Bedrock configuration
- if hasAWSCredentials() {
- viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet)
- viper.SetDefault("agents.summarizer.model", models.BedrockClaude37Sonnet)
- viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet)
- viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet)
- return
+func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig {
+ switch providerID {
+ case provider.InferenceProviderAnthropic:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeAnthropic,
+ }
+ case provider.InferenceProviderOpenAI:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeOpenAI,
+ }
+ case provider.InferenceProviderGemini:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeGemini,
+ }
+ case provider.InferenceProviderBedrock:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeBedrock,
+ }
+ case provider.InferenceProviderAzure:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeAzure,
+ }
+ case provider.InferenceProviderOpenRouter:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeOpenAI,
+ BaseURL: "https://openrouter.ai/api/v1",
+ ExtraHeaders: map[string]string{
+ "HTTP-Referer": "crush.charm.land",
+ "X-Title": "Crush",
+ },
+ }
+ case provider.InferenceProviderXAI:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeXAI,
+ BaseURL: "https://api.x.ai/v1",
+ }
+ case provider.InferenceProviderVertexAI:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeVertexAI,
+ }
+ default:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeOpenAI,
+ }
}
+}
- // Azure OpenAI configuration
- if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" {
- viper.SetDefault("agents.coder.model", models.AzureGPT41)
- viper.SetDefault("agents.summarizer.model", models.AzureGPT41)
- viper.SetDefault("agents.task.model", models.AzureGPT41Mini)
- viper.SetDefault("agents.title.model", models.AzureGPT41Mini)
- return
+func defaultConfigBasedOnEnv() *Config {
+ cfg := &Config{
+ Options: Options{
+ DataDirectory: defaultDataDirectory,
+ ContextPaths: defaultContextPaths,
+ },
+ Providers: make(map[provider.InferenceProvider]ProviderConfig),
+ Agents: make(map[AgentID]Agent),
+ LSP: make(map[string]LSPConfig),
+ MCP: make(map[string]MCP),
+ }
+
+ providers := Providers()
+
+ for _, p := range providers {
+ if strings.HasPrefix(p.APIKey, "$") {
+ envVar := strings.TrimPrefix(p.APIKey, "$")
+ if apiKey := os.Getenv(envVar); apiKey != "" {
+ providerConfig := providerDefaultConfig(p.ID)
+ providerConfig.APIKey = apiKey
+ providerConfig.DefaultLargeModel = p.DefaultLargeModelID
+ providerConfig.DefaultSmallModel = p.DefaultSmallModelID
+ baseURL := p.APIEndpoint
+ if strings.HasPrefix(baseURL, "$") {
+ envVar := strings.TrimPrefix(baseURL, "$")
+ baseURL = os.Getenv(envVar)
+ }
+ providerConfig.BaseURL = baseURL
+ for _, model := range p.Models {
+ configModel := Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ SupportsImages: model.SupportsImages,
+ }
+ // Set reasoning effort for reasoning models
+ if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
+ configModel.HasReasoningEffort = model.HasReasoningEffort
+ configModel.ReasoningEffort = model.DefaultReasoningEffort
+ }
+ providerConfig.Models = append(providerConfig.Models, configModel)
+ }
+ cfg.Providers[p.ID] = providerConfig
+ }
+ }
+ }
+ // TODO: support local models
+
+ if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
+ providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
+ providerConfig.ExtraParams = map[string]string{
+ "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
+ "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
+ }
+ // Find the VertexAI provider definition to get default models
+ for _, p := range providers {
+ if p.ID == provider.InferenceProviderVertexAI {
+ providerConfig.DefaultLargeModel = p.DefaultLargeModelID
+ providerConfig.DefaultSmallModel = p.DefaultSmallModelID
+ for _, model := range p.Models {
+ configModel := Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ SupportsImages: model.SupportsImages,
+ }
+ // Set reasoning effort for reasoning models
+ if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
+ configModel.HasReasoningEffort = model.HasReasoningEffort
+ configModel.ReasoningEffort = model.DefaultReasoningEffort
+ }
+ providerConfig.Models = append(providerConfig.Models, configModel)
+ }
+ break
+ }
+ }
+ cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
}
- // Google Cloud VertexAI configuration
- if hasVertexAICredentials() {
- viper.SetDefault("agents.coder.model", models.VertexAIGemini25)
- viper.SetDefault("agents.summarizer.model", models.VertexAIGemini25)
- viper.SetDefault("agents.task.model", models.VertexAIGemini25Flash)
- viper.SetDefault("agents.title.model", models.VertexAIGemini25Flash)
- return
+ if hasAWSCredentials() {
+ providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
+ providerConfig.ExtraParams = map[string]string{
+ "region": os.Getenv("AWS_DEFAULT_REGION"),
+ }
+ if providerConfig.ExtraParams["region"] == "" {
+ providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
+ }
+ // Find the Bedrock provider definition to get default models
+ for _, p := range providers {
+ if p.ID == provider.InferenceProviderBedrock {
+ providerConfig.DefaultLargeModel = p.DefaultLargeModelID
+ providerConfig.DefaultSmallModel = p.DefaultSmallModelID
+ for _, model := range p.Models {
+ configModel := Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ SupportsImages: model.SupportsImages,
+ }
+ // Set reasoning effort for reasoning models
+ if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
+ configModel.HasReasoningEffort = model.HasReasoningEffort
+ configModel.ReasoningEffort = model.DefaultReasoningEffort
+ }
+ providerConfig.Models = append(providerConfig.Models, configModel)
+ }
+ break
+ }
+ }
+ cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
}
+ return cfg
}
-// hasAWSCredentials checks if AWS credentials are available in the environment.
func hasAWSCredentials() bool {
- // Check for explicit AWS credentials
if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
return true
}
- // Check for AWS profile
if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
return true
}
- // Check for AWS region
if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
return true
}
- // Check if running on EC2 with instance profile
if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
return true
@@ -423,495 +829,673 @@ func hasAWSCredentials() bool {
return false
}
-// hasVertexAICredentials checks if VertexAI credentials are available in the environment.
-func hasVertexAICredentials() bool {
- // Check for explicit VertexAI parameters
- if os.Getenv("VERTEXAI_PROJECT") != "" && os.Getenv("VERTEXAI_LOCATION") != "" {
- return true
+func WorkingDirectory() string {
+ return cwd
+}
+
+// TODO: Handle error state
+
+func GetAgentModel(agentID AgentID) Model {
+ cfg := Get()
+ agent, ok := cfg.Agents[agentID]
+ if !ok {
+ logging.Error("Agent not found", "agent_id", agentID)
+ return Model{}
}
- // Check for Google Cloud project and location
- if os.Getenv("GOOGLE_CLOUD_PROJECT") != "" && (os.Getenv("GOOGLE_CLOUD_REGION") != "" || os.Getenv("GOOGLE_CLOUD_LOCATION") != "") {
- return true
+
+ var model PreferredModel
+ switch agent.Model {
+ case LargeModel:
+ model = cfg.Models.Large
+ case SmallModel:
+ model = cfg.Models.Small
+ default:
+ logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
+ model = cfg.Models.Large // Fallback to large model
}
- return false
+ providerConfig, ok := cfg.Providers[model.Provider]
+ if !ok {
+ logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
+ return Model{}
+ }
+
+ for _, m := range providerConfig.Models {
+ if m.ID == model.ModelID {
+ return m
+ }
+ }
+
+ logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
+ return Model{}
}
-// readConfig handles the result of reading a configuration file.
-func readConfig(err error) error {
- if err == nil {
- return nil
+// GetAgentEffectiveMaxTokens returns the effective max tokens for an agent,
+// considering any overrides from the preferred model configuration
+func GetAgentEffectiveMaxTokens(agentID AgentID) int64 {
+ cfg := Get()
+ agent, ok := cfg.Agents[agentID]
+ if !ok {
+ logging.Error("Agent not found", "agent_id", agentID)
+ return 0
}
- // It's okay if the config file doesn't exist
- if _, ok := err.(viper.ConfigFileNotFoundError); ok {
- return nil
+ var preferredModel PreferredModel
+ switch agent.Model {
+ case LargeModel:
+ preferredModel = cfg.Models.Large
+ case SmallModel:
+ preferredModel = cfg.Models.Small
+ default:
+ logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
+ preferredModel = cfg.Models.Large // Fallback to large model
}
- return fmt.Errorf("failed to read config: %w", err)
-}
+ // Get the base model configuration
+ baseModel := GetAgentModel(agentID)
+ if baseModel.ID == "" {
+ return 0
+ }
-// mergeLocalConfig loads and merges configuration from the local directory.
-func mergeLocalConfig(workingDir string) {
- local := viper.New()
- local.SetConfigName(fmt.Sprintf(".%s", appName))
- local.SetConfigType("json")
- local.AddConfigPath(workingDir)
+ // Start with the default max tokens from the base model
+ maxTokens := baseModel.DefaultMaxTokens
- // Merge local config if it exists
- if err := local.ReadInConfig(); err == nil {
- viper.MergeConfigMap(local.AllSettings())
+ // Override with preferred model max tokens if set
+ if preferredModel.MaxTokens > 0 {
+ maxTokens = preferredModel.MaxTokens
}
+
+ return maxTokens
}
-// applyDefaultValues sets default values for configuration fields that need processing.
-func applyDefaultValues() {
- // Set default MCP type if not specified
- for k, v := range cfg.MCPServers {
- if v.Type == "" {
- v.Type = MCPStdio
- cfg.MCPServers[k] = v
- }
+func GetAgentProvider(agentID AgentID) ProviderConfig {
+ cfg := Get()
+ agent, ok := cfg.Agents[agentID]
+ if !ok {
+ logging.Error("Agent not found", "agent_id", agentID)
+ return ProviderConfig{}
}
-}
-// It validates model IDs and providers, ensuring they are supported.
-func validateAgent(cfg *Config, name AgentName, agent Agent) error {
- // Check if model exists
- model, modelExists := models.SupportedModels[agent.Model]
- if !modelExists {
- logging.Warn("unsupported model configured, reverting to default",
- "agent", name,
- "configured_model", agent.Model)
+ var model PreferredModel
+ switch agent.Model {
+ case LargeModel:
+ model = cfg.Models.Large
+ case SmallModel:
+ model = cfg.Models.Small
+ default:
+ logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
+ model = cfg.Models.Large // Fallback to large model
+ }
- // Set default model based on available providers
- if setDefaultModelForAgent(name) {
- logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
- } else {
- return fmt.Errorf("no valid provider available for agent %s", name)
- }
- return nil
+ providerConfig, ok := cfg.Providers[model.Provider]
+ if !ok {
+ logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
+ return ProviderConfig{}
}
- // Check if provider for the model is configured
- provider := model.Provider
- providerCfg, providerExists := cfg.Providers[provider]
+ return providerConfig
+}
- if !providerExists {
- // Provider not configured, check if we have environment variables
- apiKey := getProviderAPIKey(provider)
- if apiKey == "" {
- logging.Warn("provider not configured for model, reverting to default",
- "agent", name,
- "model", agent.Model,
- "provider", provider)
+func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
+ cfg := Get()
+ providerConfig, ok := cfg.Providers[provider]
+ if !ok {
+ logging.Error("Provider not found", "provider", provider)
+ return Model{}
+ }
- // Set default model based on available providers
- if setDefaultModelForAgent(name) {
- logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
- } else {
- return fmt.Errorf("no valid provider available for agent %s", name)
- }
- } else {
- // Add provider with API key from environment
- cfg.Providers[provider] = Provider{
- APIKey: apiKey,
- }
- logging.Info("added provider from environment", "provider", provider)
- }
- } else if providerCfg.Disabled || providerCfg.APIKey == "" {
- // Provider is disabled or has no API key
- logging.Warn("provider is disabled or has no API key, reverting to default",
- "agent", name,
- "model", agent.Model,
- "provider", provider)
-
- // Set default model based on available providers
- if setDefaultModelForAgent(name) {
- logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
- } else {
- return fmt.Errorf("no valid provider available for agent %s", name)
+ for _, model := range providerConfig.Models {
+ if model.ID == modelID {
+ return model
}
}
- // Validate max tokens
- if agent.MaxTokens <= 0 {
- logging.Warn("invalid max tokens, setting to default",
- "agent", name,
- "model", agent.Model,
- "max_tokens", agent.MaxTokens)
+ logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
+ return Model{}
+}
- // Update the agent with default max tokens
- updatedAgent := cfg.Agents[name]
- if model.DefaultMaxTokens > 0 {
- updatedAgent.MaxTokens = model.DefaultMaxTokens
- } else {
- updatedAgent.MaxTokens = MaxTokensFallbackDefault
- }
- cfg.Agents[name] = updatedAgent
- } else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 {
- // Ensure max tokens doesn't exceed half the context window (reasonable limit)
- logging.Warn("max tokens exceeds half the context window, adjusting",
- "agent", name,
- "model", agent.Model,
- "max_tokens", agent.MaxTokens,
- "context_window", model.ContextWindow)
-
- // Update the agent with adjusted max tokens
- updatedAgent := cfg.Agents[name]
- updatedAgent.MaxTokens = model.ContextWindow / 2
- cfg.Agents[name] = updatedAgent
- }
-
- // Validate reasoning effort for models that support reasoning
- if model.CanReason && provider == models.ProviderOpenAI || provider == models.ProviderLocal {
- if agent.ReasoningEffort == "" {
- // Set default reasoning effort for models that support it
- logging.Info("setting default reasoning effort for model that supports reasoning",
- "agent", name,
- "model", agent.Model)
-
- // Update the agent with default reasoning effort
- updatedAgent := cfg.Agents[name]
- updatedAgent.ReasoningEffort = "medium"
- cfg.Agents[name] = updatedAgent
- } else {
- // Check if reasoning effort is valid (low, medium, high)
- effort := strings.ToLower(agent.ReasoningEffort)
- if effort != "low" && effort != "medium" && effort != "high" {
- logging.Warn("invalid reasoning effort, setting to medium",
- "agent", name,
- "model", agent.Model,
- "reasoning_effort", agent.ReasoningEffort)
+func GetModel(modelType ModelType) Model {
+ cfg := Get()
+ var model PreferredModel
+ switch modelType {
+ case LargeModel:
+ model = cfg.Models.Large
+ case SmallModel:
+ model = cfg.Models.Small
+ default:
+ model = cfg.Models.Large // Fallback to large model
+ }
+ providerConfig, ok := cfg.Providers[model.Provider]
+ if !ok {
+ return Model{}
+ }
- // Update the agent with valid reasoning effort
- updatedAgent := cfg.Agents[name]
- updatedAgent.ReasoningEffort = "medium"
- cfg.Agents[name] = updatedAgent
- }
+ for _, m := range providerConfig.Models {
+ if m.ID == model.ModelID {
+ return m
}
- } else if !model.CanReason && agent.ReasoningEffort != "" {
- // Model doesn't support reasoning but reasoning effort is set
- logging.Warn("model doesn't support reasoning but reasoning effort is set, ignoring",
- "agent", name,
- "model", agent.Model,
- "reasoning_effort", agent.ReasoningEffort)
-
- // Update the agent to remove reasoning effort
- updatedAgent := cfg.Agents[name]
- updatedAgent.ReasoningEffort = ""
- cfg.Agents[name] = updatedAgent
}
+ return Model{}
+}
+func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
+ cfg := Get()
+ switch modelType {
+ case LargeModel:
+ cfg.Models.Large = model
+ case SmallModel:
+ cfg.Models.Small = model
+ default:
+ return fmt.Errorf("unknown model type: %s", modelType)
+ }
return nil
}
-// Validate checks if the configuration is valid and applies defaults where needed.
-func Validate() error {
- if cfg == nil {
- return fmt.Errorf("config not loaded")
+// ValidationError represents a configuration validation error
+type ValidationError struct {
+ Field string
+ Message string
+}
+
+func (e ValidationError) Error() string {
+ return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
+}
+
+// ValidationErrors represents multiple validation errors
+type ValidationErrors []ValidationError
+
+func (e ValidationErrors) Error() string {
+ if len(e) == 0 {
+ return "no validation errors"
+ }
+ if len(e) == 1 {
+ return e[0].Error()
}
- // Validate agent models
- for name, agent := range cfg.Agents {
- if err := validateAgent(cfg, name, agent); err != nil {
- return err
- }
+ var messages []string
+ for _, err := range e {
+ messages = append(messages, err.Error())
}
+ return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
+}
+
+// HasErrors returns true if there are any validation errors
+func (e ValidationErrors) HasErrors() bool {
+ return len(e) > 0
+}
+
+// Add appends a new validation error
+func (e *ValidationErrors) Add(field, message string) {
+ *e = append(*e, ValidationError{Field: field, Message: message})
+}
+
+// Validate performs comprehensive validation of the configuration
+func (c *Config) Validate() error {
+ var errors ValidationErrors
// Validate providers
- for provider, providerCfg := range cfg.Providers {
- if providerCfg.APIKey == "" && !providerCfg.Disabled {
- logging.Warn("provider has no API key, marking as disabled", "provider", provider)
- providerCfg.Disabled = true
- cfg.Providers[provider] = providerCfg
- }
- }
+ c.validateProviders(&errors)
+
+ // Validate models
+ c.validateModels(&errors)
+
+ // Validate agents
+ c.validateAgents(&errors)
+
+ // Validate options
+ c.validateOptions(&errors)
+
+ // Validate MCP configurations
+ c.validateMCPs(&errors)
// Validate LSP configurations
- for language, lspConfig := range cfg.LSP {
- if lspConfig.Command == "" && !lspConfig.Disabled {
- logging.Warn("LSP configuration has no command, marking as disabled", "language", language)
- lspConfig.Disabled = true
- cfg.LSP[language] = lspConfig
- }
+ c.validateLSPs(&errors)
+
+ // Validate cross-references
+ c.validateCrossReferences(&errors)
+
+ // Validate completeness
+ c.validateCompleteness(&errors)
+
+ if errors.HasErrors() {
+ return errors
}
return nil
}
-// getProviderAPIKey gets the API key for a provider from environment variables
-func getProviderAPIKey(provider models.ModelProvider) string {
- switch provider {
- case models.ProviderAnthropic:
- return os.Getenv("ANTHROPIC_API_KEY")
- case models.ProviderOpenAI:
- return os.Getenv("OPENAI_API_KEY")
- case models.ProviderGemini:
- return os.Getenv("GEMINI_API_KEY")
- case models.ProviderGROQ:
- return os.Getenv("GROQ_API_KEY")
- case models.ProviderAzure:
- return os.Getenv("AZURE_OPENAI_API_KEY")
- case models.ProviderOpenRouter:
- return os.Getenv("OPENROUTER_API_KEY")
- case models.ProviderBedrock:
- if hasAWSCredentials() {
- return "aws-credentials-available"
- }
- case models.ProviderVertexAI:
- if hasVertexAICredentials() {
- return "vertex-ai-credentials-available"
- }
- }
- return ""
-}
-
-// setDefaultModelForAgent sets a default model for an agent based on available providers
-func setDefaultModelForAgent(agent AgentName) bool {
- // Check providers in order of preference
- if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
- maxTokens := int64(5000)
- if agent == AgentTitle {
- maxTokens = 80
- }
- cfg.Agents[agent] = Agent{
- Model: models.Claude37Sonnet,
- MaxTokens: maxTokens,
- }
- return true
+// validateProviders validates all provider configurations
+func (c *Config) validateProviders(errors *ValidationErrors) {
+ if c.Providers == nil {
+ c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
}
- if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
- var model models.ModelID
- maxTokens := int64(5000)
- reasoningEffort := ""
+ knownProviders := provider.KnownProviders()
+ validTypes := []provider.Type{
+ provider.TypeOpenAI,
+ provider.TypeAnthropic,
+ provider.TypeGemini,
+ provider.TypeAzure,
+ provider.TypeBedrock,
+ provider.TypeVertexAI,
+ provider.TypeXAI,
+ }
- switch agent {
- case AgentTitle:
- model = models.GPT41Mini
- maxTokens = 80
- case AgentTask:
- model = models.GPT41Mini
- default:
- model = models.GPT41
- }
+ for providerID, providerConfig := range c.Providers {
+ fieldPrefix := fmt.Sprintf("providers.%s", providerID)
- // Check if model supports reasoning
- if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason {
- reasoningEffort = "medium"
+ // Validate API key for non-disabled providers
+ if !providerConfig.Disabled && providerConfig.APIKey == "" {
+ // Special case for AWS Bedrock and VertexAI which may use other auth methods
+ if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
+ errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
+ }
}
- cfg.Agents[agent] = Agent{
- Model: model,
- MaxTokens: maxTokens,
- ReasoningEffort: reasoningEffort,
+ // Validate provider type
+ validType := slices.Contains(validTypes, providerConfig.ProviderType)
+ if !validType {
+ errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
}
- return true
- }
- if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
- var model models.ModelID
- maxTokens := int64(5000)
- reasoningEffort := ""
+ // Validate custom providers
+ isKnownProvider := slices.Contains(knownProviders, providerID)
- switch agent {
- case AgentTitle:
- model = models.OpenRouterClaude35Haiku
- maxTokens = 80
- case AgentTask:
- model = models.OpenRouterClaude37Sonnet
- default:
- model = models.OpenRouterClaude37Sonnet
+ if !isKnownProvider {
+ // Custom provider validation
+ if providerConfig.BaseURL == "" {
+ errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
+ }
+ if providerConfig.ProviderType != provider.TypeOpenAI {
+ errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
+ }
}
- // Check if model supports reasoning
- if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason {
- reasoningEffort = "medium"
+ // Validate models
+ modelIDs := make(map[string]bool)
+ for i, model := range providerConfig.Models {
+ modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
+
+ // Check for duplicate model IDs
+ if modelIDs[model.ID] {
+ errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
+ }
+ modelIDs[model.ID] = true
+
+ // Validate required model fields
+ if model.ID == "" {
+ errors.Add(modelFieldPrefix+".id", "model ID is required")
+ }
+ if model.Name == "" {
+ errors.Add(modelFieldPrefix+".name", "model name is required")
+ }
+ if model.ContextWindow <= 0 {
+ errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
+ }
+ if model.DefaultMaxTokens <= 0 {
+ errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
+ }
+ if model.DefaultMaxTokens > model.ContextWindow {
+ errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
+ }
+
+ // Validate cost fields
+ if model.CostPer1MIn < 0 {
+ errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
+ }
+ if model.CostPer1MOut < 0 {
+ errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
+ }
+ if model.CostPer1MInCached < 0 {
+ errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
+ }
+ if model.CostPer1MOutCached < 0 {
+ errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
+ }
}
- cfg.Agents[agent] = Agent{
- Model: model,
- MaxTokens: maxTokens,
- ReasoningEffort: reasoningEffort,
+ // Validate default model references
+ if providerConfig.DefaultLargeModel != "" {
+ if !modelIDs[providerConfig.DefaultLargeModel] {
+ errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
+ }
}
- return true
+ if providerConfig.DefaultSmallModel != "" {
+ if !modelIDs[providerConfig.DefaultSmallModel] {
+ errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
+ }
+ }
+
+ // Validate provider-specific requirements
+ c.validateProviderSpecific(providerID, providerConfig, errors)
}
+}
- if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
- var model models.ModelID
- maxTokens := int64(5000)
+// validateProviderSpecific validates provider-specific requirements
+func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
+ fieldPrefix := fmt.Sprintf("providers.%s", providerID)
- if agent == AgentTitle {
- model = models.Gemini25Flash
- maxTokens = 80
- } else {
- model = models.Gemini25
+ switch providerID {
+ case provider.InferenceProviderVertexAI:
+ if !providerConfig.Disabled {
+ if providerConfig.ExtraParams == nil {
+ errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
+ } else {
+ if providerConfig.ExtraParams["project"] == "" {
+ errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
+ }
+ if providerConfig.ExtraParams["location"] == "" {
+ errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
+ }
+ }
}
-
- cfg.Agents[agent] = Agent{
- Model: model,
- MaxTokens: maxTokens,
+ case provider.InferenceProviderBedrock:
+ if !providerConfig.Disabled {
+ if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
+ errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
+ }
+ // Check for AWS credentials in environment
+ if !hasAWSCredentials() {
+ errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
+ }
}
- return true
}
+}
- if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
- maxTokens := int64(5000)
- if agent == AgentTitle {
- maxTokens = 80
+// validateModels validates preferred model configurations
+func (c *Config) validateModels(errors *ValidationErrors) {
+ // Validate large model
+ if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
+ if c.Models.Large.ModelID == "" {
+ errors.Add("models.large.model_id", "large model ID is required when provider is set")
+ }
+ if c.Models.Large.Provider == "" {
+ errors.Add("models.large.provider", "large model provider is required when model ID is set")
}
- cfg.Agents[agent] = Agent{
- Model: models.QWENQwq,
- MaxTokens: maxTokens,
+ // Check if provider exists and is not disabled
+ if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
+ if providerConfig.Disabled {
+ errors.Add("models.large.provider", "large model provider is disabled")
+ }
+
+ // Check if model exists in provider
+ modelExists := false
+ for _, model := range providerConfig.Models {
+ if model.ID == c.Models.Large.ModelID {
+ modelExists = true
+ break
+ }
+ }
+ if !modelExists {
+ errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
+ }
+ } else {
+ errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
}
- return true
}
- if hasAWSCredentials() {
- maxTokens := int64(5000)
- if agent == AgentTitle {
- maxTokens = 80
+ // Validate small model
+ if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
+ if c.Models.Small.ModelID == "" {
+ errors.Add("models.small.model_id", "small model ID is required when provider is set")
}
+ if c.Models.Small.Provider == "" {
+ errors.Add("models.small.provider", "small model provider is required when model ID is set")
+ }
+
+ // Check if provider exists and is not disabled
+ if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
+ if providerConfig.Disabled {
+ errors.Add("models.small.provider", "small model provider is disabled")
+ }
- cfg.Agents[agent] = Agent{
- Model: models.BedrockClaude37Sonnet,
- MaxTokens: maxTokens,
- ReasoningEffort: "medium", // Claude models support reasoning
+ // Check if model exists in provider
+ modelExists := false
+ for _, model := range providerConfig.Models {
+ if model.ID == c.Models.Small.ModelID {
+ modelExists = true
+ break
+ }
+ }
+ if !modelExists {
+ errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
+ }
+ } else {
+ errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
}
- return true
+ }
+}
+
+// validateAgents validates agent configurations
+func (c *Config) validateAgents(errors *ValidationErrors) {
+ if c.Agents == nil {
+ c.Agents = make(map[AgentID]Agent)
}
- if hasVertexAICredentials() {
- var model models.ModelID
- maxTokens := int64(5000)
+ validTools := []string{
+ "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
+ }
- if agent == AgentTitle {
- model = models.VertexAIGemini25Flash
- maxTokens = 80
- } else {
- model = models.VertexAIGemini25
+ for agentID, agent := range c.Agents {
+ fieldPrefix := fmt.Sprintf("agents.%s", agentID)
+
+ // Validate agent ID consistency
+ if agent.ID != agentID {
+ errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
}
- cfg.Agents[agent] = Agent{
- Model: model,
- MaxTokens: maxTokens,
+ // Validate required fields
+ if agent.ID == "" {
+ errors.Add(fieldPrefix+".id", "agent ID is required")
+ }
+ if agent.Name == "" {
+ errors.Add(fieldPrefix+".name", "agent name is required")
}
- return true
- }
- return false
-}
+ // Validate model type
+ if agent.Model != LargeModel && agent.Model != SmallModel {
+ errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
+ }
-func updateCfgFile(updateCfg func(config *Config)) error {
- if cfg == nil {
- return fmt.Errorf("config not loaded")
- }
+ // Validate allowed tools
+ if agent.AllowedTools != nil {
+ for i, tool := range agent.AllowedTools {
+ validTool := slices.Contains(validTools, tool)
+ if !validTool {
+ errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
+ }
+ }
+ }
- // Get the config file path
- configFile := viper.ConfigFileUsed()
- var configData []byte
- if configFile == "" {
- homeDir, err := os.UserHomeDir()
- if err != nil {
- return fmt.Errorf("failed to get home directory: %w", err)
+ // Validate MCP references
+ if agent.AllowedMCP != nil {
+ for mcpName := range agent.AllowedMCP {
+ if _, exists := c.MCP[mcpName]; !exists {
+ errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
+ }
+ }
}
- configFile = filepath.Join(homeDir, fmt.Sprintf(".%s.json", appName))
- logging.Info("config file not found, creating new one", "path", configFile)
- configData = []byte(`{}`)
- } else {
- // Read the existing config file
- data, err := os.ReadFile(configFile)
- if err != nil {
- return fmt.Errorf("failed to read config file: %w", err)
+
+ // Validate LSP references
+ if agent.AllowedLSP != nil {
+ for _, lspName := range agent.AllowedLSP {
+ if _, exists := c.LSP[lspName]; !exists {
+ errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
+ }
+ }
+ }
+
+ // Validate context paths (basic path validation)
+ for i, contextPath := range agent.ContextPaths {
+ if contextPath == "" {
+ errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
+ }
+ // Check for invalid characters in path
+ if strings.Contains(contextPath, "\x00") {
+ errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
+ }
}
- configData = data
- }
- // Parse the JSON
- var userCfg *Config
- if err := json.Unmarshal(configData, &userCfg); err != nil {
- return fmt.Errorf("failed to parse config file: %w", err)
+ // Validate known agents maintain their core properties
+ if agentID == AgentCoder {
+ if agent.Name != "Coder" {
+ errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
+ }
+ if agent.Description != "An agent that helps with executing coding tasks." {
+ errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
+ }
+ } else if agentID == AgentTask {
+ if agent.Name != "Task" {
+ errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
+ }
+ if agent.Description != "An agent that helps with searching for context and finding implementation details." {
+ errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
+ }
+ expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
+ if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
+ errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
+ }
+ }
}
+}
- updateCfg(userCfg)
+// validateOptions validates configuration options
+func (c *Config) validateOptions(errors *ValidationErrors) {
+ // Validate data directory
+ if c.Options.DataDirectory == "" {
+ errors.Add("options.data_directory", "data directory is required")
+ }
- // Write the updated config back to file
- updatedData, err := json.MarshalIndent(userCfg, "", " ")
- if err != nil {
- return fmt.Errorf("failed to marshal config: %w", err)
+ // Validate context paths
+ for i, contextPath := range c.Options.ContextPaths {
+ if contextPath == "" {
+ errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
+ }
+ if strings.Contains(contextPath, "\x00") {
+ errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
+ }
}
+}
- if err := os.WriteFile(configFile, updatedData, 0o644); err != nil {
- return fmt.Errorf("failed to write config file: %w", err)
+// validateMCPs validates MCP configurations
+func (c *Config) validateMCPs(errors *ValidationErrors) {
+ if c.MCP == nil {
+ c.MCP = make(map[string]MCP)
}
- return nil
-}
+ for mcpName, mcpConfig := range c.MCP {
+ fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
-// Get returns the current configuration.
-// It's safe to call this function multiple times.
-func Get() *Config {
- return cfg
-}
+ // Validate MCP type
+ if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
+ errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
+ }
-// WorkingDirectory returns the current working directory from the configuration.
-func WorkingDirectory() string {
- if cfg == nil {
- panic("config not loaded")
+ // Validate based on type
+ if mcpConfig.Type == MCPStdio {
+ if mcpConfig.Command == "" {
+ errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
+ }
+ } else if mcpConfig.Type == MCPSse {
+ if mcpConfig.URL == "" {
+ errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
+ }
+ }
}
- return cfg.WorkingDir
}
-func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error {
- if cfg == nil {
- panic("config not loaded")
+// validateLSPs validates LSP configurations
+func (c *Config) validateLSPs(errors *ValidationErrors) {
+ if c.LSP == nil {
+ c.LSP = make(map[string]LSPConfig)
}
- existingAgentCfg := cfg.Agents[agentName]
+ for lspName, lspConfig := range c.LSP {
+ fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
- model, ok := models.SupportedModels[modelID]
- if !ok {
- return fmt.Errorf("model %s not supported", modelID)
+ if lspConfig.Command == "" {
+ errors.Add(fieldPrefix+".command", "command is required for LSP")
+ }
}
+}
+
+// validateCrossReferences validates cross-references between different config sections
+func (c *Config) validateCrossReferences(errors *ValidationErrors) {
+ // Validate that agents can use their assigned model types
+ for agentID, agent := range c.Agents {
+ fieldPrefix := fmt.Sprintf("agents.%s", agentID)
+
+ var preferredModel PreferredModel
+ switch agent.Model {
+ case LargeModel:
+ preferredModel = c.Models.Large
+ case SmallModel:
+ preferredModel = c.Models.Small
+ }
- maxTokens := existingAgentCfg.MaxTokens
- if model.DefaultMaxTokens > 0 {
- maxTokens = model.DefaultMaxTokens
+ if preferredModel.Provider != "" {
+ if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
+ if providerConfig.Disabled {
+ errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
+ }
+ }
+ }
}
+}
- newAgentCfg := Agent{
- Model: modelID,
- MaxTokens: maxTokens,
- ReasoningEffort: existingAgentCfg.ReasoningEffort,
+// validateCompleteness validates that the configuration is complete and usable
+func (c *Config) validateCompleteness(errors *ValidationErrors) {
+ // Check for at least one valid, non-disabled provider
+ hasValidProvider := false
+ for _, providerConfig := range c.Providers {
+ if !providerConfig.Disabled {
+ hasValidProvider = true
+ break
+ }
+ }
+ if !hasValidProvider {
+ errors.Add("providers", "at least one non-disabled provider is required")
}
- cfg.Agents[agentName] = newAgentCfg
- if err := validateAgent(cfg, agentName, newAgentCfg); err != nil {
- // revert config update on failure
- cfg.Agents[agentName] = existingAgentCfg
- return fmt.Errorf("failed to update agent model: %w", err)
+ // Check that default agents exist
+ if _, exists := c.Agents[AgentCoder]; !exists {
+ errors.Add("agents", "coder agent is required")
+ }
+ if _, exists := c.Agents[AgentTask]; !exists {
+ errors.Add("agents", "task agent is required")
}
- return updateCfgFile(func(config *Config) {
- if config.Agents == nil {
- config.Agents = make(map[AgentName]Agent)
+ // Check that preferred models are set if providers exist
+ if hasValidProvider {
+ if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
+ errors.Add("models.large", "large preferred model must be configured when providers are available")
}
- config.Agents[agentName] = newAgentCfg
- })
+ if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
+ errors.Add("models.small", "small preferred model must be configured when providers are available")
+ }
+ }
}
-// UpdateTheme updates the theme in the configuration and writes it to the config file.
-func UpdateTheme(themeName string) error {
- if cfg == nil {
- return fmt.Errorf("config not loaded")
+// JSONSchemaExtend adds custom schema properties for AgentID
+func (AgentID) JSONSchemaExtend(schema *jsonschema.Schema) {
+ schema.Enum = []any{
+ string(AgentCoder),
+ string(AgentTask),
}
+}
- // Update the in-memory config
- cfg.TUI.Theme = themeName
+// JSONSchemaExtend adds custom schema properties for ModelType
+func (ModelType) JSONSchemaExtend(schema *jsonschema.Schema) {
+ schema.Enum = []any{
+ string(LargeModel),
+ string(SmallModel),
+ }
+}
- // Update the file config
- return updateCfgFile(func(config *Config) {
- config.TUI.Theme = themeName
- })
+// JSONSchemaExtend adds custom schema properties for MCPType
+func (MCPType) JSONSchemaExtend(schema *jsonschema.Schema) {
+ schema.Enum = []any{
+ string(MCPStdio),
+ string(MCPSse),
+ }
}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..2b4764489dd7795df0473c22eb529bbc65fc9a2b
--- /dev/null
+++ b/internal/config/config_test.go
@@ -0,0 +1,2079 @@
+package config
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "sync"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func reset() {
+ // Clear all environment variables that could affect config
+ envVarsToUnset := []string{
+ // API Keys
+ "ANTHROPIC_API_KEY",
+ "OPENAI_API_KEY",
+ "GEMINI_API_KEY",
+ "XAI_API_KEY",
+ "OPENROUTER_API_KEY",
+
+ // Google Cloud / VertexAI
+ "GOOGLE_GENAI_USE_VERTEXAI",
+ "GOOGLE_CLOUD_PROJECT",
+ "GOOGLE_CLOUD_LOCATION",
+
+ // AWS Credentials
+ "AWS_ACCESS_KEY_ID",
+ "AWS_SECRET_ACCESS_KEY",
+ "AWS_REGION",
+ "AWS_DEFAULT_REGION",
+ "AWS_PROFILE",
+ "AWS_DEFAULT_PROFILE",
+ "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
+ "AWS_CONTAINER_CREDENTIALS_FULL_URI",
+
+ // Other
+ "CRUSH_DEV_DEBUG",
+ }
+
+ for _, envVar := range envVarsToUnset {
+ os.Unsetenv(envVar)
+ }
+
+ // Reset singleton
+ once = sync.Once{}
+ instance = nil
+ cwd = ""
+ testConfigDir = ""
+
+ // Enable mock providers for all tests to avoid API calls
+ UseMockProviders = true
+ ResetProviders()
+}
+
+// Core Configuration Loading Tests
+
+func TestInit_ValidWorkingDirectory(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.NotNil(t, cfg)
+ assert.Equal(t, cwdDir, WorkingDirectory())
+ assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory)
+ assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
+}
+
+func TestInit_WithDebugFlag(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ cfg, err := Init(cwdDir, true)
+
+ require.NoError(t, err)
+ assert.True(t, cfg.Options.Debug)
+}
+
+func TestInit_SingletonBehavior(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ cfg1, err1 := Init(cwdDir, false)
+ cfg2, err2 := Init(cwdDir, false)
+
+ require.NoError(t, err1)
+ require.NoError(t, err2)
+ assert.Same(t, cfg1, cfg2)
+}
+
+func TestGet_BeforeInitialization(t *testing.T) {
+ reset()
+
+ assert.Panics(t, func() {
+ Get()
+ })
+}
+
+func TestGet_AfterInitialization(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ cfg1, err := Init(cwdDir, false)
+ require.NoError(t, err)
+
+ cfg2 := Get()
+ assert.Same(t, cfg1, cfg2)
+}
+
+func TestLoadConfig_NoConfigFiles(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 0)
+ assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
+}
+
+func TestLoadConfig_OnlyGlobalConfig(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ CostPer1MIn: 30.0,
+ CostPer1MOut: 60.0,
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ CostPer1MIn: 1.0,
+ CostPer1MOut: 2.0,
+ ContextWindow: 4096,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ Options: Options{
+ ContextPaths: []string{"custom-context.md"},
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
+ assert.Contains(t, cfg.Options.ContextPaths, "custom-context.md")
+}
+
+func TestLoadConfig_OnlyLocalConfig(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ localConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderAnthropic: {
+ ID: provider.InferenceProviderAnthropic,
+ APIKey: "local-key",
+ ProviderType: provider.TypeAnthropic,
+ DefaultLargeModel: "claude-3-opus",
+ DefaultSmallModel: "claude-3-haiku",
+ Models: []Model{
+ {
+ ID: "claude-3-opus",
+ Name: "Claude 3 Opus",
+ CostPer1MIn: 15.0,
+ CostPer1MOut: 75.0,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "claude-3-haiku",
+ Name: "Claude 3 Haiku",
+ CostPer1MIn: 0.25,
+ CostPer1MOut: 1.25,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ Options: Options{
+ TUI: TUIOptions{CompactMode: true},
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err := json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
+ assert.True(t, cfg.Options.TUI.CompactMode)
+}
+
+func TestLoadConfig_BothGlobalAndLocal(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "global-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ CostPer1MIn: 30.0,
+ CostPer1MOut: 60.0,
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ CostPer1MIn: 1.0,
+ CostPer1MOut: 2.0,
+ ContextWindow: 4096,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ Options: Options{
+ ContextPaths: []string{"global-context.md"},
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ localConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ APIKey: "local-key", // Override global
+ },
+ provider.InferenceProviderAnthropic: {
+ ID: provider.InferenceProviderAnthropic,
+ APIKey: "anthropic-key",
+ ProviderType: provider.TypeAnthropic,
+ DefaultLargeModel: "claude-3-opus",
+ DefaultSmallModel: "claude-3-haiku",
+ Models: []Model{
+ {
+ ID: "claude-3-opus",
+ Name: "Claude 3 Opus",
+ CostPer1MIn: 15.0,
+ CostPer1MOut: 75.0,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "claude-3-haiku",
+ Name: "Claude 3 Haiku",
+ CostPer1MIn: 0.25,
+ CostPer1MOut: 1.25,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ Options: Options{
+ ContextPaths: []string{"local-context.md"},
+ TUI: TUIOptions{CompactMode: true},
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 2)
+
+ openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
+ assert.Equal(t, "local-key", openaiProvider.APIKey)
+
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
+
+ assert.Contains(t, cfg.Options.ContextPaths, "global-context.md")
+ assert.Contains(t, cfg.Options.ContextPaths, "local-context.md")
+ assert.True(t, cfg.Options.TUI.CompactMode)
+}
+
+func TestLoadConfig_MalformedGlobalJSON(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644))
+
+ _, err := Init(cwdDir, false)
+ assert.Error(t, err)
+}
+
+func TestLoadConfig_MalformedLocalJSON(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644))
+
+ _, err := Init(cwdDir, false)
+ assert.Error(t, err)
+}
+
+func TestConfigWithoutEnv(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ cfg, _ := Init(cwdDir, false)
+ assert.Len(t, cfg.Providers, 0)
+}
+
+func TestConfigWithEnv(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
+ os.Setenv("OPENAI_API_KEY", "test-openai-key")
+ os.Setenv("GEMINI_API_KEY", "test-gemini-key")
+ os.Setenv("XAI_API_KEY", "test-xai-key")
+ os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
+
+ cfg, _ := Init(cwdDir, false)
+ assert.Len(t, cfg.Providers, 5)
+}
+
+// Environment Variable Tests
+
+func TestEnvVars_NoEnvironmentVariables(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 0)
+}
+
+func TestEnvVars_AllSupportedAPIKeys(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
+ os.Setenv("OPENAI_API_KEY", "test-openai-key")
+ os.Setenv("GEMINI_API_KEY", "test-gemini-key")
+ os.Setenv("XAI_API_KEY", "test-xai-key")
+ os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 5)
+
+ anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
+ assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey)
+ assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType)
+
+ openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
+ assert.Equal(t, "test-openai-key", openaiProvider.APIKey)
+ assert.Equal(t, provider.TypeOpenAI, openaiProvider.ProviderType)
+
+ geminiProvider := cfg.Providers[provider.InferenceProviderGemini]
+ assert.Equal(t, "test-gemini-key", geminiProvider.APIKey)
+ assert.Equal(t, provider.TypeGemini, geminiProvider.ProviderType)
+
+ xaiProvider := cfg.Providers[provider.InferenceProviderXAI]
+ assert.Equal(t, "test-xai-key", xaiProvider.APIKey)
+ assert.Equal(t, provider.TypeXAI, xaiProvider.ProviderType)
+
+ openrouterProvider := cfg.Providers[provider.InferenceProviderOpenRouter]
+ assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey)
+ assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType)
+ assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL)
+}
+
+func TestEnvVars_PartialEnvironmentVariables(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
+ os.Setenv("OPENAI_API_KEY", "test-openai-key")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 2)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
+ assert.NotContains(t, cfg.Providers, provider.InferenceProviderGemini)
+}
+
+func TestEnvVars_VertexAIConfiguration(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
+ os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
+ os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderVertexAI)
+
+ vertexProvider := cfg.Providers[provider.InferenceProviderVertexAI]
+ assert.Equal(t, provider.TypeVertexAI, vertexProvider.ProviderType)
+ assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
+ assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
+}
+
+func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
+ os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.NotContains(t, cfg.Providers, provider.InferenceProviderVertexAI)
+}
+
+func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key")
+ os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
+ os.Setenv("AWS_DEFAULT_REGION", "us-east-1")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
+
+ bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
+ assert.Equal(t, provider.TypeBedrock, bedrockProvider.ProviderType)
+ assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"])
+}
+
+func TestEnvVars_AWSBedrockWithProfile(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("AWS_PROFILE", "test-profile")
+ os.Setenv("AWS_REGION", "eu-west-1")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
+
+ bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
+ assert.Equal(t, "eu-west-1", bedrockProvider.ExtraParams["region"])
+}
+
+func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test")
+ os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
+}
+
+func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
+ os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
+ os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
+ os.Setenv("AWS_REGION", "us-east-1")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
+ assert.Equal(t, "us-west-2", bedrockProvider.ExtraParams["region"])
+}
+
+func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
+ os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
+ os.Setenv("AWS_REGION", "us-east-1")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
+ assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"])
+}
+
+func TestEnvVars_NoAWSCredentials(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.NotContains(t, cfg.Providers, provider.InferenceProviderBedrock)
+}
+
+func TestEnvVars_CustomEnvironmentVariables(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ if len(cfg.Providers) > 0 {
+ if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists {
+ assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey)
+ }
+ }
+}
+
+func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-anthropic")
+ os.Setenv("OPENAI_API_KEY", "test-openai")
+ os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
+ os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
+ os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
+ os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key")
+ os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret")
+ os.Setenv("AWS_DEFAULT_REGION", "us-west-1")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ expectedProviders := []provider.InferenceProvider{
+ provider.InferenceProviderAnthropic,
+ provider.InferenceProviderOpenAI,
+ provider.InferenceProviderVertexAI,
+ provider.InferenceProviderBedrock,
+ }
+
+ for _, expectedProvider := range expectedProviders {
+ assert.Contains(t, cfg.Providers, expectedProvider)
+ }
+}
+
+func TestHasAWSCredentials_AccessKeys(t *testing.T) {
+ reset()
+
+ os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
+ os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
+
+ assert.True(t, hasAWSCredentials())
+}
+
+func TestHasAWSCredentials_Profile(t *testing.T) {
+ reset()
+
+ os.Setenv("AWS_PROFILE", "test-profile")
+
+ assert.True(t, hasAWSCredentials())
+}
+
+func TestHasAWSCredentials_DefaultProfile(t *testing.T) {
+ reset()
+
+ os.Setenv("AWS_DEFAULT_PROFILE", "default")
+
+ assert.True(t, hasAWSCredentials())
+}
+
+func TestHasAWSCredentials_Region(t *testing.T) {
+ reset()
+
+ os.Setenv("AWS_REGION", "us-east-1")
+
+ assert.True(t, hasAWSCredentials())
+}
+
+func TestHasAWSCredentials_ContainerCredentials(t *testing.T) {
+ reset()
+
+ os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test")
+
+ assert.True(t, hasAWSCredentials())
+}
+
+func TestHasAWSCredentials_NoCredentials(t *testing.T) {
+ reset()
+
+ assert.False(t, hasAWSCredentials())
+}
+
+func TestProviderMerging_GlobalToBase(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "global-openai-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ ContextWindow: 4096,
+ DefaultMaxTokens: 2048,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+
+ openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
+ assert.Equal(t, "global-openai-key", openaiProvider.APIKey)
+ assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel)
+ assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel)
+ assert.Len(t, openaiProvider.Models, 2)
+}
+
+func TestProviderMerging_LocalToBase(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ localConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderAnthropic: {
+ ID: provider.InferenceProviderAnthropic,
+ APIKey: "local-anthropic-key",
+ ProviderType: provider.TypeAnthropic,
+ DefaultLargeModel: "claude-3-opus",
+ DefaultSmallModel: "claude-3-haiku",
+ Models: []Model{
+ {
+ ID: "claude-3-opus",
+ Name: "Claude 3 Opus",
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ CostPer1MIn: 15.0,
+ CostPer1MOut: 75.0,
+ },
+ {
+ ID: "claude-3-haiku",
+ Name: "Claude 3 Haiku",
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ CostPer1MIn: 0.25,
+ CostPer1MOut: 1.25,
+ },
+ },
+ },
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err := json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+
+ anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
+ assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey)
+ assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel)
+ assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModel)
+ assert.Len(t, anthropicProvider.Models, 2)
+}
+
+func TestProviderMerging_ConflictingSettings(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "global-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ ContextWindow: 4096,
+ DefaultMaxTokens: 2048,
+ },
+ {
+ ID: "gpt-4-turbo",
+ Name: "GPT-4 Turbo",
+ ContextWindow: 128000,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ // Create local config that overrides
+ localConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ APIKey: "local-key",
+ DefaultLargeModel: "gpt-4-turbo",
+ },
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
+ assert.Equal(t, "local-key", openaiProvider.APIKey)
+ assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel)
+ assert.False(t, openaiProvider.Disabled)
+ assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel)
+}
+
+func TestProviderMerging_CustomVsKnownProviders(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ customProviderID := provider.InferenceProvider("custom-provider")
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "openai-key",
+ BaseURL: "should-not-override",
+ ProviderType: provider.TypeAnthropic,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ ContextWindow: 4096,
+ DefaultMaxTokens: 2048,
+ },
+ },
+ },
+ customProviderID: {
+ ID: customProviderID,
+ APIKey: "custom-key",
+ BaseURL: "https://custom.api.com",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "custom-large",
+ DefaultSmallModel: "custom-small",
+ Models: []Model{
+ {
+ ID: "custom-large",
+ Name: "Custom Large",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "custom-small",
+ Name: "Custom Small",
+ ContextWindow: 4096,
+ DefaultMaxTokens: 2048,
+ },
+ },
+ },
+ },
+ }
+
+ localConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ BaseURL: "https://should-not-change.com",
+ ProviderType: provider.TypeGemini, // Should not change
+ },
+ customProviderID: {
+ BaseURL: "https://updated-custom.api.com",
+ ProviderType: provider.TypeOpenAI,
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
+ assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL)
+ assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType)
+
+ customProvider := cfg.Providers[customProviderID]
+ assert.Equal(t, "custom-key", customProvider.APIKey)
+ assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL)
+ assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType)
+}
+
+func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ customProviderID := provider.InferenceProvider("custom-provider")
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ customProviderID: {
+ ID: customProviderID,
+ APIKey: "custom-key",
+ ProviderType: provider.TypeOpenAI,
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.NotContains(t, cfg.Providers, customProviderID)
+}
+
+func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ customProviderID := provider.InferenceProvider("custom-provider")
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ customProviderID: {
+ ID: customProviderID,
+ BaseURL: "https://custom.api.com",
+ ProviderType: provider.TypeOpenAI,
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.NotContains(t, cfg.Providers, customProviderID)
+}
+
+func TestProviderValidation_CustomProviderInvalidType(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ customProviderID := provider.InferenceProvider("custom-provider")
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ customProviderID: {
+ ID: customProviderID,
+ APIKey: "custom-key",
+ BaseURL: "https://custom.api.com",
+ ProviderType: provider.Type("invalid-type"),
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.NotContains(t, cfg.Providers, customProviderID)
+}
+
+func TestProviderValidation_KnownProviderValid(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "openai-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ ContextWindow: 4096,
+ DefaultMaxTokens: 2048,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
+}
+
+func TestProviderValidation_DisabledProvider(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "openai-key",
+ ProviderType: provider.TypeOpenAI,
+ Disabled: true,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ ContextWindow: 4096,
+ DefaultMaxTokens: 2048,
+ },
+ },
+ },
+ provider.InferenceProviderAnthropic: {
+ ID: provider.InferenceProviderAnthropic,
+ APIKey: "anthropic-key",
+ ProviderType: provider.TypeAnthropic,
+ Disabled: false, // This one is enabled
+ DefaultLargeModel: "claude-3-opus",
+ DefaultSmallModel: "claude-3-haiku",
+ Models: []Model{
+ {
+ ID: "claude-3-opus",
+ Name: "Claude 3 Opus",
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "claude-3-haiku",
+ Name: "Claude 3 Haiku",
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
+ assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled)
+ assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
+ assert.False(t, cfg.Providers[provider.InferenceProviderAnthropic].Disabled)
+}
+
+func TestProviderModels_AddingNewModels(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "openai-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-4-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ }
+
+ localConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ Models: []Model{
+ {
+ ID: "gpt-4-turbo",
+ Name: "GPT-4 Turbo",
+ ContextWindow: 128000,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
+ assert.Len(t, openaiProvider.Models, 2)
+
+ modelIDs := make([]string, len(openaiProvider.Models))
+ for i, model := range openaiProvider.Models {
+ modelIDs[i] = model.ID
+ }
+ assert.Contains(t, modelIDs, "gpt-4")
+ assert.Contains(t, modelIDs, "gpt-4-turbo")
+}
+
+func TestProviderModels_DuplicateModelHandling(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "openai-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-4",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ }
+
+ localConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4 Updated",
+ ContextWindow: 16384,
+ DefaultMaxTokens: 8192,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
+ assert.Len(t, openaiProvider.Models, 1)
+
+ model := openaiProvider.Models[0]
+ assert.Equal(t, "gpt-4", model.ID)
+ assert.Equal(t, "GPT-4", model.Name)
+ assert.Equal(t, int64(8192), model.ContextWindow)
+}
+
+func TestProviderModels_ModelCostAndCapabilities(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "openai-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-4",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ CostPer1MIn: 30.0,
+ CostPer1MOut: 60.0,
+ CostPer1MInCached: 15.0,
+ CostPer1MOutCached: 30.0,
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ CanReason: true,
+ ReasoningEffort: "medium",
+ SupportsImages: true,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
+ require.Len(t, openaiProvider.Models, 1)
+
+ model := openaiProvider.Models[0]
+ assert.Equal(t, 30.0, model.CostPer1MIn)
+ assert.Equal(t, 60.0, model.CostPer1MOut)
+ assert.Equal(t, 15.0, model.CostPer1MInCached)
+ assert.Equal(t, 30.0, model.CostPer1MOutCached)
+ assert.True(t, model.CanReason)
+ assert.Equal(t, "medium", model.ReasoningEffort)
+ assert.True(t, model.SupportsImages)
+}
+
+func TestDefaultAgents_CoderAgent(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Contains(t, cfg.Agents, AgentCoder)
+
+ coderAgent := cfg.Agents[AgentCoder]
+ assert.Equal(t, AgentCoder, coderAgent.ID)
+ assert.Equal(t, "Coder", coderAgent.Name)
+ assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description)
+ assert.Equal(t, LargeModel, coderAgent.Model)
+ assert.False(t, coderAgent.Disabled)
+ assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths)
+ assert.Nil(t, coderAgent.AllowedTools)
+}
+
+func TestDefaultAgents_TaskAgent(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ assert.Contains(t, cfg.Agents, AgentTask)
+
+ taskAgent := cfg.Agents[AgentTask]
+ assert.Equal(t, AgentTask, taskAgent.ID)
+ assert.Equal(t, "Task", taskAgent.Name)
+ assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description)
+ assert.Equal(t, LargeModel, taskAgent.Model)
+ assert.False(t, taskAgent.Disabled)
+ assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths)
+
+ expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
+ assert.Equal(t, expectedTools, taskAgent.AllowedTools)
+
+ assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP)
+ assert.Equal(t, []string{}, taskAgent.AllowedLSP)
+}
+
+func TestAgentMerging_CustomAgent(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Agents: map[AgentID]Agent{
+ AgentID("custom-agent"): {
+ ID: AgentID("custom-agent"),
+ Name: "Custom Agent",
+ Description: "A custom agent for testing",
+ Model: SmallModel,
+ AllowedTools: []string{"glob", "grep"},
+ AllowedMCP: map[string][]string{"mcp1": {"tool1", "tool2"}},
+ AllowedLSP: []string{"typescript", "go"},
+ ContextPaths: []string{"custom-context.md"},
+ },
+ },
+ MCP: map[string]MCP{
+ "mcp1": {
+ Type: MCPStdio,
+ Command: "test-mcp-command",
+ Args: []string{"--test"},
+ },
+ },
+ LSP: map[string]LSPConfig{
+ "typescript": {
+ Command: "typescript-language-server",
+ Args: []string{"--stdio"},
+ },
+ "go": {
+ Command: "gopls",
+ Args: []string{},
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ assert.Contains(t, cfg.Agents, AgentCoder)
+ assert.Contains(t, cfg.Agents, AgentTask)
+ assert.Contains(t, cfg.Agents, AgentID("custom-agent"))
+
+ customAgent := cfg.Agents[AgentID("custom-agent")]
+ assert.Equal(t, "Custom Agent", customAgent.Name)
+ assert.Equal(t, "A custom agent for testing", customAgent.Description)
+ assert.Equal(t, SmallModel, customAgent.Model)
+ assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools)
+ assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP)
+ assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP)
+ expectedContextPaths := append(defaultContextPaths, "custom-context.md")
+ assert.Equal(t, expectedContextPaths, customAgent.ContextPaths)
+}
+
+func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Agents: map[AgentID]Agent{
+ AgentCoder: {
+ Model: SmallModel,
+ AllowedMCP: map[string][]string{"mcp1": {"tool1"}},
+ AllowedLSP: []string{"typescript"},
+ ContextPaths: []string{"coder-specific.md"},
+ },
+ },
+ MCP: map[string]MCP{
+ "mcp1": {
+ Type: MCPStdio,
+ Command: "test-mcp-command",
+ Args: []string{"--test"},
+ },
+ },
+ LSP: map[string]LSPConfig{
+ "typescript": {
+ Command: "typescript-language-server",
+ Args: []string{"--stdio"},
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ coderAgent := cfg.Agents[AgentCoder]
+ assert.Equal(t, AgentCoder, coderAgent.ID)
+ assert.Equal(t, "Coder", coderAgent.Name)
+ assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description)
+
+ expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md")
+ assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths)
+
+ assert.Equal(t, SmallModel, coderAgent.Model)
+ assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP)
+ assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP)
+}
+
+func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Agents: map[AgentID]Agent{
+ AgentTask: {
+ Model: SmallModel,
+ AllowedMCP: map[string][]string{"search-mcp": nil},
+ AllowedLSP: []string{"python"},
+ Name: "Search Agent",
+ Description: "Custom search agent",
+ Disabled: true,
+ AllowedTools: []string{"glob", "grep", "view"},
+ },
+ },
+ MCP: map[string]MCP{
+ "search-mcp": {
+ Type: MCPStdio,
+ Command: "search-mcp-command",
+ Args: []string{"--search"},
+ },
+ },
+ LSP: map[string]LSPConfig{
+ "python": {
+ Command: "pylsp",
+ Args: []string{},
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ taskAgent := cfg.Agents[AgentTask]
+ assert.Equal(t, "Task", taskAgent.Name)
+ assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description)
+ assert.False(t, taskAgent.Disabled)
+ assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
+
+ assert.Equal(t, SmallModel, taskAgent.Model)
+ assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP)
+ assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP)
+}
+
+func TestAgentMerging_LocalOverridesGlobal(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Agents: map[AgentID]Agent{
+ AgentID("test-agent"): {
+ ID: AgentID("test-agent"),
+ Name: "Global Agent",
+ Description: "Global description",
+ Model: LargeModel,
+ AllowedTools: []string{"glob"},
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ // Create local config that overrides
+ localConfig := Config{
+ Agents: map[AgentID]Agent{
+ AgentID("test-agent"): {
+ Name: "Local Agent",
+ Description: "Local description",
+ Model: SmallModel,
+ Disabled: true,
+ AllowedTools: []string{"grep", "view"},
+ AllowedMCP: map[string][]string{"local-mcp": {"tool1"}},
+ },
+ },
+ MCP: map[string]MCP{
+ "local-mcp": {
+ Type: MCPStdio,
+ Command: "local-mcp-command",
+ Args: []string{"--local"},
+ },
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ testAgent := cfg.Agents[AgentID("test-agent")]
+ assert.Equal(t, "Local Agent", testAgent.Name)
+ assert.Equal(t, "Local description", testAgent.Description)
+ assert.Equal(t, SmallModel, testAgent.Model)
+ assert.True(t, testAgent.Disabled)
+ assert.Equal(t, []string{"grep", "view"}, testAgent.AllowedTools)
+ assert.Equal(t, map[string][]string{"local-mcp": {"tool1"}}, testAgent.AllowedMCP)
+}
+
+func TestAgentModelTypeAssignment(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Agents: map[AgentID]Agent{
+ AgentID("large-agent"): {
+ ID: AgentID("large-agent"),
+ Name: "Large Model Agent",
+ Model: LargeModel,
+ },
+ AgentID("small-agent"): {
+ ID: AgentID("small-agent"),
+ Name: "Small Model Agent",
+ Model: SmallModel,
+ },
+ AgentID("default-agent"): {
+ ID: AgentID("default-agent"),
+ Name: "Default Model Agent",
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model)
+ assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model)
+ assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model)
+}
+
+func TestAgentContextPathOverrides(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Options: Options{
+ ContextPaths: []string{"global-context.md", "shared-context.md"},
+ },
+ Agents: map[AgentID]Agent{
+ AgentID("custom-context-agent"): {
+ ID: AgentID("custom-context-agent"),
+ Name: "Custom Context Agent",
+ ContextPaths: []string{"agent-specific.md", "custom.md"},
+ },
+ AgentID("default-context-agent"): {
+ ID: AgentID("default-context-agent"),
+ Name: "Default Context Agent",
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ customAgent := cfg.Agents[AgentID("custom-context-agent")]
+ expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md")
+ assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths)
+
+ defaultAgent := cfg.Agents[AgentID("default-context-agent")]
+ expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md")
+ assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths)
+
+ coderAgent := cfg.Agents[AgentCoder]
+ assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths)
+}
+
+func TestOptionsMerging_ContextPaths(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Options: Options{
+ ContextPaths: []string{"global1.md", "global2.md"},
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ localConfig := Config{
+ Options: Options{
+ ContextPaths: []string{"local1.md", "local2.md"},
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md")
+ assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths)
+}
+
+func TestOptionsMerging_TUIOptions(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Options: Options{
+ TUI: TUIOptions{
+ CompactMode: false,
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ localConfig := Config{
+ Options: Options{
+ TUI: TUIOptions{
+ CompactMode: true,
+ },
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ assert.True(t, cfg.Options.TUI.CompactMode)
+}
+
+func TestOptionsMerging_DebugFlags(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Options: Options{
+ Debug: false,
+ DebugLSP: false,
+ DisableAutoSummarize: false,
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ localConfig := Config{
+ Options: Options{
+ DebugLSP: true,
+ DisableAutoSummarize: true,
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ assert.False(t, cfg.Options.Debug)
+ assert.True(t, cfg.Options.DebugLSP)
+ assert.True(t, cfg.Options.DisableAutoSummarize)
+}
+
+func TestOptionsMerging_DataDirectory(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Options: Options{
+ DataDirectory: "global-data",
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ localConfig := Config{
+ Options: Options{
+ DataDirectory: "local-data",
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ assert.Equal(t, "local-data", cfg.Options.DataDirectory)
+}
+
+func TestOptionsMerging_DefaultValues(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory)
+ assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
+ assert.False(t, cfg.Options.TUI.CompactMode)
+ assert.False(t, cfg.Options.Debug)
+ assert.False(t, cfg.Options.DebugLSP)
+ assert.False(t, cfg.Options.DisableAutoSummarize)
+}
+
+func TestOptionsMerging_DebugFlagFromInit(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Options: Options{
+ Debug: false,
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, true)
+
+ require.NoError(t, err)
+
+ // Debug flag from Init should take precedence
+ assert.True(t, cfg.Options.Debug)
+}
+
+func TestOptionsMerging_ComplexScenario(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ // Set up a provider
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ // Create global config with various options
+ globalConfig := Config{
+ Options: Options{
+ ContextPaths: []string{"global-context.md"},
+ DataDirectory: "global-data",
+ Debug: false,
+ DebugLSP: false,
+ DisableAutoSummarize: false,
+ TUI: TUIOptions{
+ CompactMode: false,
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ // Create local config that partially overrides
+ localConfig := Config{
+ Options: Options{
+ ContextPaths: []string{"local-context.md"},
+ DebugLSP: true, // Override
+ DisableAutoSummarize: true, // Override
+ TUI: TUIOptions{
+ CompactMode: true, // Override
+ },
+ // DataDirectory and Debug not specified - should keep global values
+ },
+ }
+
+ localConfigPath := filepath.Join(cwdDir, "crush.json")
+ data, err = json.Marshal(localConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+
+ // Check merged results
+ expectedContextPaths := append(defaultContextPaths, "global-context.md", "local-context.md")
+ assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths)
+ assert.Equal(t, "global-data", cfg.Options.DataDirectory) // From global
+ assert.False(t, cfg.Options.Debug) // From global
+ assert.True(t, cfg.Options.DebugLSP) // From local
+ assert.True(t, cfg.Options.DisableAutoSummarize) // From local
+ assert.True(t, cfg.Options.TUI.CompactMode) // From local
+}
+
+// Model Selection Tests
+
+func TestModelSelection_PreferredModelSelection(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ // Set up multiple providers to test selection logic
+ os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
+ os.Setenv("OPENAI_API_KEY", "test-openai-key")
+
+ cfg, err := Init(cwdDir, false)
+
+ require.NoError(t, err)
+ require.Len(t, cfg.Providers, 2)
+
+ // Should have preferred models set
+ assert.NotEmpty(t, cfg.Models.Large.ModelID)
+ assert.NotEmpty(t, cfg.Models.Large.Provider)
+ assert.NotEmpty(t, cfg.Models.Small.ModelID)
+ assert.NotEmpty(t, cfg.Models.Small.Provider)
+
+ // Both should use the same provider (first available)
+ assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider)
+}
+
+func TestValidation_InvalidModelReference(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "non-existent-model",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ ContextWindow: 4096,
+ DefaultMaxTokens: 2048,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ _, err = Init(cwdDir, false)
+ assert.Error(t, err)
+}
+
+func TestValidation_EmptyAPIKey(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ globalConfig := Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ ProviderType: provider.TypeOpenAI,
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ _, err = Init(cwdDir, false)
+ assert.Error(t, err)
+}
+
+func TestValidation_InvalidAgentModelType(t *testing.T) {
+ reset()
+ testConfigDir = t.TempDir()
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-key")
+
+ globalConfig := Config{
+ Agents: map[AgentID]Agent{
+ AgentID("invalid-agent"): {
+ ID: AgentID("invalid-agent"),
+ Name: "Invalid Agent",
+ Model: ModelType("invalid"),
+ },
+ },
+ }
+
+ configPath := filepath.Join(testConfigDir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ data, err := json.Marshal(globalConfig)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(configPath, data, 0o644))
+
+ _, err = Init(cwdDir, false)
+ assert.Error(t, err)
+}
diff --git a/internal/config/fs.go b/internal/config/fs.go
new file mode 100644
index 0000000000000000000000000000000000000000..efa622cf937846370616042de4fe2bcd6f33b7a1
--- /dev/null
+++ b/internal/config/fs.go
@@ -0,0 +1,71 @@
+package config
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+)
+
+var testConfigDir string
+
+func baseConfigPath() string {
+ if testConfigDir != "" {
+ return testConfigDir
+ }
+
+ xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
+ if xdgConfigHome != "" {
+ return filepath.Join(xdgConfigHome, "crush")
+ }
+
+ // return the path to the main config directory
+ // for windows, it should be in `%LOCALAPPDATA%/crush/`
+ // for linux and macOS, it should be in `$HOME/.config/crush/`
+ if runtime.GOOS == "windows" {
+ localAppData := os.Getenv("LOCALAPPDATA")
+ if localAppData == "" {
+ localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
+ }
+ return filepath.Join(localAppData, appName)
+ }
+
+ return filepath.Join(os.Getenv("HOME"), ".config", appName)
+}
+
+func baseDataPath() string {
+ if testConfigDir != "" {
+ return testConfigDir
+ }
+
+ xdgDataHome := os.Getenv("XDG_DATA_HOME")
+ if xdgDataHome != "" {
+ return filepath.Join(xdgDataHome, appName)
+ }
+
+ // return the path to the main data directory
+ // for windows, it should be in `%LOCALAPPDATA%/crush/`
+ // for linux and macOS, it should be in `$HOME/.local/share/crush/`
+ if runtime.GOOS == "windows" {
+ localAppData := os.Getenv("LOCALAPPDATA")
+ if localAppData == "" {
+ localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
+ }
+ return filepath.Join(localAppData, appName)
+ }
+
+ return filepath.Join(os.Getenv("HOME"), ".local", "share", appName)
+}
+
+func ConfigPath() string {
+ return filepath.Join(baseConfigPath(), fmt.Sprintf("%s.json", appName))
+}
+
+func CrushInitialized() bool {
+ cfgPath := ConfigPath()
+ if _, err := os.Stat(cfgPath); os.IsNotExist(err) {
+ // config file does not exist, so Crush is not initialized
+ return false
+ }
+ return true
+}
diff --git a/internal/config/init.go b/internal/config/init.go
index 1b603fbb846aba45230cd0f4683cb465e14db69a..f17e1db28e41cc44e168765e55e88311423e1102 100644
--- a/internal/config/init.go
+++ b/internal/config/init.go
@@ -17,23 +17,20 @@ type ProjectInitFlag struct {
Initialized bool `json:"initialized"`
}
-// ShouldShowInitDialog checks if the initialization dialog should be shown for the current directory
-func ShouldShowInitDialog() (bool, error) {
- if cfg == nil {
+// ProjectNeedsInitialization checks if the current project needs initialization
+func ProjectNeedsInitialization() (bool, error) {
+ if instance == nil {
return false, fmt.Errorf("config not loaded")
}
- // Create the flag file path
- flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename)
+ flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
// Check if the flag file exists
_, err := os.Stat(flagFilePath)
if err == nil {
- // File exists, don't show the dialog
return false, nil
}
- // If the error is not "file not found", return the error
if !os.IsNotExist(err) {
return false, fmt.Errorf("failed to check init flag file: %w", err)
}
@@ -44,11 +41,9 @@ func ShouldShowInitDialog() (bool, error) {
return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err)
}
if crushExists {
- // CRUSH.md already exists, don't show the dialog
return false, nil
}
- // File doesn't exist, show the dialog
return true, nil
}
@@ -75,13 +70,11 @@ func crushMdExists(dir string) (bool, error) {
// MarkProjectInitialized marks the current project as initialized
func MarkProjectInitialized() error {
- if cfg == nil {
+ if instance == nil {
return fmt.Errorf("config not loaded")
}
- // Create the flag file path
- flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename)
+ flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
- // Create an empty file to mark the project as initialized
file, err := os.Create(flagFilePath)
if err != nil {
return fmt.Errorf("failed to create init flag file: %w", err)
diff --git a/internal/config/provider.go b/internal/config/provider.go
new file mode 100644
index 0000000000000000000000000000000000000000..09e3b0e3fc84b9e2688ccc4d2559604aca83ddfc
--- /dev/null
+++ b/internal/config/provider.go
@@ -0,0 +1,83 @@
+package config
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "github.com/charmbracelet/crush/internal/fur/client"
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
+
+var fur = client.New()
+
+var (
+ providerOnc sync.Once // Ensures the initialization happens only once
+ providerList []provider.Provider
+ // UseMockProviders can be set to true in tests to avoid API calls
+ UseMockProviders bool
+)
+
+func providersPath() string {
+ return filepath.Join(baseDataPath(), "providers.json")
+}
+
+func saveProviders(providers []provider.Provider) error {
+ path := providersPath()
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return err
+ }
+
+ data, err := json.MarshalIndent(providers, "", " ")
+ if err != nil {
+ return err
+ }
+
+ return os.WriteFile(path, data, 0o644)
+}
+
+func loadProviders() ([]provider.Provider, error) {
+ path := providersPath()
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ var providers []provider.Provider
+ err = json.Unmarshal(data, &providers)
+ return providers, err
+}
+
+func Providers() []provider.Provider {
+ providerOnc.Do(func() {
+ // Use mock providers when testing
+ if UseMockProviders {
+ providerList = MockProviders()
+ return
+ }
+
+ // Try to get providers from upstream API
+ if providers, err := fur.GetProviders(); err == nil {
+ providerList = providers
+ // Save providers locally for future fallback
+ _ = saveProviders(providers)
+ } else {
+ // If upstream fails, try to load from local cache
+ if localProviders, localErr := loadProviders(); localErr == nil {
+ providerList = localProviders
+ } else {
+ // If both fail, return empty list
+ providerList = []provider.Provider{}
+ }
+ }
+ })
+ return providerList
+}
+
+// ResetProviders resets the provider cache. Useful for testing.
+func ResetProviders() {
+ providerOnc = sync.Once{}
+ providerList = nil
+}
diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go
new file mode 100644
index 0000000000000000000000000000000000000000..801afdd8d6c9891eb47fa53294c047917b031637
--- /dev/null
+++ b/internal/config/provider_mock.go
@@ -0,0 +1,293 @@
+package config
+
+import (
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
+
+// MockProviders returns a mock list of providers for testing.
+// This avoids making API calls during tests and provides consistent test data.
+// Simplified version with only default models from each provider.
+func MockProviders() []provider.Provider {
+ return []provider.Provider{
+ {
+ Name: "Anthropic",
+ ID: provider.InferenceProviderAnthropic,
+ APIKey: "$ANTHROPIC_API_KEY",
+ APIEndpoint: "$ANTHROPIC_API_ENDPOINT",
+ Type: provider.TypeAnthropic,
+ DefaultLargeModelID: "claude-sonnet-4-20250514",
+ DefaultSmallModelID: "claude-3-5-haiku-20241022",
+ Models: []provider.Model{
+ {
+ ID: "claude-sonnet-4-20250514",
+ Name: "Claude Sonnet 4",
+ CostPer1MIn: 3.0,
+ CostPer1MOut: 15.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.3,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 50000,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ {
+ ID: "claude-3-5-haiku-20241022",
+ Name: "Claude 3.5 Haiku",
+ CostPer1MIn: 0.8,
+ CostPer1MOut: 4.0,
+ CostPer1MInCached: 1.0,
+ CostPer1MOutCached: 0.08,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 5000,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "OpenAI",
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "$OPENAI_API_ENDPOINT",
+ Type: provider.TypeOpenAI,
+ DefaultLargeModelID: "codex-mini-latest",
+ DefaultSmallModelID: "gpt-4o",
+ Models: []provider.Model{
+ {
+ ID: "codex-mini-latest",
+ Name: "Codex Mini",
+ CostPer1MIn: 1.5,
+ CostPer1MOut: 6.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.375,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 50000,
+ CanReason: true,
+ HasReasoningEffort: true,
+ DefaultReasoningEffort: "medium",
+ SupportsImages: true,
+ },
+ {
+ ID: "gpt-4o",
+ Name: "GPT-4o",
+ CostPer1MIn: 2.5,
+ CostPer1MOut: 10.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 1.25,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 20000,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "Google Gemini",
+ ID: provider.InferenceProviderGemini,
+ APIKey: "$GEMINI_API_KEY",
+ APIEndpoint: "$GEMINI_API_ENDPOINT",
+ Type: provider.TypeGemini,
+ DefaultLargeModelID: "gemini-2.5-pro",
+ DefaultSmallModelID: "gemini-2.5-flash",
+ Models: []provider.Model{
+ {
+ ID: "gemini-2.5-pro",
+ Name: "Gemini 2.5 Pro",
+ CostPer1MIn: 1.25,
+ CostPer1MOut: 10.0,
+ CostPer1MInCached: 1.625,
+ CostPer1MOutCached: 0.31,
+ ContextWindow: 1048576,
+ DefaultMaxTokens: 50000,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ {
+ ID: "gemini-2.5-flash",
+ Name: "Gemini 2.5 Flash",
+ CostPer1MIn: 0.3,
+ CostPer1MOut: 2.5,
+ CostPer1MInCached: 0.3833,
+ CostPer1MOutCached: 0.075,
+ ContextWindow: 1048576,
+ DefaultMaxTokens: 50000,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "xAI",
+ ID: provider.InferenceProviderXAI,
+ APIKey: "$XAI_API_KEY",
+ APIEndpoint: "https://api.x.ai/v1",
+ Type: provider.TypeXAI,
+ DefaultLargeModelID: "grok-3",
+ DefaultSmallModelID: "grok-3-mini",
+ Models: []provider.Model{
+ {
+ ID: "grok-3",
+ Name: "Grok 3",
+ CostPer1MIn: 3.0,
+ CostPer1MOut: 15.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.75,
+ ContextWindow: 131072,
+ DefaultMaxTokens: 20000,
+ CanReason: false,
+ SupportsImages: false,
+ },
+ {
+ ID: "grok-3-mini",
+ Name: "Grok 3 Mini",
+ CostPer1MIn: 0.3,
+ CostPer1MOut: 0.5,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.075,
+ ContextWindow: 131072,
+ DefaultMaxTokens: 20000,
+ CanReason: true,
+ SupportsImages: false,
+ },
+ },
+ },
+ {
+ Name: "Azure OpenAI",
+ ID: provider.InferenceProviderAzure,
+ APIKey: "$AZURE_OPENAI_API_KEY",
+ APIEndpoint: "$AZURE_OPENAI_API_ENDPOINT",
+ Type: provider.TypeAzure,
+ DefaultLargeModelID: "o4-mini",
+ DefaultSmallModelID: "gpt-4o",
+ Models: []provider.Model{
+ {
+ ID: "o4-mini",
+ Name: "o4 Mini",
+ CostPer1MIn: 1.1,
+ CostPer1MOut: 4.4,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.275,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 50000,
+ CanReason: true,
+ HasReasoningEffort: false,
+ DefaultReasoningEffort: "medium",
+ SupportsImages: true,
+ },
+ {
+ ID: "gpt-4o",
+ Name: "GPT-4o",
+ CostPer1MIn: 2.5,
+ CostPer1MOut: 10.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 1.25,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 20000,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "AWS Bedrock",
+ ID: provider.InferenceProviderBedrock,
+ Type: provider.TypeBedrock,
+ DefaultLargeModelID: "anthropic.claude-sonnet-4-20250514-v1:0",
+ DefaultSmallModelID: "anthropic.claude-3-5-haiku-20241022-v1:0",
+ Models: []provider.Model{
+ {
+ ID: "anthropic.claude-sonnet-4-20250514-v1:0",
+ Name: "AWS Claude Sonnet 4",
+ CostPer1MIn: 3.0,
+ CostPer1MOut: 15.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.3,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 50000,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ {
+ ID: "anthropic.claude-3-5-haiku-20241022-v1:0",
+ Name: "AWS Claude 3.5 Haiku",
+ CostPer1MIn: 0.8,
+ CostPer1MOut: 4.0,
+ CostPer1MInCached: 1.0,
+ CostPer1MOutCached: 0.08,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 50000,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "Google Vertex AI",
+ ID: provider.InferenceProviderVertexAI,
+ Type: provider.TypeVertexAI,
+ DefaultLargeModelID: "gemini-2.5-pro",
+ DefaultSmallModelID: "gemini-2.5-flash",
+ Models: []provider.Model{
+ {
+ ID: "gemini-2.5-pro",
+ Name: "Gemini 2.5 Pro",
+ CostPer1MIn: 1.25,
+ CostPer1MOut: 10.0,
+ CostPer1MInCached: 1.625,
+ CostPer1MOutCached: 0.31,
+ ContextWindow: 1048576,
+ DefaultMaxTokens: 50000,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ {
+ ID: "gemini-2.5-flash",
+ Name: "Gemini 2.5 Flash",
+ CostPer1MIn: 0.3,
+ CostPer1MOut: 2.5,
+ CostPer1MInCached: 0.3833,
+ CostPer1MOutCached: 0.075,
+ ContextWindow: 1048576,
+ DefaultMaxTokens: 50000,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "OpenRouter",
+ ID: provider.InferenceProviderOpenRouter,
+ APIKey: "$OPENROUTER_API_KEY",
+ APIEndpoint: "https://openrouter.ai/api/v1",
+ Type: provider.TypeOpenAI,
+ DefaultLargeModelID: "anthropic/claude-sonnet-4",
+ DefaultSmallModelID: "anthropic/claude-haiku-3.5",
+ Models: []provider.Model{
+ {
+ ID: "anthropic/claude-sonnet-4",
+ Name: "Anthropic: Claude Sonnet 4",
+ CostPer1MIn: 3.0,
+ CostPer1MOut: 15.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.3,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 32000,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ {
+ ID: "anthropic/claude-haiku-3.5",
+ Name: "Anthropic: Claude 3.5 Haiku",
+ CostPer1MIn: 0.8,
+ CostPer1MOut: 4.0,
+ CostPer1MInCached: 1.0,
+ CostPer1MOutCached: 0.08,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 4096,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
+ }
+}
diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..53a084d244c2d48538a514e8c72530a3850782d7
--- /dev/null
+++ b/internal/config/provider_test.go
@@ -0,0 +1,81 @@
+package config
+
+import (
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestProviders_MockEnabled(t *testing.T) {
+ originalUseMock := UseMockProviders
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = originalUseMock
+ ResetProviders()
+ }()
+
+ ResetProviders()
+ providers := Providers()
+ require.NotEmpty(t, providers)
+
+ providerIDs := make(map[provider.InferenceProvider]bool)
+ for _, p := range providers {
+ providerIDs[p.ID] = true
+ }
+
+ assert.True(t, providerIDs[provider.InferenceProviderAnthropic])
+ assert.True(t, providerIDs[provider.InferenceProviderOpenAI])
+ assert.True(t, providerIDs[provider.InferenceProviderGemini])
+}
+
+func TestProviders_ResetFunctionality(t *testing.T) {
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = false
+ ResetProviders()
+ }()
+
+ providers1 := Providers()
+ require.NotEmpty(t, providers1)
+
+ ResetProviders()
+ providers2 := Providers()
+ require.NotEmpty(t, providers2)
+
+ assert.Equal(t, len(providers1), len(providers2))
+}
+
+func TestProviders_ModelCapabilities(t *testing.T) {
+ originalUseMock := UseMockProviders
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = originalUseMock
+ ResetProviders()
+ }()
+
+ ResetProviders()
+ providers := Providers()
+
+ var openaiProvider provider.Provider
+ for _, p := range providers {
+ if p.ID == provider.InferenceProviderOpenAI {
+ openaiProvider = p
+ break
+ }
+ }
+ require.NotEmpty(t, openaiProvider.ID)
+
+ var foundReasoning, foundNonReasoning bool
+ for _, model := range openaiProvider.Models {
+ if model.CanReason && model.HasReasoningEffort {
+ foundReasoning = true
+ } else if !model.CanReason {
+ foundNonReasoning = true
+ }
+ }
+
+ assert.True(t, foundReasoning)
+ assert.True(t, foundNonReasoning)
+}
diff --git a/internal/config/shell.go b/internal/config/shell.go
new file mode 100644
index 0000000000000000000000000000000000000000..b7c3c8c5a787def8ff28aec677193f5ac58b652a
--- /dev/null
+++ b/internal/config/shell.go
@@ -0,0 +1,73 @@
+package config
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/logging"
+ "github.com/charmbracelet/crush/internal/shell"
+)
+
+// ExecuteCommand executes a shell command and returns the output
+// This is a shared utility that can be used by both provider config and tools
+func ExecuteCommand(ctx context.Context, command string, workingDir string) (string, error) {
+ if workingDir == "" {
+ workingDir = WorkingDirectory()
+ }
+
+ persistentShell := shell.NewShell(&shell.Options{WorkingDir: workingDir})
+
+ stdout, stderr, err := persistentShell.Exec(ctx, command)
+ if err != nil {
+ logging.Debug("Command execution failed", "command", command, "error", err, "stderr", stderr)
+ return "", fmt.Errorf("command execution failed: %w", err)
+ }
+
+ return strings.TrimSpace(stdout), nil
+}
+
+// ResolveAPIKey resolves an API key that can be either:
+// - A direct string value
+// - An environment variable (prefixed with $)
+// - A shell command (wrapped in $(...))
+func ResolveAPIKey(apiKey string) (string, error) {
+ if !strings.HasPrefix(apiKey, "$") {
+ return apiKey, nil
+ }
+
+ if strings.HasPrefix(apiKey, "$(") && strings.HasSuffix(apiKey, ")") {
+ command := strings.TrimSuffix(strings.TrimPrefix(apiKey, "$("), ")")
+ logging.Debug("Resolving API key from command", "command", command)
+ return resolveCommandAPIKey(command)
+ }
+
+ envVar := strings.TrimPrefix(apiKey, "$")
+ if value := os.Getenv(envVar); value != "" {
+ logging.Debug("Resolved environment variable", "envVar", envVar, "value", value)
+ return value, nil
+ }
+
+ logging.Debug("Environment variable not found", "envVar", envVar)
+
+ return "", fmt.Errorf("environment variable %s not found", envVar)
+}
+
+// resolveCommandAPIKey executes a command to get an API key, with caching support
+func resolveCommandAPIKey(command string) (string, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ logging.Debug("Executing command for API key", "command", command)
+
+ workingDir := WorkingDirectory()
+
+ result, err := ExecuteCommand(ctx, command, workingDir)
+ if err != nil {
+ return "", fmt.Errorf("failed to execute API key command: %w", err)
+ }
+ logging.Debug("Command executed successfully", "command", command, "result", result)
+ return result, nil
+}
diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..0aef035ae7bddfc7532e9dde550ab0184ed180db
--- /dev/null
+++ b/internal/config/validation_test.go
@@ -0,0 +1,462 @@
+package config
+
+import (
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestConfig_Validate_ValidConfig(t *testing.T) {
+ cfg := &Config{
+ Models: PreferredModels{
+ Large: PreferredModel{
+ ModelID: "gpt-4",
+ Provider: provider.InferenceProviderOpenAI,
+ },
+ Small: PreferredModel{
+ ModelID: "gpt-3.5-turbo",
+ Provider: provider.InferenceProviderOpenAI,
+ },
+ },
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "gpt-4",
+ DefaultSmallModel: "gpt-3.5-turbo",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ CostPer1MIn: 30.0,
+ CostPer1MOut: 60.0,
+ },
+ {
+ ID: "gpt-3.5-turbo",
+ Name: "GPT-3.5 Turbo",
+ ContextWindow: 4096,
+ DefaultMaxTokens: 2048,
+ CostPer1MIn: 1.5,
+ CostPer1MOut: 2.0,
+ },
+ },
+ },
+ },
+ Agents: map[AgentID]Agent{
+ AgentCoder: {
+ ID: AgentCoder,
+ Name: "Coder",
+ Description: "An agent that helps with executing coding tasks.",
+ Model: LargeModel,
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ AgentTask: {
+ ID: AgentTask,
+ Name: "Task",
+ Description: "An agent that helps with searching for context and finding implementation details.",
+ Model: LargeModel,
+ ContextPaths: []string{"CRUSH.md"},
+ AllowedTools: []string{"glob", "grep", "ls", "sourcegraph", "view"},
+ AllowedMCP: map[string][]string{},
+ AllowedLSP: []string{},
+ },
+ },
+ MCP: map[string]MCP{},
+ LSP: map[string]LSPConfig{},
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ assert.NoError(t, err)
+}
+
+func TestConfig_Validate_MissingAPIKey(t *testing.T) {
+ cfg := &Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ ProviderType: provider.TypeOpenAI,
+ // Missing APIKey
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "API key is required")
+}
+
+func TestConfig_Validate_InvalidProviderType(t *testing.T) {
+ cfg := &Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.Type("invalid"),
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid provider type")
+}
+
+func TestConfig_Validate_CustomProviderMissingBaseURL(t *testing.T) {
+ customProvider := provider.InferenceProvider("custom-provider")
+ cfg := &Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ customProvider: {
+ ID: customProvider,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ // Missing BaseURL for custom provider
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "BaseURL is required for custom providers")
+}
+
+func TestConfig_Validate_DuplicateModelIDs(t *testing.T) {
+ cfg := &Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ {
+ ID: "gpt-4", // Duplicate ID
+ Name: "GPT-4 Duplicate",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "duplicate model ID")
+}
+
+func TestConfig_Validate_InvalidModelFields(t *testing.T) {
+ cfg := &Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ Models: []Model{
+ {
+ ID: "", // Empty ID
+ Name: "GPT-4",
+ ContextWindow: 0, // Invalid context window
+ DefaultMaxTokens: -1, // Invalid max tokens
+ CostPer1MIn: -5.0, // Negative cost
+ },
+ },
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ validationErr := err.(ValidationErrors)
+ assert.True(t, len(validationErr) >= 4) // Should have multiple validation errors
+}
+
+func TestConfig_Validate_DefaultModelNotFound(t *testing.T) {
+ cfg := &Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ DefaultLargeModel: "nonexistent-model",
+ Models: []Model{
+ {
+ ID: "gpt-4",
+ Name: "GPT-4",
+ ContextWindow: 8192,
+ DefaultMaxTokens: 4096,
+ },
+ },
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "default large model 'nonexistent-model' not found")
+}
+
+func TestConfig_Validate_AgentIDMismatch(t *testing.T) {
+ cfg := &Config{
+ Agents: map[AgentID]Agent{
+ AgentCoder: {
+ ID: AgentTask, // Wrong ID
+ Name: "Coder",
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "agent ID mismatch")
+}
+
+func TestConfig_Validate_InvalidAgentModelType(t *testing.T) {
+ cfg := &Config{
+ Agents: map[AgentID]Agent{
+ AgentCoder: {
+ ID: AgentCoder,
+ Name: "Coder",
+ Model: ModelType("invalid"),
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid model type")
+}
+
+func TestConfig_Validate_UnknownTool(t *testing.T) {
+ cfg := &Config{
+ Agents: map[AgentID]Agent{
+ AgentID("custom-agent"): {
+ ID: AgentID("custom-agent"),
+ Name: "Custom Agent",
+ Model: LargeModel,
+ AllowedTools: []string{"unknown-tool"},
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "unknown tool")
+}
+
+func TestConfig_Validate_MCPReference(t *testing.T) {
+ cfg := &Config{
+ Agents: map[AgentID]Agent{
+ AgentID("custom-agent"): {
+ ID: AgentID("custom-agent"),
+ Name: "Custom Agent",
+ Model: LargeModel,
+ AllowedMCP: map[string][]string{"nonexistent-mcp": nil},
+ },
+ },
+ MCP: map[string]MCP{}, // Empty MCP map
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "referenced MCP 'nonexistent-mcp' not found")
+}
+
+func TestConfig_Validate_InvalidMCPType(t *testing.T) {
+ cfg := &Config{
+ MCP: map[string]MCP{
+ "test-mcp": {
+ Type: MCPType("invalid"),
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid MCP type")
+}
+
+func TestConfig_Validate_MCPMissingCommand(t *testing.T) {
+ cfg := &Config{
+ MCP: map[string]MCP{
+ "test-mcp": {
+ Type: MCPStdio,
+ // Missing Command
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "command is required for stdio MCP")
+}
+
+func TestConfig_Validate_LSPMissingCommand(t *testing.T) {
+ cfg := &Config{
+ LSP: map[string]LSPConfig{
+ "test-lsp": {
+ // Missing Command
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "command is required for LSP")
+}
+
+func TestConfig_Validate_NoValidProviders(t *testing.T) {
+ cfg := &Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ Disabled: true, // Disabled
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "at least one non-disabled provider is required")
+}
+
+func TestConfig_Validate_MissingDefaultAgents(t *testing.T) {
+ cfg := &Config{
+ Providers: map[provider.InferenceProvider]ProviderConfig{
+ provider.InferenceProviderOpenAI: {
+ ID: provider.InferenceProviderOpenAI,
+ APIKey: "test-key",
+ ProviderType: provider.TypeOpenAI,
+ },
+ },
+ Agents: map[AgentID]Agent{}, // Missing default agents
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "coder agent is required")
+ assert.Contains(t, err.Error(), "task agent is required")
+}
+
+func TestConfig_Validate_KnownAgentProtection(t *testing.T) {
+ cfg := &Config{
+ Agents: map[AgentID]Agent{
+ AgentCoder: {
+ ID: AgentCoder,
+ Name: "Modified Coder", // Should not be allowed
+ Description: "Modified description", // Should not be allowed
+ Model: LargeModel,
+ },
+ },
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "coder agent name cannot be changed")
+ assert.Contains(t, err.Error(), "coder agent description cannot be changed")
+}
+
+func TestConfig_Validate_EmptyDataDirectory(t *testing.T) {
+ cfg := &Config{
+ Options: Options{
+ DataDirectory: "", // Empty
+ ContextPaths: []string{"CRUSH.md"},
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "data directory is required")
+}
+
+func TestConfig_Validate_EmptyContextPath(t *testing.T) {
+ cfg := &Config{
+ Options: Options{
+ DataDirectory: ".crush",
+ ContextPaths: []string{""}, // Empty context path
+ },
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "context path cannot be empty")
+}
diff --git a/internal/db/connect.go b/internal/db/connect.go
index 9212ce1f097e6877a9ce9b368e77d76e739b673f..95df2e491d13918d8ea66a3f8312438a42e81b4f 100644
--- a/internal/db/connect.go
+++ b/internal/db/connect.go
@@ -17,7 +17,7 @@ import (
)
func Connect(ctx context.Context) (*sql.DB, error) {
- dataDir := config.Get().Data.Directory
+ dataDir := config.Get().Options.DataDirectory
if dataDir == "" {
return nil, fmt.Errorf("data.dir is not set")
}
diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go
index 2acfe18fdbc63312c49d65e9e3acb1bd24cf4d7e..81f322921db87dde7ade48ce64322aa01004d255 100644
--- a/internal/db/messages.sql.go
+++ b/internal/db/messages.sql.go
@@ -17,12 +17,13 @@ INSERT INTO messages (
role,
parts,
model,
+ provider,
created_at,
updated_at
) VALUES (
- ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
+ ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
-RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at
+RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
`
type CreateMessageParams struct {
@@ -31,6 +32,7 @@ type CreateMessageParams struct {
Role string `json:"role"`
Parts string `json:"parts"`
Model sql.NullString `json:"model"`
+ Provider sql.NullString `json:"provider"`
}
func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
@@ -40,6 +42,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
arg.Role,
arg.Parts,
arg.Model,
+ arg.Provider,
)
var i Message
err := row.Scan(
@@ -51,6 +54,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
+ &i.Provider,
)
return i, err
}
@@ -76,7 +80,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e
}
const getMessage = `-- name: GetMessage :one
-SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
FROM messages
WHERE id = ? LIMIT 1
`
@@ -93,12 +97,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
+ &i.Provider,
)
return i, err
}
const listMessagesBySession = `-- name: ListMessagesBySession :many
-SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
FROM messages
WHERE session_id = ?
ORDER BY created_at ASC
@@ -122,6 +127,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
+ &i.Provider,
); err != nil {
return nil, err
}
diff --git a/internal/db/migrations/20250627000000_add_provider_to_messages.sql b/internal/db/migrations/20250627000000_add_provider_to_messages.sql
new file mode 100644
index 0000000000000000000000000000000000000000..9bf0ed9749c49640f10407c97deb032f60baaac2
--- /dev/null
+++ b/internal/db/migrations/20250627000000_add_provider_to_messages.sql
@@ -0,0 +1,11 @@
+-- +goose Up
+-- +goose StatementBegin
+-- Add provider column to messages table
+ALTER TABLE messages ADD COLUMN provider TEXT;
+-- +goose StatementEnd
+
+-- +goose Down
+-- +goose StatementBegin
+-- Remove provider column from messages table
+ALTER TABLE messages DROP COLUMN provider;
+-- +goose StatementEnd
\ No newline at end of file
diff --git a/internal/db/models.go b/internal/db/models.go
index ec19f99b213e041331b5d6a14dee3648bc14c1de..ec3e6e10ad990d0f1a3d03a7533c8b1aed184447 100644
--- a/internal/db/models.go
+++ b/internal/db/models.go
@@ -27,6 +27,7 @@ type Message struct {
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
FinishedAt sql.NullInt64 `json:"finished_at"`
+ Provider sql.NullString `json:"provider"`
}
type Session struct {
diff --git a/internal/db/sql/messages.sql b/internal/db/sql/messages.sql
index a59cebe7d00fe5fd7cbd449df681df45e832979a..ea946177591d1e145a59475a1ca9272f3191d4d6 100644
--- a/internal/db/sql/messages.sql
+++ b/internal/db/sql/messages.sql
@@ -16,10 +16,11 @@ INSERT INTO messages (
role,
parts,
model,
+ provider,
created_at,
updated_at
) VALUES (
- ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
+ ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
RETURNING *;
diff --git a/internal/fsext/fileutil.go b/internal/fsext/fileutil.go
index 1726f916b07ac9ac0defdf7c06dae7a8768b30c5..cc430d73edb34cc5b81e1e36ecbad550bf4312fe 100644
--- a/internal/fsext/fileutil.go
+++ b/internal/fsext/fileutil.go
@@ -55,7 +55,7 @@ func GetRgSearchCmd(pattern, path, include string) *exec.Cmd {
return nil
}
// Use -n to show line numbers and include the matched line
- args := []string{"-n", pattern}
+ args := []string{"-H", "-n", pattern}
if include != "" {
args = append(args, "--glob", include)
}
diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go
new file mode 100644
index 0000000000000000000000000000000000000000..5f0ddeaeee708d4b5475403ce1874591f7e9bb2c
--- /dev/null
+++ b/internal/fur/client/client.go
@@ -0,0 +1,63 @@
+// Package client provides a client for interacting with the fur service.
+package client
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "os"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
+
+const defaultURL = "https://fur.charmcli.dev"
+
+// Client represents a client for the fur service.
+type Client struct {
+ baseURL string
+ httpClient *http.Client
+}
+
+// New creates a new client instance
+// Uses FUR_URL environment variable or falls back to localhost:8080.
+func New() *Client {
+ baseURL := os.Getenv("FUR_URL")
+ if baseURL == "" {
+ baseURL = defaultURL
+ }
+
+ return &Client{
+ baseURL: baseURL,
+ httpClient: &http.Client{},
+ }
+}
+
+// NewWithURL creates a new client with a specific URL.
+func NewWithURL(url string) *Client {
+ return &Client{
+ baseURL: url,
+ httpClient: &http.Client{},
+ }
+}
+
+// GetProviders retrieves all available providers from the service.
+func (c *Client) GetProviders() ([]provider.Provider, error) {
+ url := fmt.Sprintf("%s/providers", c.baseURL)
+
+ resp, err := c.httpClient.Get(url) //nolint:noctx
+ if err != nil {
+ return nil, fmt.Errorf("failed to make request: %w", err)
+ }
+ defer resp.Body.Close() //nolint:errcheck
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ var providers []provider.Provider
+ if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil {
+ return nil, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ return providers, nil
+}
diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go
new file mode 100644
index 0000000000000000000000000000000000000000..e3c0f6209cbe71c239da104b38c3022e090599aa
--- /dev/null
+++ b/internal/fur/provider/provider.go
@@ -0,0 +1,73 @@
+// Package provider provides types and constants for AI providers.
+package provider
+
+// Type represents the type of AI provider.
+type Type string
+
+// All the supported AI provider types.
+const (
+ TypeOpenAI Type = "openai"
+ TypeAnthropic Type = "anthropic"
+ TypeGemini Type = "gemini"
+ TypeAzure Type = "azure"
+ TypeBedrock Type = "bedrock"
+ TypeVertexAI Type = "vertexai"
+ TypeXAI Type = "xai"
+)
+
+// InferenceProvider represents the inference provider identifier.
+type InferenceProvider string
+
+// All the inference providers supported by the system.
+const (
+ InferenceProviderOpenAI InferenceProvider = "openai"
+ InferenceProviderAnthropic InferenceProvider = "anthropic"
+ InferenceProviderGemini InferenceProvider = "gemini"
+ InferenceProviderAzure InferenceProvider = "azure"
+ InferenceProviderBedrock InferenceProvider = "bedrock"
+ InferenceProviderVertexAI InferenceProvider = "vertexai"
+ InferenceProviderXAI InferenceProvider = "xai"
+ InferenceProviderOpenRouter InferenceProvider = "openrouter"
+)
+
+// Provider represents an AI provider configuration.
+type Provider struct {
+ Name string `json:"name"`
+ ID InferenceProvider `json:"id"`
+ APIKey string `json:"api_key,omitempty"`
+ APIEndpoint string `json:"api_endpoint,omitempty"`
+ Type Type `json:"type,omitempty"`
+ DefaultLargeModelID string `json:"default_large_model_id,omitempty"`
+ DefaultSmallModelID string `json:"default_small_model_id,omitempty"`
+ Models []Model `json:"models,omitempty"`
+}
+
+// Model represents an AI model configuration.
+type Model struct {
+ ID string `json:"id"`
+ Name string `json:"model"`
+ CostPer1MIn float64 `json:"cost_per_1m_in"`
+ CostPer1MOut float64 `json:"cost_per_1m_out"`
+ CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
+ CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
+ ContextWindow int64 `json:"context_window"`
+ DefaultMaxTokens int64 `json:"default_max_tokens"`
+ CanReason bool `json:"can_reason"`
+ HasReasoningEffort bool `json:"has_reasoning_efforts"`
+ DefaultReasoningEffort string `json:"default_reasoning_effort,omitempty"`
+ SupportsImages bool `json:"supports_attachments"`
+}
+
+// KnownProviders returns all the known inference providers.
+func KnownProviders() []InferenceProvider {
+ return []InferenceProvider{
+ InferenceProviderOpenAI,
+ InferenceProviderAnthropic,
+ InferenceProviderGemini,
+ InferenceProviderAzure,
+ InferenceProviderBedrock,
+ InferenceProviderVertexAI,
+ InferenceProviderXAI,
+ InferenceProviderOpenRouter,
+ }
+}
diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go
index de4a86ac36d62ef0990a58d6abeb9a53572bc215..ae15c5867e7321f2ac29e1809f5eb7effb830fdc 100644
--- a/internal/llm/agent/agent-tool.go
+++ b/internal/llm/agent/agent-tool.go
@@ -5,17 +5,15 @@ import (
"encoding/json"
"fmt"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
- "github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
)
type agentTool struct {
- sessions session.Service
- messages message.Service
- lspClients map[string]*lsp.Client
+ agent Service
+ sessions session.Service
+ messages message.Service
}
const (
@@ -26,6 +24,10 @@ type AgentParams struct {
Prompt string `json:"prompt"`
}
+func (b *agentTool) Name() string {
+ return AgentToolName
+}
+
func (b *agentTool) Info() tools.ToolInfo {
return tools.ToolInfo{
Name: AgentToolName,
@@ -54,17 +56,12 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
}
- agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients))
- if err != nil {
- return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
- }
-
session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
}
- done, err := agent.Run(ctx, session.ID, params.Prompt)
+ done, err := b.agent.Run(ctx, session.ID, params.Prompt)
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
}
@@ -97,13 +94,13 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
}
func NewAgentTool(
- Sessions session.Service,
- Messages message.Service,
- LspClients map[string]*lsp.Client,
+ agent Service,
+ sessions session.Service,
+ messages message.Service,
) tools.BaseTool {
return &agentTool{
- sessions: Sessions,
- messages: Messages,
- lspClients: LspClients,
+ sessions: sessions,
+ messages: messages,
+ agent: agent,
}
}
diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go
index 26d952441317607323e1170d6a06559f9173605d..cd2e2fdaccc9108af3bab8a0072baad062585846 100644
--- a/internal/llm/agent/agent.go
+++ b/internal/llm/agent/agent.go
@@ -4,16 +4,18 @@ import (
"context"
"errors"
"fmt"
+ "slices"
"strings"
"sync"
"time"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
+ "github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -47,72 +49,189 @@ type AgentEvent struct {
type Service interface {
pubsub.Suscriber[AgentEvent]
- Model() models.Model
+ Model() config.Model
+ EffectiveMaxTokens() int64
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
Cancel(sessionID string)
CancelAll()
IsSessionBusy(sessionID string) bool
IsBusy() bool
- Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
Summarize(ctx context.Context, sessionID string) error
+ UpdateModel() error
}
type agent struct {
*pubsub.Broker[AgentEvent]
+ agentCfg config.Agent
sessions session.Service
messages message.Service
- tools []tools.BaseTool
- provider provider.Provider
+ tools []tools.BaseTool
+ provider provider.Provider
+ providerID string
- titleProvider provider.Provider
- summarizeProvider provider.Provider
+ titleProvider provider.Provider
+ summarizeProvider provider.Provider
+ summarizeProviderID string
activeRequests sync.Map
}
+var agentPromptMap = map[config.AgentID]prompt.PromptID{
+ config.AgentCoder: prompt.PromptCoder,
+ config.AgentTask: prompt.PromptTask,
+}
+
func NewAgent(
- agentName config.AgentName,
+ agentCfg config.Agent,
+ // These services are needed in the tools
+ permissions permission.Service,
sessions session.Service,
messages message.Service,
- agentTools []tools.BaseTool,
+ history history.Service,
+ lspClients map[string]*lsp.Client,
) (Service, error) {
- agentProvider, err := createAgentProvider(agentName)
+ ctx := context.Background()
+ cfg := config.Get()
+ otherTools := GetMcpTools(ctx, permissions)
+ if len(lspClients) > 0 {
+ otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
+ }
+
+ allTools := []tools.BaseTool{
+ tools.NewBashTool(permissions),
+ tools.NewEditTool(lspClients, permissions, history),
+ tools.NewFetchTool(permissions),
+ tools.NewGlobTool(),
+ tools.NewGrepTool(),
+ tools.NewLsTool(),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients),
+ tools.NewWriteTool(lspClients, permissions, history),
+ }
+
+ if agentCfg.ID == config.AgentCoder {
+ taskAgentCfg := config.Get().Agents[config.AgentTask]
+ if taskAgentCfg.ID == "" {
+ return nil, fmt.Errorf("task agent not found in config")
+ }
+ taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create task agent: %w", err)
+ }
+
+ allTools = append(
+ allTools,
+ NewAgentTool(
+ taskAgent,
+ sessions,
+ messages,
+ ),
+ )
+ }
+
+ allTools = append(allTools, otherTools...)
+ providerCfg := config.GetAgentProvider(agentCfg.ID)
+ if providerCfg.ID == "" {
+ return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
+ }
+ model := config.GetAgentModel(agentCfg.ID)
+
+ if model.ID == "" {
+ return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
+ }
+
+ promptID := agentPromptMap[agentCfg.ID]
+ if promptID == "" {
+ promptID = prompt.PromptDefault
+ }
+ opts := []provider.ProviderClientOption{
+ provider.WithModel(agentCfg.Model),
+ provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
+ }
+ agentProvider, err := provider.NewProvider(providerCfg, opts...)
if err != nil {
return nil, err
}
- var titleProvider provider.Provider
- // Only generate titles for the coder agent
- if agentName == config.AgentCoder {
- titleProvider, err = createAgentProvider(config.AgentTitle)
- if err != nil {
- return nil, err
+
+ smallModelCfg := cfg.Models.Small
+ var smallModel config.Model
+
+ var smallModelProviderCfg config.ProviderConfig
+ if smallModelCfg.Provider == providerCfg.ID {
+ smallModelProviderCfg = providerCfg
+ } else {
+ for _, p := range cfg.Providers {
+ if p.ID == smallModelCfg.Provider {
+ smallModelProviderCfg = p
+ break
+ }
+ }
+ if smallModelProviderCfg.ID == "" {
+ return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
}
}
- var summarizeProvider provider.Provider
- if agentName == config.AgentCoder {
- summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
- if err != nil {
- return nil, err
+ for _, m := range smallModelProviderCfg.Models {
+ if m.ID == smallModelCfg.ModelID {
+ smallModel = m
+ break
+ }
+ }
+ if smallModel.ID == "" {
+ return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
+ }
+
+ titleOpts := []provider.ProviderClientOption{
+ provider.WithModel(config.SmallModel),
+ provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
+ }
+ titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
+ if err != nil {
+ return nil, err
+ }
+ summarizeOpts := []provider.ProviderClientOption{
+ provider.WithModel(config.SmallModel),
+ provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
+ }
+ summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
+ if err != nil {
+ return nil, err
+ }
+
+ agentTools := []tools.BaseTool{}
+ if agentCfg.AllowedTools == nil {
+ agentTools = allTools
+ } else {
+ for _, tool := range allTools {
+ if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
+ agentTools = append(agentTools, tool)
+ }
}
}
agent := &agent{
- Broker: pubsub.NewBroker[AgentEvent](),
- provider: agentProvider,
- messages: messages,
- sessions: sessions,
- tools: agentTools,
- titleProvider: titleProvider,
- summarizeProvider: summarizeProvider,
- activeRequests: sync.Map{},
+ Broker: pubsub.NewBroker[AgentEvent](),
+ agentCfg: agentCfg,
+ provider: agentProvider,
+ providerID: string(providerCfg.ID),
+ messages: messages,
+ sessions: sessions,
+ tools: agentTools,
+ titleProvider: titleProvider,
+ summarizeProvider: summarizeProvider,
+ summarizeProviderID: string(smallModelProviderCfg.ID),
+ activeRequests: sync.Map{},
}
return agent, nil
}
-func (a *agent) Model() models.Model {
- return a.provider.Model()
+func (a *agent) Model() config.Model {
+ return config.GetAgentModel(a.agentCfg.ID)
+}
+
+func (a *agent) EffectiveMaxTokens() int64 {
+ return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
}
func (a *agent) Cancel(sessionID string) {
@@ -139,10 +258,10 @@ func (a *agent) IsBusy() bool {
if cancelFunc, ok := value.(context.CancelFunc); ok {
if cancelFunc != nil {
busy = true
- return false // Stop iterating
+ return false
}
}
- return true // Continue iterating
+ return true
})
return busy
}
@@ -163,7 +282,9 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str
if err != nil {
return err
}
- parts := []message.ContentPart{message.TextContent{Text: content}}
+ parts := []message.ContentPart{message.TextContent{
+ Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
+ }}
// Use streaming approach like summarization
response := a.titleProvider.StreamResponse(
@@ -207,7 +328,7 @@ func (a *agent) err(err error) AgentEvent {
}
func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
- if !a.provider.Model().SupportsAttachments && attachments != nil {
+ if !a.Model().SupportsImages && attachments != nil {
attachments = nil
}
events := make(chan AgentEvent)
@@ -242,6 +363,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
}
func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
+ cfg := config.Get()
// List existing messages; if none, start title generation asynchronously.
msgs, err := a.messages.List(ctx, sessionID)
if err != nil {
@@ -300,7 +422,13 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
}
return a.err(fmt.Errorf("failed to process events: %w", err))
}
- logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
+ if cfg.Options.Debug {
+ seqId := (len(msgHistory) + 1) / 2
+ toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
+ logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
+ } else {
+ logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
+ }
if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
// We are not done, we need to respond with the tool response
msgHistory = append(msgHistory, agentMessage, *toolResults)
@@ -324,12 +452,14 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
}
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
+ ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- Model: a.provider.Model().ID,
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
+ Model: a.Model().ID,
+ Provider: a.providerID,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
@@ -337,7 +467,6 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
// Add the session and message ID into the context if needed by tools.
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
- ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
// Process each event in the stream.
for event := range eventChan {
@@ -369,9 +498,10 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
default:
// Continue processing
var tool tools.BaseTool
- for _, availableTools := range a.tools {
- if availableTools.Info().Name == toolCall.Name {
- tool = availableTools
+ for _, availableTool := range a.tools {
+ if availableTool.Info().Name == toolCall.Name {
+ tool = availableTool
+ break
}
}
@@ -424,8 +554,9 @@ out:
parts = append(parts, tr)
}
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: parts,
+ Role: message.Tool,
+ Parts: parts,
+ Provider: a.providerID,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
@@ -478,13 +609,13 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
if err := a.messages.Update(ctx, *assistantMsg); err != nil {
return fmt.Errorf("failed to update message: %w", err)
}
- return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
+ return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
}
return nil
}
-func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -506,25 +637,6 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
return nil
}
-func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
- if a.IsBusy() {
- return models.Model{}, fmt.Errorf("cannot change model while processing requests")
- }
-
- if err := config.UpdateAgentModel(agentName, modelID); err != nil {
- return models.Model{}, fmt.Errorf("failed to update config: %w", err)
- }
-
- provider, err := createAgentProvider(agentName)
- if err != nil {
- return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
- }
-
- a.provider = provider
-
- return a.provider.Model(), nil
-}
-
func (a *agent) Summarize(ctx context.Context, sessionID string) error {
if a.summarizeProvider == nil {
return fmt.Errorf("summarize provider not available")
@@ -561,6 +673,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
a.Publish(pubsub.CreatedEvent, event)
return
}
+ summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
if len(msgs) == 0 {
event = AgentEvent{
@@ -654,7 +767,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
Time: time.Now().Unix(),
},
},
- Model: a.summarizeProvider.Model().ID,
+ Model: a.summarizeProvider.Model().ID,
+ Provider: a.summarizeProviderID,
})
if err != nil {
event = AgentEvent{
@@ -706,56 +820,98 @@ func (a *agent) CancelAll() {
})
}
-func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
+func (a *agent) UpdateModel() error {
cfg := config.Get()
- agentConfig, ok := cfg.Agents[agentName]
- if !ok {
- return nil, fmt.Errorf("agent %s not found", agentName)
- }
- model, ok := models.SupportedModels[agentConfig.Model]
- if !ok {
- return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
- }
- providerCfg, ok := cfg.Providers[model.Provider]
- if !ok {
- return nil, fmt.Errorf("provider %s not supported", model.Provider)
+ // Get current provider configuration
+ currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
+ if currentProviderCfg.ID == "" {
+ return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
}
- if providerCfg.Disabled {
- return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
+
+ // Check if provider has changed
+ if string(currentProviderCfg.ID) != a.providerID {
+ // Provider changed, need to recreate the main provider
+ model := config.GetAgentModel(a.agentCfg.ID)
+ if model.ID == "" {
+ return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
+ }
+
+ promptID := agentPromptMap[a.agentCfg.ID]
+ if promptID == "" {
+ promptID = prompt.PromptDefault
+ }
+
+ opts := []provider.ProviderClientOption{
+ provider.WithModel(a.agentCfg.Model),
+ provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
+ }
+
+ newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
+ if err != nil {
+ return fmt.Errorf("failed to create new provider: %w", err)
+ }
+
+ // Update the provider and provider ID
+ a.provider = newProvider
+ a.providerID = string(currentProviderCfg.ID)
}
- maxTokens := model.DefaultMaxTokens
- if agentConfig.MaxTokens > 0 {
- maxTokens = agentConfig.MaxTokens
+
+ // Check if small model provider has changed (affects title and summarize providers)
+ smallModelCfg := cfg.Models.Small
+ var smallModelProviderCfg config.ProviderConfig
+
+ for _, p := range cfg.Providers {
+ if p.ID == smallModelCfg.Provider {
+ smallModelProviderCfg = p
+ break
+ }
}
- opts := []provider.ProviderClientOption{
- provider.WithAPIKey(providerCfg.APIKey),
- provider.WithModel(model),
- provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
- provider.WithMaxTokens(maxTokens),
- }
- if (model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal) && model.CanReason {
- opts = append(
- opts,
- provider.WithOpenAIOptions(
- provider.WithReasoningEffort(agentConfig.ReasoningEffort),
- ),
- )
- } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
- opts = append(
- opts,
- provider.WithAnthropicOptions(
- provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
- ),
- )
+
+ if smallModelProviderCfg.ID == "" {
+ return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
}
- agentProvider, err := provider.NewProvider(
- model.Provider,
- opts...,
- )
- if err != nil {
- return nil, fmt.Errorf("could not create provider: %v", err)
+
+ // Check if summarize provider has changed
+ if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
+ var smallModel config.Model
+ for _, m := range smallModelProviderCfg.Models {
+ if m.ID == smallModelCfg.ModelID {
+ smallModel = m
+ break
+ }
+ }
+ if smallModel.ID == "" {
+ return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
+ }
+
+ // Recreate title provider
+ titleOpts := []provider.ProviderClientOption{
+ provider.WithModel(config.SmallModel),
+ provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
+ // We want the title to be short, so we limit the max tokens
+ provider.WithMaxTokens(40),
+ }
+ newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
+ if err != nil {
+ return fmt.Errorf("failed to create new title provider: %w", err)
+ }
+
+ // Recreate summarize provider
+ summarizeOpts := []provider.ProviderClientOption{
+ provider.WithModel(config.SmallModel),
+ provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
+ }
+ newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
+ if err != nil {
+ return fmt.Errorf("failed to create new summarize provider: %w", err)
+ }
+
+ // Update the providers and provider ID
+ a.titleProvider = newTitleProvider
+ a.summarizeProvider = newSummarizeProvider
+ a.summarizeProviderID = string(smallModelProviderCfg.ID)
}
- return agentProvider, nil
+ return nil
}
diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go
index 55b6983d053a70ccebf56f7c6d239246acf8c317..fed0c06196c600bb5ecc06d1f92a1f3a07f14b38 100644
--- a/internal/llm/agent/mcp-tools.go
+++ b/internal/llm/agent/mcp-tools.go
@@ -18,7 +18,7 @@ import (
type mcpTool struct {
mcpName string
tool mcp.Tool
- mcpConfig config.MCPServer
+ mcpConfig config.MCP
permissions permission.Service
}
@@ -32,12 +32,20 @@ type MCPClient interface {
Close() error
}
+func (b *mcpTool) Name() string {
+ return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
+}
+
func (b *mcpTool) Info() tools.ToolInfo {
+ required := b.tool.InputSchema.Required
+ if required == nil {
+ required = make([]string, 0)
+ }
return tools.ToolInfo{
Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
Description: b.tool.Description,
Parameters: b.tool.InputSchema.Properties,
- Required: b.tool.InputSchema.Required,
+ Required: required,
}
}
@@ -124,7 +132,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("invalid mcp type"), nil
}
-func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
+func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
return &mcpTool{
mcpName: name,
tool: tool,
@@ -135,7 +143,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC
var mcpTools []tools.BaseTool
-func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
+func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
@@ -166,7 +174,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
if len(mcpTools) > 0 {
return mcpTools
}
- for name, m := range config.Get().MCPServers {
+ for name, m := range config.Get().MCP {
switch m.Type {
case config.MCPStdio:
c, err := client.NewStdioMCPClient(
diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go
deleted file mode 100644
index 0fe2c530ca6dc30916fd2dfa094ad6303bf39443..0000000000000000000000000000000000000000
--- a/internal/llm/agent/tools.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package agent
-
-import (
- "context"
-
- "github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/llm/tools"
- "github.com/charmbracelet/crush/internal/lsp"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/session"
-)
-
-func CoderAgentTools(
- permissions permission.Service,
- sessions session.Service,
- messages message.Service,
- history history.Service,
- lspClients map[string]*lsp.Client,
-) []tools.BaseTool {
- ctx := context.Background()
- otherTools := GetMcpTools(ctx, permissions)
- if len(lspClients) > 0 {
- otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
- }
- return append(
- []tools.BaseTool{
- tools.NewBashTool(permissions),
- tools.NewEditTool(lspClients, permissions, history),
- tools.NewFetchTool(permissions),
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients),
- tools.NewWriteTool(lspClients, permissions, history),
- NewAgentTool(sessions, messages, lspClients),
- }, otherTools...,
- )
-}
-
-func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
- return []tools.BaseTool{
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients),
- }
-}
diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go
deleted file mode 100644
index 9da03a835126956a74ee16888397abb343811ec4..0000000000000000000000000000000000000000
--- a/internal/llm/models/anthropic.go
+++ /dev/null
@@ -1,111 +0,0 @@
-package models
-
-const (
- ProviderAnthropic ModelProvider = "anthropic"
-
- // Models
- Claude35Sonnet ModelID = "claude-3.5-sonnet"
- Claude3Haiku ModelID = "claude-3-haiku"
- Claude37Sonnet ModelID = "claude-3.7-sonnet"
- Claude35Haiku ModelID = "claude-3.5-haiku"
- Claude3Opus ModelID = "claude-3-opus"
- Claude4Opus ModelID = "claude-4-opus"
- Claude4Sonnet ModelID = "claude-4-sonnet"
-)
-
-// https://docs.anthropic.com/en/docs/about-claude/models/all-models
-var AnthropicModels = map[ModelID]Model{
- Claude35Sonnet: {
- ID: Claude35Sonnet,
- Name: "Claude 3.5 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-5-sonnet-latest",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 5000,
- SupportsAttachments: true,
- },
- Claude3Haiku: {
- ID: Claude3Haiku,
- Name: "Claude 3 Haiku",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
- CostPer1MIn: 0.25,
- CostPer1MInCached: 0.30,
- CostPer1MOutCached: 0.03,
- CostPer1MOut: 1.25,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
- Claude37Sonnet: {
- ID: Claude37Sonnet,
- Name: "Claude 3.7 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-7-sonnet-latest",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- Claude35Haiku: {
- ID: Claude35Haiku,
- Name: "Claude 3.5 Haiku",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-5-haiku-latest",
- CostPer1MIn: 0.80,
- CostPer1MInCached: 1.0,
- CostPer1MOutCached: 0.08,
- CostPer1MOut: 4.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
- Claude3Opus: {
- ID: Claude3Opus,
- Name: "Claude 3 Opus",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-opus-latest",
- CostPer1MIn: 15.0,
- CostPer1MInCached: 18.75,
- CostPer1MOutCached: 1.50,
- CostPer1MOut: 75.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
- Claude4Sonnet: {
- ID: Claude4Sonnet,
- Name: "Claude 4 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-sonnet-4-20250514",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- Claude4Opus: {
- ID: Claude4Opus,
- Name: "Claude 4 Opus",
- Provider: ProviderAnthropic,
- APIModel: "claude-opus-4-20250514",
- CostPer1MIn: 15.0,
- CostPer1MInCached: 18.75,
- CostPer1MOutCached: 1.50,
- CostPer1MOut: 75.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
-}
diff --git a/internal/llm/models/azure.go b/internal/llm/models/azure.go
deleted file mode 100644
index 416597302f362b4f2d7c605f7166ced2b200885a..0000000000000000000000000000000000000000
--- a/internal/llm/models/azure.go
+++ /dev/null
@@ -1,168 +0,0 @@
-package models
-
-const ProviderAzure ModelProvider = "azure"
-
-const (
- AzureGPT41 ModelID = "azure.gpt-4.1"
- AzureGPT41Mini ModelID = "azure.gpt-4.1-mini"
- AzureGPT41Nano ModelID = "azure.gpt-4.1-nano"
- AzureGPT45Preview ModelID = "azure.gpt-4.5-preview"
- AzureGPT4o ModelID = "azure.gpt-4o"
- AzureGPT4oMini ModelID = "azure.gpt-4o-mini"
- AzureO1 ModelID = "azure.o1"
- AzureO1Mini ModelID = "azure.o1-mini"
- AzureO3 ModelID = "azure.o3"
- AzureO3Mini ModelID = "azure.o3-mini"
- AzureO4Mini ModelID = "azure.o4-mini"
-)
-
-var AzureModels = map[ModelID]Model{
- AzureGPT41: {
- ID: AzureGPT41,
- Name: "Azure OpenAI – GPT 4.1",
- Provider: ProviderAzure,
- APIModel: "gpt-4.1",
- CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT41Mini: {
- ID: AzureGPT41Mini,
- Name: "Azure OpenAI – GPT 4.1 mini",
- Provider: ProviderAzure,
- APIModel: "gpt-4.1-mini",
- CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT41Nano: {
- ID: AzureGPT41Nano,
- Name: "Azure OpenAI – GPT 4.1 nano",
- Provider: ProviderAzure,
- APIModel: "gpt-4.1-nano",
- CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT45Preview: {
- ID: AzureGPT45Preview,
- Name: "Azure OpenAI – GPT 4.5 preview",
- Provider: ProviderAzure,
- APIModel: "gpt-4.5-preview",
- CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT4o: {
- ID: AzureGPT4o,
- Name: "Azure OpenAI – GPT-4o",
- Provider: ProviderAzure,
- APIModel: "gpt-4o",
- CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT4o].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT4oMini: {
- ID: AzureGPT4oMini,
- Name: "Azure OpenAI – GPT-4o mini",
- Provider: ProviderAzure,
- APIModel: "gpt-4o-mini",
- CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureO1: {
- ID: AzureO1,
- Name: "Azure OpenAI – O1",
- Provider: ProviderAzure,
- APIModel: "o1",
- CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
- CanReason: OpenAIModels[O1].CanReason,
- SupportsAttachments: true,
- },
- AzureO1Mini: {
- ID: AzureO1Mini,
- Name: "Azure OpenAI – O1 mini",
- Provider: ProviderAzure,
- APIModel: "o1-mini",
- CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O1Mini].CanReason,
- SupportsAttachments: true,
- },
- AzureO3: {
- ID: AzureO3,
- Name: "Azure OpenAI – O3",
- Provider: ProviderAzure,
- APIModel: "o3",
- CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O3].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
- CanReason: OpenAIModels[O3].CanReason,
- SupportsAttachments: true,
- },
- AzureO3Mini: {
- ID: AzureO3Mini,
- Name: "Azure OpenAI – O3 mini",
- Provider: ProviderAzure,
- APIModel: "o3-mini",
- CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O3Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O3Mini].CanReason,
- SupportsAttachments: false,
- },
- AzureO4Mini: {
- ID: AzureO4Mini,
- Name: "Azure OpenAI – O4 mini",
- Provider: ProviderAzure,
- APIModel: "o4-mini",
- CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O4Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O4Mini].CanReason,
- SupportsAttachments: true,
- },
-}
diff --git a/internal/llm/models/gemini.go b/internal/llm/models/gemini.go
deleted file mode 100644
index 794ec3f0a06a0e9975d110cd4fb89e1427a32552..0000000000000000000000000000000000000000
--- a/internal/llm/models/gemini.go
+++ /dev/null
@@ -1,67 +0,0 @@
-package models
-
-const (
- ProviderGemini ModelProvider = "gemini"
-
- // Models
- Gemini25Flash ModelID = "gemini-2.5-flash"
- Gemini25 ModelID = "gemini-2.5"
- Gemini20Flash ModelID = "gemini-2.0-flash"
- Gemini20FlashLite ModelID = "gemini-2.0-flash-lite"
-)
-
-var GeminiModels = map[ModelID]Model{
- Gemini25Flash: {
- ID: Gemini25Flash,
- Name: "Gemini 2.5 Flash",
- Provider: ProviderGemini,
- APIModel: "gemini-2.5-flash-preview-04-17",
- CostPer1MIn: 0.15,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.60,
- ContextWindow: 1000000,
- DefaultMaxTokens: 50000,
- SupportsAttachments: true,
- },
- Gemini25: {
- ID: Gemini25,
- Name: "Gemini 2.5 Pro",
- Provider: ProviderGemini,
- APIModel: "gemini-2.5-pro-preview-05-06",
- CostPer1MIn: 1.25,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 10,
- ContextWindow: 1000000,
- DefaultMaxTokens: 50000,
- SupportsAttachments: true,
- },
-
- Gemini20Flash: {
- ID: Gemini20Flash,
- Name: "Gemini 2.0 Flash",
- Provider: ProviderGemini,
- APIModel: "gemini-2.0-flash",
- CostPer1MIn: 0.10,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.40,
- ContextWindow: 1000000,
- DefaultMaxTokens: 6000,
- SupportsAttachments: true,
- },
- Gemini20FlashLite: {
- ID: Gemini20FlashLite,
- Name: "Gemini 2.0 Flash Lite",
- Provider: ProviderGemini,
- APIModel: "gemini-2.0-flash-lite",
- CostPer1MIn: 0.05,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.30,
- ContextWindow: 1000000,
- DefaultMaxTokens: 6000,
- SupportsAttachments: true,
- },
-}
diff --git a/internal/llm/models/groq.go b/internal/llm/models/groq.go
deleted file mode 100644
index 19917f20bb2647e296db681e30b1b0f379bf7349..0000000000000000000000000000000000000000
--- a/internal/llm/models/groq.go
+++ /dev/null
@@ -1,87 +0,0 @@
-package models
-
-const (
- ProviderGROQ ModelProvider = "groq"
-
- // GROQ
- QWENQwq ModelID = "qwen-qwq"
-
- // GROQ preview models
- Llama4Scout ModelID = "meta-llama/llama-4-scout-17b-16e-instruct"
- Llama4Maverick ModelID = "meta-llama/llama-4-maverick-17b-128e-instruct"
- Llama3_3_70BVersatile ModelID = "llama-3.3-70b-versatile"
- DeepseekR1DistillLlama70b ModelID = "deepseek-r1-distill-llama-70b"
-)
-
-var GroqModels = map[ModelID]Model{
- //
- // GROQ
- QWENQwq: {
- ID: QWENQwq,
- Name: "Qwen Qwq",
- Provider: ProviderGROQ,
- APIModel: "qwen-qwq-32b",
- CostPer1MIn: 0.29,
- CostPer1MInCached: 0.275,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 0.39,
- ContextWindow: 128_000,
- DefaultMaxTokens: 50000,
- // for some reason, the groq api doesn't like the reasoningEffort parameter
- CanReason: false,
- SupportsAttachments: false,
- },
-
- Llama4Scout: {
- ID: Llama4Scout,
- Name: "Llama4Scout",
- Provider: ProviderGROQ,
- APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
- CostPer1MIn: 0.11,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.34,
- ContextWindow: 128_000, // 10M when?
- SupportsAttachments: true,
- },
-
- Llama4Maverick: {
- ID: Llama4Maverick,
- Name: "Llama4Maverick",
- Provider: ProviderGROQ,
- APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
- CostPer1MIn: 0.20,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.20,
- ContextWindow: 128_000,
- SupportsAttachments: true,
- },
-
- Llama3_3_70BVersatile: {
- ID: Llama3_3_70BVersatile,
- Name: "Llama3_3_70BVersatile",
- Provider: ProviderGROQ,
- APIModel: "llama-3.3-70b-versatile",
- CostPer1MIn: 0.59,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.79,
- ContextWindow: 128_000,
- SupportsAttachments: false,
- },
-
- DeepseekR1DistillLlama70b: {
- ID: DeepseekR1DistillLlama70b,
- Name: "DeepseekR1DistillLlama70b",
- Provider: ProviderGROQ,
- APIModel: "deepseek-r1-distill-llama-70b",
- CostPer1MIn: 0.75,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.99,
- ContextWindow: 128_000,
- CanReason: true,
- SupportsAttachments: false,
- },
-}
diff --git a/internal/llm/models/local.go b/internal/llm/models/local.go
deleted file mode 100644
index 3a50fdf48fe86167600eceee3cce26b6caac900e..0000000000000000000000000000000000000000
--- a/internal/llm/models/local.go
+++ /dev/null
@@ -1,206 +0,0 @@
-package models
-
-import (
- "cmp"
- "context"
- "encoding/json"
- "net/http"
- "net/url"
- "os"
- "regexp"
- "strings"
- "unicode"
-
- "github.com/charmbracelet/crush/internal/logging"
- "github.com/spf13/viper"
-)
-
-const (
- ProviderLocal ModelProvider = "local"
-
- localModelsPath = "v1/models"
- lmStudioBetaModelsPath = "api/v0/models"
-)
-
-func init() {
- if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
- localEndpoint, err := url.Parse(endpoint)
- if err != nil {
- logging.Debug("Failed to parse local endpoint",
- "error", err,
- "endpoint", endpoint,
- )
- return
- }
-
- load := func(url *url.URL, path string) []localModel {
- url.Path = path
- return listLocalModels(url.String())
- }
-
- models := load(localEndpoint, lmStudioBetaModelsPath)
-
- if len(models) == 0 {
- models = load(localEndpoint, localModelsPath)
- }
-
- if len(models) == 0 {
- logging.Debug("No local models found",
- "endpoint", endpoint,
- )
- return
- }
-
- loadLocalModels(models)
-
- viper.SetDefault("providers.local.apiKey", "dummy")
- }
-}
-
-type localModelList struct {
- Data []localModel `json:"data"`
-}
-
-type localModel struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Type string `json:"type"`
- Publisher string `json:"publisher"`
- Arch string `json:"arch"`
- CompatibilityType string `json:"compatibility_type"`
- Quantization string `json:"quantization"`
- State string `json:"state"`
- MaxContextLength int64 `json:"max_context_length"`
- LoadedContextLength int64 `json:"loaded_context_length"`
-}
-
-func listLocalModels(modelsEndpoint string) []localModel {
- res, err := http.NewRequestWithContext(context.Background(), http.MethodGet, modelsEndpoint, nil)
- if err != nil {
- logging.Debug("Failed to list local models",
- "error", err,
- "endpoint", modelsEndpoint,
- )
- }
- defer res.Body.Close()
-
- if res.Response.StatusCode != http.StatusOK {
- logging.Debug("Failed to list local models",
- "status", res.Response.Status,
- "endpoint", modelsEndpoint,
- )
- }
-
- var modelList localModelList
- if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
- logging.Debug("Failed to list local models",
- "error", err,
- "endpoint", modelsEndpoint,
- )
- }
-
- var supportedModels []localModel
- for _, model := range modelList.Data {
- if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
- if model.Object != "model" || model.Type != "llm" {
- logging.Debug("Skipping unsupported LMStudio model",
- "endpoint", modelsEndpoint,
- "id", model.ID,
- "object", model.Object,
- "type", model.Type,
- )
-
- continue
- }
- }
-
- supportedModels = append(supportedModels, model)
- }
-
- return supportedModels
-}
-
-func loadLocalModels(models []localModel) {
- for i, m := range models {
- model := convertLocalModel(m)
- SupportedModels[model.ID] = model
-
- if i == 0 || m.State == "loaded" {
- viper.SetDefault("agents.coder.model", model.ID)
- viper.SetDefault("agents.summarizer.model", model.ID)
- viper.SetDefault("agents.task.model", model.ID)
- viper.SetDefault("agents.title.model", model.ID)
- }
- }
-}
-
-func convertLocalModel(model localModel) Model {
- return Model{
- ID: ModelID("local." + model.ID),
- Name: friendlyModelName(model.ID),
- Provider: ProviderLocal,
- APIModel: model.ID,
- ContextWindow: cmp.Or(model.LoadedContextLength, 4096),
- DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096),
- CanReason: true,
- SupportsAttachments: true,
- }
-}
-
-var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
-
-func friendlyModelName(modelID string) string {
- mainID := modelID
- tag := ""
-
- if slash := strings.LastIndex(mainID, "/"); slash != -1 {
- mainID = mainID[slash+1:]
- }
-
- if at := strings.Index(modelID, "@"); at != -1 {
- mainID = modelID[:at]
- tag = modelID[at+1:]
- }
-
- match := modelInfoRegex.FindStringSubmatch(mainID)
- if match == nil {
- return modelID
- }
-
- capitalize := func(s string) string {
- if s == "" {
- return ""
- }
- runes := []rune(s)
- runes[0] = unicode.ToUpper(runes[0])
- return string(runes)
- }
-
- family := capitalize(match[1])
- version := ""
- label := ""
-
- if len(match) > 2 && match[2] != "" {
- version = strings.ToUpper(match[2])
- }
-
- if len(match) > 3 && match[3] != "" {
- label = capitalize(match[3])
- }
-
- var parts []string
- if family != "" {
- parts = append(parts, family)
- }
- if version != "" {
- parts = append(parts, version)
- }
- if label != "" {
- parts = append(parts, label)
- }
- if tag != "" {
- parts = append(parts, tag)
- }
-
- return strings.Join(parts, " ")
-}
diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go
deleted file mode 100644
index 50e8723989ccb268a9f515b4c693662654fa38d5..0000000000000000000000000000000000000000
--- a/internal/llm/models/models.go
+++ /dev/null
@@ -1,60 +0,0 @@
-package models
-
-import "maps"
-
-type (
- ModelID string
- ModelProvider string
-)
-
-type Model struct {
- ID ModelID `json:"id"`
- Name string `json:"name"`
- Provider ModelProvider `json:"provider"`
- APIModel string `json:"api_model"`
- CostPer1MIn float64 `json:"cost_per_1m_in"`
- CostPer1MOut float64 `json:"cost_per_1m_out"`
- CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
- CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
- ContextWindow int64 `json:"context_window"`
- DefaultMaxTokens int64 `json:"default_max_tokens"`
- CanReason bool `json:"can_reason"`
- SupportsAttachments bool `json:"supports_attachments"`
-}
-
-// Model IDs
-const ( // GEMINI
- // Bedrock
- BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
-)
-
-const (
- ProviderBedrock ModelProvider = "bedrock"
- // ForTests
- ProviderMock ModelProvider = "__mock"
-)
-
-var SupportedModels = map[ModelID]Model{
- // Bedrock
- BedrockClaude37Sonnet: {
- ID: BedrockClaude37Sonnet,
- Name: "Bedrock: Claude 3.7 Sonnet",
- Provider: ProviderBedrock,
- APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- },
-}
-
-func init() {
- maps.Copy(SupportedModels, AnthropicModels)
- maps.Copy(SupportedModels, OpenAIModels)
- maps.Copy(SupportedModels, GeminiModels)
- maps.Copy(SupportedModels, GroqModels)
- maps.Copy(SupportedModels, AzureModels)
- maps.Copy(SupportedModels, OpenRouterModels)
- maps.Copy(SupportedModels, XAIModels)
- maps.Copy(SupportedModels, VertexAIGeminiModels)
-}
diff --git a/internal/llm/models/openai.go b/internal/llm/models/openai.go
deleted file mode 100644
index abe0e30c53207c0a7cacfe3d27f43495cfec6e8e..0000000000000000000000000000000000000000
--- a/internal/llm/models/openai.go
+++ /dev/null
@@ -1,181 +0,0 @@
-package models
-
-const (
- ProviderOpenAI ModelProvider = "openai"
-
- GPT41 ModelID = "gpt-4.1"
- GPT41Mini ModelID = "gpt-4.1-mini"
- GPT41Nano ModelID = "gpt-4.1-nano"
- GPT45Preview ModelID = "gpt-4.5-preview"
- GPT4o ModelID = "gpt-4o"
- GPT4oMini ModelID = "gpt-4o-mini"
- O1 ModelID = "o1"
- O1Pro ModelID = "o1-pro"
- O1Mini ModelID = "o1-mini"
- O3 ModelID = "o3"
- O3Mini ModelID = "o3-mini"
- O4Mini ModelID = "o4-mini"
-)
-
-var OpenAIModels = map[ModelID]Model{
- GPT41: {
- ID: GPT41,
- Name: "GPT 4.1",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.1",
- CostPer1MIn: 2.00,
- CostPer1MInCached: 0.50,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 8.00,
- ContextWindow: 1_047_576,
- DefaultMaxTokens: 20000,
- SupportsAttachments: true,
- },
- GPT41Mini: {
- ID: GPT41Mini,
- Name: "GPT 4.1 mini",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.1",
- CostPer1MIn: 0.40,
- CostPer1MInCached: 0.10,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 1.60,
- ContextWindow: 200_000,
- DefaultMaxTokens: 20000,
- SupportsAttachments: true,
- },
- GPT41Nano: {
- ID: GPT41Nano,
- Name: "GPT 4.1 nano",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.1-nano",
- CostPer1MIn: 0.10,
- CostPer1MInCached: 0.025,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 0.40,
- ContextWindow: 1_047_576,
- DefaultMaxTokens: 20000,
- SupportsAttachments: true,
- },
- GPT45Preview: {
- ID: GPT45Preview,
- Name: "GPT 4.5 preview",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.5-preview",
- CostPer1MIn: 75.00,
- CostPer1MInCached: 37.50,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 150.00,
- ContextWindow: 128_000,
- DefaultMaxTokens: 15000,
- SupportsAttachments: true,
- },
- GPT4o: {
- ID: GPT4o,
- Name: "GPT 4o",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4o",
- CostPer1MIn: 2.50,
- CostPer1MInCached: 1.25,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 10.00,
- ContextWindow: 128_000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
- GPT4oMini: {
- ID: GPT4oMini,
- Name: "GPT 4o mini",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4o-mini",
- CostPer1MIn: 0.15,
- CostPer1MInCached: 0.075,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 0.60,
- ContextWindow: 128_000,
- SupportsAttachments: true,
- },
- O1: {
- ID: O1,
- Name: "O1",
- Provider: ProviderOpenAI,
- APIModel: "o1",
- CostPer1MIn: 15.00,
- CostPer1MInCached: 7.50,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 60.00,
- ContextWindow: 200_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- O1Pro: {
- ID: O1Pro,
- Name: "o1 pro",
- Provider: ProviderOpenAI,
- APIModel: "o1-pro",
- CostPer1MIn: 150.00,
- CostPer1MInCached: 0.0,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 600.00,
- ContextWindow: 200_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- O1Mini: {
- ID: O1Mini,
- Name: "o1 mini",
- Provider: ProviderOpenAI,
- APIModel: "o1-mini",
- CostPer1MIn: 1.10,
- CostPer1MInCached: 0.55,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 4.40,
- ContextWindow: 128_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- O3: {
- ID: O3,
- Name: "o3",
- Provider: ProviderOpenAI,
- APIModel: "o3",
- CostPer1MIn: 10.00,
- CostPer1MInCached: 2.50,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 40.00,
- ContextWindow: 200_000,
- CanReason: true,
- SupportsAttachments: true,
- },
- O3Mini: {
- ID: O3Mini,
- Name: "o3 mini",
- Provider: ProviderOpenAI,
- APIModel: "o3-mini",
- CostPer1MIn: 1.10,
- CostPer1MInCached: 0.55,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 4.40,
- ContextWindow: 200_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: false,
- },
- O4Mini: {
- ID: O4Mini,
- Name: "o4 mini",
- Provider: ProviderOpenAI,
- APIModel: "o4-mini",
- CostPer1MIn: 1.10,
- CostPer1MInCached: 0.275,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 4.40,
- ContextWindow: 128_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
-}
diff --git a/internal/llm/models/openrouter.go b/internal/llm/models/openrouter.go
deleted file mode 100644
index 10ad5a0409937e060616eda6c188991d979e3ea1..0000000000000000000000000000000000000000
--- a/internal/llm/models/openrouter.go
+++ /dev/null
@@ -1,276 +0,0 @@
-package models
-
-const (
- ProviderOpenRouter ModelProvider = "openrouter"
-
- OpenRouterGPT41 ModelID = "openrouter.gpt-4.1"
- OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini"
- OpenRouterGPT41Nano ModelID = "openrouter.gpt-4.1-nano"
- OpenRouterGPT45Preview ModelID = "openrouter.gpt-4.5-preview"
- OpenRouterGPT4o ModelID = "openrouter.gpt-4o"
- OpenRouterGPT4oMini ModelID = "openrouter.gpt-4o-mini"
- OpenRouterO1 ModelID = "openrouter.o1"
- OpenRouterO1Pro ModelID = "openrouter.o1-pro"
- OpenRouterO1Mini ModelID = "openrouter.o1-mini"
- OpenRouterO3 ModelID = "openrouter.o3"
- OpenRouterO3Mini ModelID = "openrouter.o3-mini"
- OpenRouterO4Mini ModelID = "openrouter.o4-mini"
- OpenRouterGemini25Flash ModelID = "openrouter.gemini-2.5-flash"
- OpenRouterGemini25 ModelID = "openrouter.gemini-2.5"
- OpenRouterClaude35Sonnet ModelID = "openrouter.claude-3.5-sonnet"
- OpenRouterClaude3Haiku ModelID = "openrouter.claude-3-haiku"
- OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet"
- OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku"
- OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus"
- OpenRouterDeepSeekR1Free ModelID = "openrouter.deepseek-r1-free"
-)
-
-var OpenRouterModels = map[ModelID]Model{
- OpenRouterGPT41: {
- ID: OpenRouterGPT41,
- Name: "OpenRouter – GPT 4.1",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4.1",
- CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
- },
- OpenRouterGPT41Mini: {
- ID: OpenRouterGPT41Mini,
- Name: "OpenRouter – GPT 4.1 mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4.1-mini",
- CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
- },
- OpenRouterGPT41Nano: {
- ID: OpenRouterGPT41Nano,
- Name: "OpenRouter – GPT 4.1 nano",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4.1-nano",
- CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
- },
- OpenRouterGPT45Preview: {
- ID: OpenRouterGPT45Preview,
- Name: "OpenRouter – GPT 4.5 preview",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4.5-preview",
- CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
- },
- OpenRouterGPT4o: {
- ID: OpenRouterGPT4o,
- Name: "OpenRouter – GPT 4o",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4o",
- CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT4o].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
- },
- OpenRouterGPT4oMini: {
- ID: OpenRouterGPT4oMini,
- Name: "OpenRouter – GPT 4o mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4o-mini",
- CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
- },
- OpenRouterO1: {
- ID: OpenRouterO1,
- Name: "OpenRouter – O1",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o1",
- CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
- CanReason: OpenAIModels[O1].CanReason,
- },
- OpenRouterO1Pro: {
- ID: OpenRouterO1Pro,
- Name: "OpenRouter – o1 pro",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o1-pro",
- CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1Pro].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1Pro].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1Pro].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1Pro].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1Pro].DefaultMaxTokens,
- CanReason: OpenAIModels[O1Pro].CanReason,
- },
- OpenRouterO1Mini: {
- ID: OpenRouterO1Mini,
- Name: "OpenRouter – o1 mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o1-mini",
- CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O1Mini].CanReason,
- },
- OpenRouterO3: {
- ID: OpenRouterO3,
- Name: "OpenRouter – o3",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o3",
- CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O3].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
- CanReason: OpenAIModels[O3].CanReason,
- },
- OpenRouterO3Mini: {
- ID: OpenRouterO3Mini,
- Name: "OpenRouter – o3 mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o3-mini-high",
- CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O3Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O3Mini].CanReason,
- },
- OpenRouterO4Mini: {
- ID: OpenRouterO4Mini,
- Name: "OpenRouter – o4 mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o4-mini-high",
- CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O4Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O4Mini].CanReason,
- },
- OpenRouterGemini25Flash: {
- ID: OpenRouterGemini25Flash,
- Name: "OpenRouter – Gemini 2.5 Flash",
- Provider: ProviderOpenRouter,
- APIModel: "google/gemini-2.5-flash-preview:thinking",
- CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
- CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
- CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
- CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
- ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
- DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
- },
- OpenRouterGemini25: {
- ID: OpenRouterGemini25,
- Name: "OpenRouter – Gemini 2.5 Pro",
- Provider: ProviderOpenRouter,
- APIModel: "google/gemini-2.5-pro-preview-03-25",
- CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
- CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
- CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
- CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
- ContextWindow: GeminiModels[Gemini25].ContextWindow,
- DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
- },
- OpenRouterClaude35Sonnet: {
- ID: OpenRouterClaude35Sonnet,
- Name: "OpenRouter – Claude 3.5 Sonnet",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3.5-sonnet",
- CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude35Sonnet].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude35Sonnet].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude35Sonnet].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude35Sonnet].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude35Sonnet].DefaultMaxTokens,
- },
- OpenRouterClaude3Haiku: {
- ID: OpenRouterClaude3Haiku,
- Name: "OpenRouter – Claude 3 Haiku",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3-haiku",
- CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude3Haiku].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude3Haiku].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude3Haiku].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude3Haiku].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude3Haiku].DefaultMaxTokens,
- },
- OpenRouterClaude37Sonnet: {
- ID: OpenRouterClaude37Sonnet,
- Name: "OpenRouter – Claude 3.7 Sonnet",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3.7-sonnet",
- CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude37Sonnet].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude37Sonnet].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude37Sonnet].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude37Sonnet].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude37Sonnet].DefaultMaxTokens,
- CanReason: AnthropicModels[Claude37Sonnet].CanReason,
- },
- OpenRouterClaude35Haiku: {
- ID: OpenRouterClaude35Haiku,
- Name: "OpenRouter – Claude 3.5 Haiku",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3.5-haiku",
- CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude35Haiku].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude35Haiku].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude35Haiku].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude35Haiku].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude35Haiku].DefaultMaxTokens,
- },
- OpenRouterClaude3Opus: {
- ID: OpenRouterClaude3Opus,
- Name: "OpenRouter – Claude 3 Opus",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3-opus",
- CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude3Opus].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude3Opus].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude3Opus].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude3Opus].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens,
- },
-
- OpenRouterDeepSeekR1Free: {
- ID: OpenRouterDeepSeekR1Free,
- Name: "OpenRouter – DeepSeek R1 Free",
- Provider: ProviderOpenRouter,
- APIModel: "deepseek/deepseek-r1-0528:free",
- CostPer1MIn: 0,
- CostPer1MInCached: 0,
- CostPer1MOut: 0,
- CostPer1MOutCached: 0,
- ContextWindow: 163_840,
- DefaultMaxTokens: 10000,
- },
-}
diff --git a/internal/llm/models/vertexai.go b/internal/llm/models/vertexai.go
deleted file mode 100644
index d71dfc0bed0a8071c89ab22883990413384cd56f..0000000000000000000000000000000000000000
--- a/internal/llm/models/vertexai.go
+++ /dev/null
@@ -1,38 +0,0 @@
-package models
-
-const (
- ProviderVertexAI ModelProvider = "vertexai"
-
- // Models
- VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash"
- VertexAIGemini25 ModelID = "vertexai.gemini-2.5"
-)
-
-var VertexAIGeminiModels = map[ModelID]Model{
- VertexAIGemini25Flash: {
- ID: VertexAIGemini25Flash,
- Name: "VertexAI: Gemini 2.5 Flash",
- Provider: ProviderVertexAI,
- APIModel: "gemini-2.5-flash-preview-04-17",
- CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
- CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
- CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
- CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
- ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
- DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- VertexAIGemini25: {
- ID: VertexAIGemini25,
- Name: "VertexAI: Gemini 2.5 Pro",
- Provider: ProviderVertexAI,
- APIModel: "gemini-2.5-pro-preview-03-25",
- CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
- CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
- CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
- CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
- ContextWindow: GeminiModels[Gemini25].ContextWindow,
- DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
- SupportsAttachments: true,
- },
-}
diff --git a/internal/llm/models/xai.go b/internal/llm/models/xai.go
deleted file mode 100644
index 00caf3b89750c0789f75f6273d49e38a4cdf6282..0000000000000000000000000000000000000000
--- a/internal/llm/models/xai.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package models
-
-const (
- ProviderXAI ModelProvider = "xai"
-
- XAIGrok3Beta ModelID = "grok-3-beta"
- XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
- XAIGrok3FastBeta ModelID = "grok-3-fast-beta"
- XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
-)
-
-var XAIModels = map[ModelID]Model{
- XAIGrok3Beta: {
- ID: XAIGrok3Beta,
- Name: "Grok3 Beta",
- Provider: ProviderXAI,
- APIModel: "grok-3-beta",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 0,
- CostPer1MOut: 15,
- CostPer1MOutCached: 0,
- ContextWindow: 131_072,
- DefaultMaxTokens: 20_000,
- },
- XAIGrok3MiniBeta: {
- ID: XAIGrok3MiniBeta,
- Name: "Grok3 Mini Beta",
- Provider: ProviderXAI,
- APIModel: "grok-3-mini-beta",
- CostPer1MIn: 0.3,
- CostPer1MInCached: 0,
- CostPer1MOut: 0.5,
- CostPer1MOutCached: 0,
- ContextWindow: 131_072,
- DefaultMaxTokens: 20_000,
- },
- XAIGrok3FastBeta: {
- ID: XAIGrok3FastBeta,
- Name: "Grok3 Fast Beta",
- Provider: ProviderXAI,
- APIModel: "grok-3-fast-beta",
- CostPer1MIn: 5,
- CostPer1MInCached: 0,
- CostPer1MOut: 25,
- CostPer1MOutCached: 0,
- ContextWindow: 131_072,
- DefaultMaxTokens: 20_000,
- },
- XAiGrok3MiniFastBeta: {
- ID: XAiGrok3MiniFastBeta,
- Name: "Grok3 Mini Fast Beta",
- Provider: ProviderXAI,
- APIModel: "grok-3-mini-fast-beta",
- CostPer1MIn: 0.6,
- CostPer1MInCached: 0,
- CostPer1MOut: 4.0,
- CostPer1MOutCached: 0,
- ContextWindow: 131_072,
- DefaultMaxTokens: 20_000,
- },
-}
diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go
index ea31bfa0297c1ce207e188a7f162e26831927636..523933d18e5c39ea766c42e1aafe09b5aaff3e63 100644
--- a/internal/llm/prompt/coder.go
+++ b/internal/llm/prompt/coder.go
@@ -9,62 +9,119 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
+ "github.com/charmbracelet/crush/internal/logging"
)
-func CoderPrompt(provider models.ModelProvider) string {
- basePrompt := baseAnthropicCoderPrompt
- switch provider {
- case models.ProviderOpenAI:
+func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string {
+ var basePrompt string
+ switch p {
+ case provider.InferenceProviderOpenAI:
basePrompt = baseOpenAICoderPrompt
+ case provider.InferenceProviderGemini, provider.InferenceProviderVertexAI:
+ basePrompt = baseGeminiCoderPrompt
+ default:
+ basePrompt = baseAnthropicCoderPrompt
}
envInfo := getEnvironmentInfo()
- return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
+ basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
+
+ contextContent := getContextFromPaths(contextFiles)
+ logging.Debug("Context content", "Context", contextContent)
+ if contextContent != "" {
+ return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
+ }
+ return basePrompt
}
const baseOpenAICoderPrompt = `
-You are operating as and within the Crush CLI, a terminal-based agentic coding assistant built by OpenAI. It wraps OpenAI models to enable natural language interaction with a local codebase. You are expected to be precise, safe, and helpful.
-
-You can:
-- Receive user prompts, project context, and files.
-- Stream responses and emit function calls (e.g., shell commands, code edits).
+Please resolve the user's task by editing and testing the code files in your current code execution session.
+You are a deployed coding agent.
+Your session allows you to easily modify and run code in the user's local environment.
+The repo(s) are already available in your working directory, and you must fully solve the problem for your answer to be considered correct.
+IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
-You are an agent - please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. If you are not sure about file content or codebase structure pertaining to the user's request, use your tools to read files and gather the relevant information: do NOT guess or make up an answer.
+# Memory
+If the current working directory contains a file called CRUSH.md, it will be automatically added to your context. This file serves multiple purposes:
+1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time
+2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.)
+3. Maintaining useful information about the codebase structure and organization
-Please resolve the user's task by editing and testing the code files in your current code execution session. You are a deployed coding agent. Your session allows for you to modify and run code. The repo(s) are already cloned in your working directory, and you must fully solve the problem for your answer to be considered correct.
+When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CRUSH.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CRUSH.md so you can remember it for next time.
You MUST adhere to the following criteria when executing the task:
+
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
- Analyzing code for vulnerabilities is allowed.
- Showing user code and tool call details is allowed.
-- User instructions may overwrite the *CODING GUIDELINES* section in this developer message.
+- User instructions may overwrite the _CODING GUIDELINES_ section in this developer message.
+- Do not use ` + "`ls -R`" + `, ` + "`find`" + `, or ` + "`grep`" + ` - these are slow in large repos. Use the Agent tool for searching instead.
+- Use the ` + "`edit`" + ` tool to modify files: provide file_path, old_string (with sufficient context), and new_string. The edit tool requires:
+ - Absolute file paths (starting with /)
+ - Unique old_string matches with 3-5 lines of context before and after
+ - Exact whitespace and indentation matching
+ - For new files: provide file_path and new_string, leave old_string empty
+ - For deleting content: provide file_path and old_string, leave new_string empty
+
+# Following conventions
+When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns.
+- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language).
+- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions.
+- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic.
+- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository.
+
+# Code style
+- Do not add comments to the code you write, unless the user asks you to, or the code is complex and requires additional context.
+
- If completing the user's task requires writing or modifying files:
- - Your code and final answer should follow these *CODING GUIDELINES*:
- - Fix the problem at the root cause rather than applying surface-level patches, when possible.
- - Avoid unneeded complexity in your solution.
- - Ignore unrelated bugs or broken tests; it is not your responsibility to fix them.
- - Update documentation as necessary.
- - Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
- - Use "git log" and "git blame" to search the history of the codebase if additional context is required; internet access is disabled.
- - NEVER add copyright or license headers unless specifically requested.
- - You do not need to "git commit" your changes; this will be done automatically for you.
- - Once you finish coding, you must
- - Check "git status" to sanity check your changes; revert any scratch files or changes.
- - Remove all inline comments you added as much as possible, even if they look normal. Check using "git diff". Inline comments must be generally avoided, unless active maintainers of the repo, after long careful study of the code and the issue, will still misinterpret the code without the comments.
- - Check if you accidentally add copyright or license headers. If so, remove them.
- - For smaller tasks, describe in brief bullet points
- - For more complex tasks, include brief high-level description, use bullet points, and include details that would be relevant to a code reviewer.
+ - Your code and final answer should follow these _CODING GUIDELINES_:
+ - Fix the problem at the root cause rather than applying surface-level patches, when possible.
+ - Avoid unneeded complexity in your solution.
+ - Ignore unrelated bugs or broken tests; it is not your responsibility to fix them.
+ - Update documentation as necessary.
+ - Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
+ - Use ` + "`git log`" + ` and ` + "`git blame`" + ` to search the history of the codebase if additional context is required.
+ - NEVER add copyright or license headers unless specifically requested.
+ - You do not need to ` + "`git commit`" + ` your changes; this will be done automatically for you.
+ - If there is a .pre-commit-config.yaml, use ` + "`pre-commit run --files ...`" + ` to check that your changes pass the pre-commit checks. However, do not fix pre-existing errors on lines you didn't touch.
+ - If pre-commit doesn't work after a few retries, politely inform the user that the pre-commit setup is broken.
+ - Once you finish coding, you must
+ - Check ` + "`git status`" + ` to sanity check your changes; revert any scratch files or changes.
+ - Remove all inline comments you added as much as possible, even if they look normal. Check using ` + "`git diff`" + `. Inline comments must be generally avoided, unless active maintainers of the repo, after long careful study of the code and the issue, will still misinterpret the code without the comments.
+ - Check if you accidentally add copyright or license headers. If so, remove them.
+ - Try to run pre-commit if it is available.
+ - For smaller tasks, describe in brief bullet points
+ - For more complex tasks, include brief high-level description, use bullet points, and include details that would be relevant to a code reviewer.
+
+# Doing tasks
+The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended:
+1. Use the available search tools to understand the codebase and the user's query.
+2. Implement the solution using all tools available to you
+3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach.
+4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to CRUSH.md so that you will know to run it next time.
+
+NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.
+
+# Tool usage policy
+- When doing file search, prefer to use the Agent tool in order to reduce context usage.
+- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in parallel.
+- IMPORTANT: The user does not see the full output of the tool responses, so if you need the output of the tool for the response make sure to summarize it for the user.
+
+# Proactiveness
+You are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between:
+1. Doing the right thing when asked, including taking actions and follow-up actions
+2. Not surprising the user with actions you take without asking
+For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions.
+3. Do not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did.
+
- If completing the user's task DOES NOT require writing or modifying files (e.g., the user asks a question about the code base):
- - Respond in a friendly tune as a remote teammate, who is knowledgeable, capable and eager to help with coding.
+ - Respond in a friendly tone as a remote teammate, who is knowledgeable, capable and eager to help with coding.
- When your task involves writing or modifying files:
- - Do NOT tell the user to "save the file" or "copy the code into a file" if you already created or modified the file using "edit/write". Instead, reference the file as already saved.
- - Do NOT show the full contents of large files you have already written, unless the user explicitly asks for them.
-- When doing things with paths, always use use the full path, if the working directory is /abc/xyz and you want to edit the file abc.go in the working dir refer to it as /abc/xyz/abc.go.
-- If you send a path not including the working dir, the working dir will be prepended to it.
-- Remember the user does not see the full output of tools
+ - Do NOT tell the user to "save the file" or "copy the code into a file" if you already created or modified the file using ` + "`edit`" + `. Instead, reference the file as already saved.
+ - Do NOT show the full contents of large files you have already written, unless the user explicitly asks for them.
- NEVER use emojis in your responses
`
@@ -167,6 +224,161 @@ VERY IMPORTANT NEVER use emojis in your responses.
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.`
+const baseGeminiCoderPrompt = `
+You are an interactive CLI agent specializing in software engineering tasks. Your primary goal is to help users safely and efficiently, adhering strictly to the following instructions and utilizing your available tools.
+
+IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
+
+# Memory
+If the current working directory contains a file called CRUSH.md, it will be automatically added to your context. This file serves multiple purposes:
+1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time
+2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.)
+3. Maintaining useful information about the codebase structure and organization
+
+When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CRUSH.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CRUSH.md so you can remember it for next time.
+
+# Core Mandates
+
+- **Conventions:** Rigorously adhere to existing project conventions when reading or modifying code. Analyze surrounding code, tests, and configuration first.
+- **Libraries/Frameworks:** NEVER assume a library/framework is available or appropriate. Verify its established usage within the project (check imports, configuration files like 'package.json', 'Cargo.toml', 'requirements.txt', 'build.gradle', etc., or observe neighboring files) before employing it.
+- **Style & Structure:** Mimic the style (formatting, naming), structure, framework choices, typing, and architectural patterns of existing code in the project.
+- **Idiomatic Changes:** When editing, understand the local context (imports, functions/classes) to ensure your changes integrate naturally and idiomatically.
+- **Comments:** Add code comments sparingly. Focus on *why* something is done, especially for complex logic, rather than *what* is done. Only add high-value comments if necessary for clarity or if requested by the user. Do not edit comments that are separate from the code you are changing. *NEVER* talk to the user or describe your changes through comments.
+- **Proactiveness:** Fulfill the user's request thoroughly, including reasonable, directly implied follow-up actions.
+- **Confirm Ambiguity/Expansion:** Do not take significant actions beyond the clear scope of the request without confirming with the user. If asked *how* to do something, explain first, don't just do it.
+- **Explaining Changes:** After completing a code modification or file operation *do not* provide summaries unless asked.
+- **Do Not revert changes:** Do not revert changes to the codebase unless asked to do so by the user. Only revert changes made by you if they have resulted in an error or if the user has explicitly asked you to revert the changes.
+
+# Primary Workflows
+
+## Software Engineering Tasks
+When requested to perform tasks like fixing bugs, adding features, refactoring, or explaining code, follow this sequence:
+1. **Understand:** Think about the user's request and the relevant codebase context. Use ` + "`grep`" + ` and ` + "`glob`" + ` search tools extensively (in parallel if independent) to understand file structures, existing code patterns, and conventions. Use ` + "`view`" + ` to understand context and validate any assumptions you may have.
+2. **Plan:** Build a coherent and grounded (based on the understanding in step 1) plan for how you intend to resolve the user's task. Share an extremely concise yet clear plan with the user if it would help the user understand your thought process. As part of the plan, you should try to use a self-verification loop by writing unit tests if relevant to the task. Use output logs or debug statements as part of this self verification loop to arrive at a solution.
+3. **Implement:** Use the available tools (e.g., ` + "`edit`" + `, ` + "`write`" + ` ` + "`bash`" + ` ...) to act on the plan, strictly adhering to the project's established conventions (detailed under 'Core Mandates').
+4. **Verify (Tests):** If applicable and feasible, verify the changes using the project's testing procedures. Identify the correct test commands and frameworks by examining 'README' files, build/package configuration (e.g., 'package.json'), or existing test execution patterns. NEVER assume standard test commands.
+5. **Verify (Standards):** VERY IMPORTANT: After making code changes, execute the project-specific build, linting and type-checking commands (e.g., 'tsc', 'npm run lint', 'ruff check .') that you have identified for this project (or obtained from the user). This ensures code quality and adherence to standards. If unsure about these commands, you can ask the user if they'd like you to run them and if so how to.
+
+NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.
+
+# Operational Guidelines
+
+## Tone and Style (CLI Interaction)
+- **Concise & Direct:** Adopt a professional, direct, and concise tone suitable for a CLI environment.
+- **Minimal Output:** Aim for fewer than 3 lines of text output (excluding tool use/code generation) per response whenever practical. Focus strictly on the user's query.
+- **Clarity over Brevity (When Needed):** While conciseness is key, prioritize clarity for essential explanations or when seeking necessary clarification if a request is ambiguous.
+- **No Chitchat:** Avoid conversational filler, preambles ("Okay, I will now..."), or postambles ("I have finished the changes..."). Get straight to the action or answer.
+- **Formatting:** Use GitHub-flavored Markdown. Responses will be rendered in monospace.
+- **Tools vs. Text:** Use tools for actions, text output *only* for communication. Do not add explanatory comments within tool calls or code blocks unless specifically part of the required code/command itself.
+- **Handling Inability:** If unable/unwilling to fulfill a request, state so briefly (1-2 sentences) without excessive justification. Offer alternatives if appropriate.
+
+## Security and Safety Rules
+- **Explain Critical Commands:** Before executing commands with ` + "`bash`" + ` that modify the file system, codebase, or system state, you *must* provide a brief explanation of the command's purpose and potential impact. Prioritize user understanding and safety.
+- **Security First:** Always apply security best practices. Never introduce code that exposes, logs, or commits secrets, API keys, or other sensitive information.
+
+## Tool Usage
+- **File Paths:** Always use absolute paths when referring to files with tools like ` + "`view`" + ` or ` + "`write`" + `. Relative paths are not supported. You must provide an absolute path.
+- **Parallelism:** Execute multiple independent tool calls in parallel when feasible (i.e. searching the codebase).
+- **Command Execution:** Use the ` + "`bash`" + ` tool for running shell commands, remembering the safety rule to explain modifying commands first.
+- **Background Processes:** Use background processes (via ` + "`&`" + `) for commands that are unlikely to stop on their own, e.g. ` + "`node server.js &`" + `. If unsure, ask the user.
+- **Interactive Commands:** Try to avoid shell commands that are likely to require user interaction (e.g. ` + "`git rebase -i`" + `). Use non-interactive versions of commands (e.g. ` + "`npm init -y`" + ` instead of ` + "`npm init`" + `) when available, and otherwise remind the user that interactive shell commands are not supported and may cause hangs until canceled by the user.
+
+# Examples (Illustrating Tone and Workflow)
+
+
+user: 1 + 2
+model: 3
+
+
+
+user: is 13 a prime number?
+model: true
+
+
+
+user: list files here.
+model: [tool_call: ls for path '.']
+
+
+
+user: start the server implemented in server.js
+model: [tool_call: bash for 'node server.js &' because it must run in the background]
+
+
+
+user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
+model: Okay, I can refactor 'src/auth.py'.
+First, I'll analyze the code and check for a test safety net before planning any changes.
+[tool_call: glob for pattern 'tests/test_auth.py']
+[tool_call: view for absolute_path '/path/to/tests/test_auth.py']
+(After analysis)
+Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
+I'll also confirm 'requests' is a dependency.
+[tool_call: view for absolute_path '/path/to/requirements.txt']
+(After analysis)
+Looks good, 'requests' is available.
+
+Here's the plan:
+1. Replace the 'urllib' calls with 'requests'.
+2. Add proper 'try...except' error handling for the new network calls.
+3. Remove the old 'urllib' import.
+4. Run the project's linter and tests to verify the changes.
+
+Should I proceed?
+user: Yes
+model:
+[tool_call: edit to apply the refactoring to 'src/auth.py']
+Refactoring complete. Running verification...
+[tool_call: bash for 'ruff check src/auth.py && pytest']
+(After verification passes)
+All checks passed. This is a stable checkpoint.
+
+
+
+user: Delete the temp directory.
+model: I can run ` + "`rm -rf ./temp`" + `. This will permanently delete the directory and all its contents.
+
+
+
+user: Write tests for someFile.ts
+model:
+Okay, I can write those tests. First, I'll read ` + "`someFile.ts`" + ` to understand its functionality.
+[tool_call: view for absolute_path '/path/to/someFile.ts' or use glob to find ` + "`someFile.ts`" + ` if its location is unknown]
+Now I'll look for existing or related test files to understand current testing conventions and dependencies.
+[tool_call: glob for pattern '**/*.test.ts' and '**/*.spec.ts']
+(After reviewing existing tests and the file content)
+[tool_call: write to create /path/to/someFile.test.ts with the test code]
+I've written the tests. Now I'll run the project's test command to verify them.
+[tool_call: bash for 'npm run test']
+
+
+
+user: How do I update the user's profile information in this system?
+model:
+I'm not immediately sure how user profile information is updated. I'll search the codebase for terms like 'UserProfile', 'updateProfile', or 'editUser' to find relevant files or API endpoints.
+[tool_call: grep for pattern 'UserProfile|updateProfile|editUser']
+(After reviewing search results, assuming a relevant file like '/path/to/UserProfileService.java' was found)
+Okay, ` + "`/path/to/UserProfileService.java`" + ` seems like the most relevant file. I'll read its content to understand how updates are handled.
+[tool_call: view for absolute_path '/path/to/UserProfileService.java']
+(After reading the file)
+It appears the ` + "`updateUserProfile`" + ` method in ` + "`UserProfileService.java`" + ` is responsible for this. It expects a user ID and a ` + "`UserProfileDTO`" + ` object...
+
+
+
+user: Where are all the 'app.config' files in this project? I need to check their settings.
+model:
+[tool_call: glob for pattern '**/app.config']
+(Assuming glob returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
+I found the following 'app.config' files:
+- /path/to/moduleA/app.config
+- /path/to/moduleB/app.config
+To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
+
+
+# Final Reminder
+Your core function is efficient and safe assistance. Balance extreme conciseness with the crucial need for clarity, especially regarding safety and potential system modifications. Always prioritize user control and project conventions. Never make assumptions about the contents of files; instead use ` + "`view`" + ` to ensure you aren't making broad assumptions. Finally, you are an agent - please keep going until the user's query is completely resolved.
+`
+
func getEnvironmentInfo() string {
cwd := config.WorkingDirectory()
isGit := isGitRepo(cwd)
diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go
index 9065fd6bc0bb4a69eab19479b9a18b270dddee02..36148edd9c71790c3a4cb06d551cdee06272c8b7 100644
--- a/internal/llm/prompt/prompt.go
+++ b/internal/llm/prompt/prompt.go
@@ -1,60 +1,44 @@
package prompt
import (
- "fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
- "github.com/charmbracelet/crush/internal/logging"
+ "github.com/charmbracelet/crush/internal/fur/provider"
)
-func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
+type PromptID string
+
+const (
+ PromptCoder PromptID = "coder"
+ PromptTitle PromptID = "title"
+ PromptTask PromptID = "task"
+ PromptSummarizer PromptID = "summarizer"
+ PromptDefault PromptID = "default"
+)
+
+func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string {
basePrompt := ""
- switch agentName {
- case config.AgentCoder:
+ switch promptID {
+ case PromptCoder:
basePrompt = CoderPrompt(provider)
- case config.AgentTitle:
+ case PromptTitle:
basePrompt = TitlePrompt(provider)
- case config.AgentTask:
+ case PromptTask:
basePrompt = TaskPrompt(provider)
- case config.AgentSummarizer:
+ case PromptSummarizer:
basePrompt = SummarizerPrompt(provider)
default:
basePrompt = "You are a helpful assistant"
}
-
- if agentName == config.AgentCoder || agentName == config.AgentTask {
- // Add context from project-specific instruction files if they exist
- contextContent := getContextFromPaths()
- logging.Debug("Context content", "Context", contextContent)
- if contextContent != "" {
- return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
- }
- }
return basePrompt
}
-var (
- onceContext sync.Once
- contextContent string
-)
-
-func getContextFromPaths() string {
- onceContext.Do(func() {
- var (
- cfg = config.Get()
- workDir = cfg.WorkingDir
- contextPaths = cfg.ContextPaths
- )
-
- contextContent = processContextPaths(workDir, contextPaths)
- })
-
- return contextContent
+func getContextFromPaths(contextPaths []string) string {
+ return processContextPaths(config.WorkingDirectory(), contextPaths)
}
func processContextPaths(workDir string, paths []string) string {
diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go
index a350c55a32260173dabd56e22d9e514e97b3e5a3..41e3fe92c7fb5615b6c93e2aa89bad35820567ef 100644
--- a/internal/llm/prompt/prompt_test.go
+++ b/internal/llm/prompt/prompt_test.go
@@ -15,16 +15,10 @@ func TestGetContextFromPaths(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
- _, err := config.Load(tmpDir, false)
+ _, err := config.Init(tmpDir, false)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
- cfg := config.Get()
- cfg.WorkingDir = tmpDir
- cfg.ContextPaths = []string{
- "file.txt",
- "directory/",
- }
testFiles := []string{
"file.txt",
"directory/file_a.txt",
@@ -34,7 +28,12 @@ func TestGetContextFromPaths(t *testing.T) {
createTestFiles(t, tmpDir, testFiles)
- context := getContextFromPaths()
+ context := getContextFromPaths(
+ []string{
+ "file.txt",
+ "directory/",
+ },
+ )
expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir)
assert.Equal(t, expectedContext, context)
}
diff --git a/internal/llm/prompt/summarizer.go b/internal/llm/prompt/summarizer.go
index 87a0f95c66af8b51d07a3a4e792c07dea7dab503..77d98184bcf985ebb2bc569205b6b4cc77b3d601 100644
--- a/internal/llm/prompt/summarizer.go
+++ b/internal/llm/prompt/summarizer.go
@@ -1,8 +1,10 @@
package prompt
-import "github.com/charmbracelet/crush/internal/llm/models"
+import (
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
-func SummarizerPrompt(_ models.ModelProvider) string {
+func SummarizerPrompt(_ provider.InferenceProvider) string {
return `You are a helpful AI assistant tasked with summarizing conversations.
When asked to summarize, provide a detailed but concise summary of the conversation.
diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go
index 53fd67dc2f88928b4fbe9773db0cd1487bcd811a..719c0ef45778814e38b391e86174708edcdd7c3e 100644
--- a/internal/llm/prompt/task.go
+++ b/internal/llm/prompt/task.go
@@ -3,10 +3,10 @@ package prompt
import (
"fmt"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
)
-func TaskPrompt(_ models.ModelProvider) string {
+func TaskPrompt(_ provider.InferenceProvider) string {
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go
index 03e47288507fa66bb88605bff4b2194b889cc3f7..11bab4b6835ac0e53adc578cfddd3133f8b654e5 100644
--- a/internal/llm/prompt/title.go
+++ b/internal/llm/prompt/title.go
@@ -1,8 +1,10 @@
package prompt
-import "github.com/charmbracelet/crush/internal/llm/models"
+import (
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
-func TitlePrompt(_ models.ModelProvider) string {
+func TitlePrompt(_ provider.InferenceProvider) string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message
diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go
index 634040ebc3dad8d6dc9e7642ebbe95ac3b051c63..05f09ad77e224a59bcb825e85f353e317c7c4a83 100644
--- a/internal/llm/provider/anthropic.go
+++ b/internal/llm/provider/anthropic.go
@@ -6,55 +6,42 @@ import (
"errors"
"fmt"
"io"
- "strings"
"time"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
)
-type anthropicOptions struct {
- useBedrock bool
- disableCache bool
- shouldThink func(userMessage string) bool
-}
-
-type AnthropicOption func(*anthropicOptions)
-
type anthropicClient struct {
providerOptions providerClientOptions
- options anthropicOptions
+ useBedrock bool
client anthropic.Client
}
type AnthropicClient ProviderClient
-func newAnthropicClient(opts providerClientOptions) AnthropicClient {
- anthropicOpts := anthropicOptions{}
- for _, o := range opts.anthropicOptions {
- o(&anthropicOpts)
+func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
+ return &anthropicClient{
+ providerOptions: opts,
+ client: createAnthropicClient(opts, useBedrock),
}
+}
+func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
anthropicClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
}
- if anthropicOpts.useBedrock {
+ if useBedrock {
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
}
-
- client := anthropic.NewClient(anthropicClientOptions...)
- return &anthropicClient{
- providerOptions: opts,
- options: anthropicOpts,
- client: client,
- }
+ return anthropic.NewClient(anthropicClientOptions...)
}
func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
@@ -66,7 +53,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content().String())
- if cache && !a.options.disableCache {
+ if cache && !a.providerOptions.disableCache {
content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -74,7 +61,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
var contentBlocks []anthropic.ContentBlockParamUnion
contentBlocks = append(contentBlocks, content)
for _, binaryContent := range msg.BinaryContent() {
- base64Image := binaryContent.String(models.ProviderAnthropic)
+ base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
contentBlocks = append(contentBlocks, imageBlock)
}
@@ -84,7 +71,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content().String() != "" {
content := anthropic.NewTextBlock(msg.Content().String())
- if cache && !a.options.disableCache {
+ if cache && !a.providerOptions.disableCache {
content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -132,7 +119,7 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU
},
}
- if i == len(tools)-1 && !a.options.disableCache {
+ if i == len(tools)-1 && !a.providerOptions.disableCache {
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -160,26 +147,33 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason {
}
func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
+ model := a.providerOptions.model(a.providerOptions.modelType)
var thinkingParam anthropic.ThinkingConfigParamUnion
- lastMessage := messages[len(messages)-1]
- isUser := lastMessage.Role == anthropic.MessageParamRoleUser
- messageContent := ""
+ cfg := config.Get()
+ modelConfig := cfg.Models.Large
+ if a.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
temperature := anthropic.Float(0)
- if isUser {
- for _, m := range lastMessage.Content {
- if m.OfText != nil && m.OfText.Text != "" {
- messageContent = m.OfText.Text
- }
- }
- if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) {
- thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
- temperature = anthropic.Float(1)
- }
+
+ if a.Model().CanReason && modelConfig.Think {
+ thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
+ temperature = anthropic.Float(1)
+ }
+
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
+
+ // Override max tokens if set in provider options
+ if a.providerOptions.maxTokens > 0 {
+ maxTokens = a.providerOptions.maxTokens
}
return anthropic.MessageNewParams{
- Model: anthropic.Model(a.providerOptions.model.APIModel),
- MaxTokens: a.providerOptions.maxTokens,
+ Model: anthropic.Model(model.ID),
+ MaxTokens: maxTokens,
Temperature: temperature,
Messages: messages,
Tools: tools,
@@ -198,7 +192,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(preparedMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -247,7 +241,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
// jsonData, _ := json.Marshal(preparedMessages)
// logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -383,12 +377,21 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
}
func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
- var apierr *anthropic.Error
- if !errors.As(err, &apierr) {
+ var apiErr *anthropic.Error
+ if !errors.As(err, &apiErr) {
return false, 0, err
}
- if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
+ if apiErr.StatusCode == 401 {
+ a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+ }
+ a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
+ return true, 0, nil
+ }
+
+ if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 {
return false, 0, err
}
@@ -397,7 +400,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
}
retryMs := 0
- retryAfterValues := apierr.Response.Header.Values("Retry-After")
+ retryAfterValues := apiErr.Response.Header.Values("Retry-After")
backoffMs := 2000 * (1 << (attempts - 1))
jitterMs := int(float64(backoffMs) * 0.2)
@@ -439,24 +442,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
}
}
-func WithAnthropicBedrock(useBedrock bool) AnthropicOption {
- return func(options *anthropicOptions) {
- options.useBedrock = useBedrock
- }
-}
-
-func WithAnthropicDisableCache() AnthropicOption {
- return func(options *anthropicOptions) {
- options.disableCache = true
- }
-}
-
-func DefaultShouldThinkFn(s string) bool {
- return strings.Contains(strings.ToLower(s), "think")
-}
-
-func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption {
- return func(options *anthropicOptions) {
- options.shouldThink = fn
- }
+func (a *anthropicClient) Model() config.Model {
+ return a.providerOptions.model(a.providerOptions.modelType)
}
diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go
index 8d3a86198aab5a38742e33b167f2545efd808873..1519099b00401e32ad5f19c1f6ed253eb8b7130d 100644
--- a/internal/llm/provider/bedrock.go
+++ b/internal/llm/provider/bedrock.go
@@ -4,65 +4,54 @@ import (
"context"
"errors"
"fmt"
- "os"
"strings"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
-type bedrockOptions struct {
- // Bedrock specific options can be added here
-}
-
-type BedrockOption func(*bedrockOptions)
-
type bedrockClient struct {
providerOptions providerClientOptions
- options bedrockOptions
childProvider ProviderClient
}
type BedrockClient ProviderClient
func newBedrockClient(opts providerClientOptions) BedrockClient {
- bedrockOpts := bedrockOptions{}
- // Apply bedrock specific options if they are added in the future
-
// Get AWS region from environment
- region := os.Getenv("AWS_REGION")
- if region == "" {
- region = os.Getenv("AWS_DEFAULT_REGION")
- }
-
+ region := opts.extraParams["region"]
if region == "" {
region = "us-east-1" // default region
}
if len(region) < 2 {
return &bedrockClient{
providerOptions: opts,
- options: bedrockOpts,
childProvider: nil, // Will cause an error when used
}
}
- // Prefix the model name with region
- regionPrefix := region[:2]
- modelName := opts.model.APIModel
- opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
+ opts.model = func(modelType config.ModelType) config.Model {
+ model := config.GetModel(modelType)
+
+ // Prefix the model name with region
+ regionPrefix := region[:2]
+ modelName := model.ID
+ model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
+ return model
+ }
+
+ model := opts.model(opts.modelType)
// Determine which provider to use based on the model
- if strings.Contains(string(opts.model.APIModel), "anthropic") {
+ if strings.Contains(string(model.ID), "anthropic") {
// Create Anthropic client with Bedrock configuration
anthropicOpts := opts
- anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
- WithAnthropicBedrock(true),
- WithAnthropicDisableCache(),
- )
+ // TODO: later find a way to check if the AWS account has caching enabled
+ opts.disableCache = true // Disable cache for Bedrock
return &bedrockClient{
providerOptions: opts,
- options: bedrockOpts,
- childProvider: newAnthropicClient(anthropicOpts),
+ childProvider: newAnthropicClient(anthropicOpts, true),
}
}
@@ -70,7 +59,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
// This will cause an error when used
return &bedrockClient{
providerOptions: opts,
- options: bedrockOpts,
childProvider: nil,
}
}
@@ -98,3 +86,7 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
return b.childProvider.stream(ctx, messages, tools)
}
+
+func (b *bedrockClient) Model() config.Model {
+ return b.providerOptions.model(b.providerOptions.modelType)
+}
diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go
index 9481d8d545aab12a3739fe99b4af61f4ed99a514..e80af34d0815695ea6ed76d01c25262381a836ec 100644
--- a/internal/llm/provider/gemini.go
+++ b/internal/llm/provider/gemini.go
@@ -17,27 +17,15 @@ import (
"google.golang.org/genai"
)
-type geminiOptions struct {
- disableCache bool
-}
-
-type GeminiOption func(*geminiOptions)
-
type geminiClient struct {
providerOptions providerClientOptions
- options geminiOptions
client *genai.Client
}
type GeminiClient ProviderClient
func newGeminiClient(opts providerClientOptions) GeminiClient {
- geminiOpts := geminiOptions{}
- for _, o := range opts.geminiOptions {
- o(&geminiOpts)
- }
-
- client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
+ client, err := createGeminiClient(opts)
if err != nil {
logging.Error("Failed to create Gemini client", "error", err)
return nil
@@ -45,11 +33,18 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
return &geminiClient{
providerOptions: opts,
- options: geminiOpts,
client: client,
}
}
+func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
+ client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
+ if err != nil {
+ return nil, err
+ }
+ return client, nil
+}
+
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
var history []*genai.Content
for _, msg := range messages {
@@ -168,17 +163,26 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
// Convert messages
geminiMessages := g.convertMessages(messages)
-
+ model := g.providerOptions.model(g.providerOptions.modelType)
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
+ modelConfig := cfg.Models.Large
+ if g.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
+
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
- MaxOutputTokens: int32(g.providerOptions.maxTokens),
+ MaxOutputTokens: int32(maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
@@ -186,7 +190,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
- chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
+ chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
attempts := 0
for {
@@ -257,16 +261,30 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
// Convert messages
geminiMessages := g.convertMessages(messages)
+ model := g.providerOptions.model(g.providerOptions.modelType)
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
+ modelConfig := cfg.Models.Large
+ if g.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
+
+ // Override max tokens if set in provider options
+ if g.providerOptions.maxTokens > 0 {
+ maxTokens = g.providerOptions.maxTokens
+ }
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
- MaxOutputTokens: int32(g.providerOptions.maxTokens),
+ MaxOutputTokens: int32(maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
@@ -274,7 +292,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
- chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
+ chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
attempts := 0
eventChan := make(chan ProviderEvent)
@@ -404,6 +422,19 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
errMsg := err.Error()
isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
+ // Check for token expiration (401 Unauthorized)
+ if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
+ g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+ }
+ g.client, err = createGeminiClient(g.providerOptions)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
+ }
+ return true, 0, nil
+ }
+
// Check for common rate limit error messages
if !isRateLimit {
@@ -418,27 +449,6 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
return true, int64(retryMs), nil
}
-func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
- var toolCalls []message.ToolCall
-
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- if part.FunctionCall != nil {
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(part.FunctionCall.Args)
- toolCalls = append(toolCalls, message.ToolCall{
- ID: id,
- Name: part.FunctionCall.Name,
- Input: string(args),
- Type: "function",
- })
- }
- }
- }
-
- return toolCalls
-}
-
func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
@@ -452,10 +462,8 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
}
}
-func WithGeminiDisableCache() GeminiOption {
- return func(options *geminiOptions) {
- options.disableCache = true
- }
+func (g *geminiClient) Model() config.Model {
+ return g.providerOptions.model(g.providerOptions.modelType)
}
// Helper functions
diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go
index 05658dd6db760a1d05a88ae4931de5c70d9cc453..bddf820c2d3ccf9bba1a683ed4fe469d05fa31bf 100644
--- a/internal/llm/provider/openai.go
+++ b/internal/llm/provider/openai.go
@@ -9,7 +9,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
@@ -18,51 +18,36 @@ import (
"github.com/openai/openai-go/shared"
)
-type openaiOptions struct {
- baseURL string
- disableCache bool
- reasoningEffort string
- extraHeaders map[string]string
-}
-
-type OpenAIOption func(*openaiOptions)
-
type openaiClient struct {
providerOptions providerClientOptions
- options openaiOptions
client openai.Client
}
type OpenAIClient ProviderClient
func newOpenAIClient(opts providerClientOptions) OpenAIClient {
- openaiOpts := openaiOptions{
- reasoningEffort: "medium",
- }
- for _, o := range opts.openaiOptions {
- o(&openaiOpts)
+ return &openaiClient{
+ providerOptions: opts,
+ client: createOpenAIClient(opts),
}
+}
+func createOpenAIClient(opts providerClientOptions) openai.Client {
openaiClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
}
- if openaiOpts.baseURL != "" {
- openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
+ if opts.baseURL != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL))
}
- if openaiOpts.extraHeaders != nil {
- for key, value := range openaiOpts.extraHeaders {
+ if opts.extraHeaders != nil {
+ for key, value := range opts.extraHeaders {
openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
}
}
- client := openai.NewClient(openaiClientOptions...)
- return &openaiClient{
- providerOptions: opts,
- options: openaiOpts,
- client: client,
- }
+ return openai.NewClient(openaiClientOptions...)
}
func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
@@ -76,7 +61,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
for _, binaryContent := range msg.BinaryContent() {
- imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
+ imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
@@ -160,15 +145,37 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason {
}
func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
+ model := o.providerOptions.model(o.providerOptions.modelType)
+ cfg := config.Get()
+
+ modelConfig := cfg.Models.Large
+ if o.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
+
+ reasoningEffort := model.ReasoningEffort
+ if modelConfig.ReasoningEffort != "" {
+ reasoningEffort = modelConfig.ReasoningEffort
+ }
+
params := openai.ChatCompletionNewParams{
- Model: openai.ChatModel(o.providerOptions.model.APIModel),
+ Model: openai.ChatModel(model.ID),
Messages: messages,
Tools: tools,
}
- if o.providerOptions.model.CanReason {
- params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
- switch o.options.reasoningEffort {
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
+
+ // Override max tokens if set in provider options
+ if o.providerOptions.maxTokens > 0 {
+ maxTokens = o.providerOptions.maxTokens
+ }
+ if model.CanReason {
+ params.MaxCompletionTokens = openai.Int(maxTokens)
+ switch reasoningEffort {
case "low":
params.ReasoningEffort = shared.ReasoningEffortLow
case "medium":
@@ -176,10 +183,10 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
case "high":
params.ReasoningEffort = shared.ReasoningEffortHigh
default:
- params.ReasoningEffort = shared.ReasoningEffortMedium
+ params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
}
} else {
- params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
+ params.MaxTokens = openai.Int(maxTokens)
}
return params
@@ -188,7 +195,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(params)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -245,7 +252,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
}
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(params)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -335,12 +342,22 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
}
func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
- var apierr *openai.Error
- if !errors.As(err, &apierr) {
+ var apiErr *openai.Error
+ if !errors.As(err, &apiErr) {
return false, 0, err
}
- if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
+ // Check for token expiration (401 Unauthorized)
+ if apiErr.StatusCode == 401 {
+ o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+ }
+ o.client = createOpenAIClient(o.providerOptions)
+ return true, 0, nil
+ }
+
+ if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
return false, 0, err
}
@@ -349,7 +366,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
}
retryMs := 0
- retryAfterValues := apierr.Response.Header.Values("Retry-After")
+ retryAfterValues := apiErr.Response.Header.Values("Retry-After")
backoffMs := 2000 * (1 << (attempts - 1))
jitterMs := int(float64(backoffMs) * 0.2)
@@ -393,33 +410,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
}
}
-func WithOpenAIBaseURL(baseURL string) OpenAIOption {
- return func(options *openaiOptions) {
- options.baseURL = baseURL
- }
-}
-
-func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
- return func(options *openaiOptions) {
- options.extraHeaders = headers
- }
-}
-
-func WithOpenAIDisableCache() OpenAIOption {
- return func(options *openaiOptions) {
- options.disableCache = true
- }
-}
-
-func WithReasoningEffort(effort string) OpenAIOption {
- return func(options *openaiOptions) {
- defaultReasoningEffort := "medium"
- switch effort {
- case "low", "medium", "high":
- defaultReasoningEffort = effort
- default:
- logging.Warn("Invalid reasoning effort, using default: medium")
- }
- options.reasoningEffort = defaultReasoningEffort
- }
+func (a *openaiClient) Model() config.Model {
+ return a.providerOptions.model(a.providerOptions.modelType)
}
diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go
index d63f73f67cf0455bcccbb06ae70e3dde6a09557c..3ffbf86c00c5e3ca27f1b68965f4ff950f1f7454 100644
--- a/internal/llm/provider/provider.go
+++ b/internal/llm/provider/provider.go
@@ -3,9 +3,9 @@ package provider
import (
"context"
"fmt"
- "os"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
@@ -55,19 +55,20 @@ type Provider interface {
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
- Model() models.Model
+ Model() config.Model
}
type providerClientOptions struct {
+ baseURL string
+ config config.ProviderConfig
apiKey string
- model models.Model
- maxTokens int64
+ modelType config.ModelType
+ model func(config.ModelType) config.Model
+ disableCache bool
systemMessage string
-
- anthropicOptions []AnthropicOption
- openaiOptions []OpenAIOption
- geminiOptions []GeminiOption
- bedrockOptions []BedrockOption
+ maxTokens int64
+ extraHeaders map[string]string
+ extraParams map[string]string
}
type ProviderClientOption func(*providerClientOptions)
@@ -75,6 +76,8 @@ type ProviderClientOption func(*providerClientOptions)
type ProviderClient interface {
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
+
+ Model() config.Model
}
type baseProvider[C ProviderClient] struct {
@@ -82,85 +85,6 @@ type baseProvider[C ProviderClient] struct {
client C
}
-func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
- clientOptions := providerClientOptions{}
- for _, o := range opts {
- o(&clientOptions)
- }
- switch providerName {
- case models.ProviderAnthropic:
- return &baseProvider[AnthropicClient]{
- options: clientOptions,
- client: newAnthropicClient(clientOptions),
- }, nil
- case models.ProviderOpenAI:
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderGemini:
- return &baseProvider[GeminiClient]{
- options: clientOptions,
- client: newGeminiClient(clientOptions),
- }, nil
- case models.ProviderBedrock:
- return &baseProvider[BedrockClient]{
- options: clientOptions,
- client: newBedrockClient(clientOptions),
- }, nil
- case models.ProviderGROQ:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
- )
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderAzure:
- return &baseProvider[AzureClient]{
- options: clientOptions,
- client: newAzureClient(clientOptions),
- }, nil
- case models.ProviderVertexAI:
- return &baseProvider[VertexAIClient]{
- options: clientOptions,
- client: newVertexAIClient(clientOptions),
- }, nil
- case models.ProviderOpenRouter:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
- WithOpenAIExtraHeaders(map[string]string{
- "HTTP-Referer": "crush.charm.land",
- "X-Title": "Crush",
- }),
- )
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderXAI:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://api.x.ai/v1"),
- )
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderLocal:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
- )
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderMock:
- // TODO: implement mock client for test
- panic("not implemented")
- }
- return nil, fmt.Errorf("provider not supported: %s", providerName)
-}
-
func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
for _, msg := range messages {
// The message has no content
@@ -177,30 +101,24 @@ func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.M
return p.client.send(ctx, messages, tools)
}
-func (p *baseProvider[C]) Model() models.Model {
- return p.options.model
-}
-
func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
messages = p.cleanMessages(messages)
return p.client.stream(ctx, messages, tools)
}
-func WithAPIKey(apiKey string) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.apiKey = apiKey
- }
+func (p *baseProvider[C]) Model() config.Model {
+ return p.client.Model()
}
-func WithModel(model models.Model) ProviderClientOption {
+func WithModel(model config.ModelType) ProviderClientOption {
return func(options *providerClientOptions) {
- options.model = model
+ options.modelType = model
}
}
-func WithMaxTokens(maxTokens int64) ProviderClientOption {
+func WithDisableCache(disableCache bool) ProviderClientOption {
return func(options *providerClientOptions) {
- options.maxTokens = maxTokens
+ options.disableCache = disableCache
}
}
@@ -210,26 +128,67 @@ func WithSystemMessage(systemMessage string) ProviderClientOption {
}
}
-func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
+func WithMaxTokens(maxTokens int64) ProviderClientOption {
return func(options *providerClientOptions) {
- options.anthropicOptions = anthropicOptions
+ options.maxTokens = maxTokens
}
}
-func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.openaiOptions = openaiOptions
+func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
+ resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
}
-}
-func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.geminiOptions = geminiOptions
+ clientOptions := providerClientOptions{
+ baseURL: cfg.BaseURL,
+ config: cfg,
+ apiKey: resolvedAPIKey,
+ extraHeaders: cfg.ExtraHeaders,
+ model: func(tp config.ModelType) config.Model {
+ return config.GetModel(tp)
+ },
}
-}
-
-func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.bedrockOptions = bedrockOptions
+ for _, o := range opts {
+ o(&clientOptions)
+ }
+ switch cfg.ProviderType {
+ case provider.TypeAnthropic:
+ return &baseProvider[AnthropicClient]{
+ options: clientOptions,
+ client: newAnthropicClient(clientOptions, false),
+ }, nil
+ case provider.TypeOpenAI:
+ return &baseProvider[OpenAIClient]{
+ options: clientOptions,
+ client: newOpenAIClient(clientOptions),
+ }, nil
+ case provider.TypeGemini:
+ return &baseProvider[GeminiClient]{
+ options: clientOptions,
+ client: newGeminiClient(clientOptions),
+ }, nil
+ case provider.TypeBedrock:
+ return &baseProvider[BedrockClient]{
+ options: clientOptions,
+ client: newBedrockClient(clientOptions),
+ }, nil
+ case provider.TypeAzure:
+ return &baseProvider[AzureClient]{
+ options: clientOptions,
+ client: newAzureClient(clientOptions),
+ }, nil
+ case provider.TypeVertexAI:
+ return &baseProvider[VertexAIClient]{
+ options: clientOptions,
+ client: newVertexAIClient(clientOptions),
+ }, nil
+ case provider.TypeXAI:
+ clientOptions.baseURL = "https://api.x.ai/v1"
+ return &baseProvider[OpenAIClient]{
+ options: clientOptions,
+ client: newOpenAIClient(clientOptions),
+ }, nil
}
+ return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
}
diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go
index fe2de2f4588f9dbe583e4f8af85e61eea67d5648..2d95ad3f60db22e1338db3931b0900e83bccab52 100644
--- a/internal/llm/provider/vertexai.go
+++ b/internal/llm/provider/vertexai.go
@@ -2,7 +2,6 @@ package provider
import (
"context"
- "os"
"github.com/charmbracelet/crush/internal/logging"
"google.golang.org/genai"
@@ -11,14 +10,11 @@ import (
type VertexAIClient ProviderClient
func newVertexAIClient(opts providerClientOptions) VertexAIClient {
- geminiOpts := geminiOptions{}
- for _, o := range opts.geminiOptions {
- o(&geminiOpts)
- }
-
+ project := opts.extraHeaders["project"]
+ location := opts.extraHeaders["location"]
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
- Project: os.Getenv("VERTEXAI_PROJECT"),
- Location: os.Getenv("VERTEXAI_LOCATION"),
+ Project: project,
+ Location: location,
Backend: genai.BackendVertexAI,
})
if err != nil {
@@ -28,7 +24,6 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient {
return &geminiClient{
providerOptions: opts,
- options: geminiOpts,
client: client,
}
}
diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go
index abbd19113db746cd8e82c5cdebc02c4b8fc28b99..5f8b41338c8c5ef6f771e80fbd4e1355b27eb036 100644
--- a/internal/llm/tools/bash.go
+++ b/internal/llm/tools/bash.go
@@ -250,6 +250,10 @@ func NewBashTool(permission permission.Service) BaseTool {
}
}
+func (b *bashTool) Name() string {
+ return BashToolName
+}
+
func (b *bashTool) Info() ToolInfo {
return ToolInfo{
Name: BashToolName,
diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go
index e48132414ad338635ada6294134e3d7f2d938c7d..5e0d19f9e5b3f72ea7588e4c4a15b71ae7b262d9 100644
--- a/internal/llm/tools/diagnostics.go
+++ b/internal/llm/tools/diagnostics.go
@@ -51,6 +51,10 @@ func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
}
}
+func (b *diagnosticsTool) Name() string {
+ return DiagnosticsToolName
+}
+
func (b *diagnosticsTool) Info() ToolInfo {
return ToolInfo{
Name: DiagnosticsToolName,
diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go
index 52a202e32cd95869e3f76ebe6e3e01ca2f4802cc..b72112f43e140edd7298e802ab88ba2747784d7c 100644
--- a/internal/llm/tools/edit.go
+++ b/internal/llm/tools/edit.go
@@ -106,6 +106,10 @@ func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Servi
}
}
+func (e *editTool) Name() string {
+ return EditToolName
+}
+
func (e *editTool) Info() ToolInfo {
return ToolInfo{
Name: EditToolName,
diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go
index 7acf23bae61df88792dd805317bdf8a67095dd0d..ac73ddbf3b0033cf503bdc8cfa2ef065a0072477 100644
--- a/internal/llm/tools/fetch.go
+++ b/internal/llm/tools/fetch.go
@@ -79,6 +79,10 @@ func NewFetchTool(permissions permission.Service) BaseTool {
}
}
+func (t *fetchTool) Name() string {
+ return FetchToolName
+}
+
func (t *fetchTool) Info() ToolInfo {
return ToolInfo{
Name: FetchToolName,
diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go
index 39471a8ae81f9c31c4bccccf273f88044184f6ad..25c80860b791a5b601366d455f5ddd1ea91523ed 100644
--- a/internal/llm/tools/glob.go
+++ b/internal/llm/tools/glob.go
@@ -74,6 +74,10 @@ func NewGlobTool() BaseTool {
return &globTool{}
}
+func (g *globTool) Name() string {
+ return GlobToolName
+}
+
func (g *globTool) Info() ToolInfo {
return ToolInfo{
Name: GlobToolName,
diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go
index 7321ad4ed9a2d713cb8685b3f4b23dbc62cc6c2a..c3e13766884f17932187ad63cb5ffaacdf375b45 100644
--- a/internal/llm/tools/grep.go
+++ b/internal/llm/tools/grep.go
@@ -140,6 +140,10 @@ func NewGrepTool() BaseTool {
return &grepTool{}
}
+func (g *grepTool) Name() string {
+ return GrepToolName
+}
+
func (g *grepTool) Info() ToolInfo {
return ToolInfo{
Name: GrepToolName,
diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go
index aeda758a66bc9dc796dc8d4e3722f8711aa39f67..a51b5bdb5dccb7c209d9cdc28e94dad328e8c093 100644
--- a/internal/llm/tools/ls.go
+++ b/internal/llm/tools/ls.go
@@ -74,6 +74,10 @@ func NewLsTool() BaseTool {
return &lsTool{}
}
+func (l *lsTool) Name() string {
+ return LSToolName
+}
+
func (l *lsTool) Info() ToolInfo {
return ToolInfo{
Name: LSToolName,
diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go
index 29518b7b818da5746d195ea8b7da521d80429962..fcc9bb57428b45a4620417775c61acb380b4d7ad 100644
--- a/internal/llm/tools/sourcegraph.go
+++ b/internal/llm/tools/sourcegraph.go
@@ -138,6 +138,10 @@ func NewSourcegraphTool() BaseTool {
}
}
+func (t *sourcegraphTool) Name() string {
+ return SourcegraphToolName
+}
+
func (t *sourcegraphTool) Info() ToolInfo {
return ToolInfo{
Name: SourcegraphToolName,
diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go
index bf0f8df0bacdc34fc6a3f5f89f3591b8b6fd8f15..41c0515616032b117f3c09a0056cac9e86b62c66 100644
--- a/internal/llm/tools/tools.go
+++ b/internal/llm/tools/tools.go
@@ -68,6 +68,7 @@ type ToolCall struct {
type BaseTool interface {
Info() ToolInfo
+ Name() string
Run(ctx context.Context, params ToolCall) (ToolResponse, error)
}
diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go
index 7d9cdaefcb98e0d6289396c9c44d45c4aa0b03c5..750efef73795f115e3ad90e4da9a2d955ee10529 100644
--- a/internal/llm/tools/view.go
+++ b/internal/llm/tools/view.go
@@ -77,6 +77,10 @@ func NewViewTool(lspClients map[string]*lsp.Client) BaseTool {
}
}
+func (v *viewTool) Name() string {
+ return ViewToolName
+}
+
func (v *viewTool) Info() ToolInfo {
return ToolInfo{
Name: ViewToolName,
diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go
index 4186954d03acc4768233435c71c2ac5d78606119..0c213cec1f4e0a9bc8fc205a183206c0842f9688 100644
--- a/internal/llm/tools/write.go
+++ b/internal/llm/tools/write.go
@@ -84,6 +84,10 @@ func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Serv
}
}
+func (w *writeTool) Name() string {
+ return WriteToolName
+}
+
func (w *writeTool) Info() ToolInfo {
return ToolInfo{
Name: WriteToolName,
diff --git a/internal/logging/logger.go b/internal/logging/logger.go
index 9c2cfb50f33d27d52b9acb3009859f3509484253..ef33ba299855f24f0b48ae5af9ed80a069ba7a31 100644
--- a/internal/logging/logger.go
+++ b/internal/logging/logger.go
@@ -4,16 +4,35 @@ import (
"fmt"
"log/slog"
"os"
+
+ // "path/filepath"
+ "encoding/json"
+ "runtime"
"runtime/debug"
+ "sync"
"time"
)
+func getCaller() string {
+ var caller string
+ if _, file, line, ok := runtime.Caller(2); ok {
+ // caller = fmt.Sprintf("%s:%d", filepath.Base(file), line)
+ caller = fmt.Sprintf("%s:%d", file, line)
+ } else {
+ caller = "unknown"
+ }
+ return caller
+}
+
func Info(msg string, args ...any) {
- slog.Info(msg, args...)
+ source := getCaller()
+ slog.Info(msg, append([]any{"source", source}, args...)...)
}
func Debug(msg string, args ...any) {
- slog.Debug(msg, args...)
+ // slog.Debug(msg, args...)
+ source := getCaller()
+ slog.Debug(msg, append([]any{"source", source}, args...)...)
}
func Warn(msg string, args ...any) {
@@ -76,3 +95,115 @@ func RecoverPanic(name string, cleanup func()) {
}
}
}
+
+// Message Logging for Debug
+var MessageDir string
+
+func GetSessionPrefix(sessionId string) string {
+ return sessionId[:8]
+}
+
+var sessionLogMutex sync.Mutex
+
+func AppendToSessionLogFile(sessionId string, filename string, content string) string {
+ if MessageDir == "" || sessionId == "" {
+ return ""
+ }
+ sessionPrefix := GetSessionPrefix(sessionId)
+
+ sessionLogMutex.Lock()
+ defer sessionLogMutex.Unlock()
+
+ sessionPath := fmt.Sprintf("%s/%s", MessageDir, sessionPrefix)
+ if _, err := os.Stat(sessionPath); os.IsNotExist(err) {
+ if err := os.MkdirAll(sessionPath, 0o766); err != nil {
+ Error("Failed to create session directory", "dirpath", sessionPath, "error", err)
+ return ""
+ }
+ }
+
+ filePath := fmt.Sprintf("%s/%s", sessionPath, filename)
+
+ f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
+ if err != nil {
+ Error("Failed to open session log file", "filepath", filePath, "error", err)
+ return ""
+ }
+ defer f.Close()
+
+ // Append chunk to file
+ _, err = f.WriteString(content)
+ if err != nil {
+ Error("Failed to write chunk to session log file", "filepath", filePath, "error", err)
+ return ""
+ }
+ return filePath
+}
+
+func WriteRequestMessageJson(sessionId string, requestSeqId int, message any) string {
+ if MessageDir == "" || sessionId == "" || requestSeqId <= 0 {
+ return ""
+ }
+ msgJson, err := json.Marshal(message)
+ if err != nil {
+ Error("Failed to marshal message", "session_id", sessionId, "request_seq_id", requestSeqId, "error", err)
+ return ""
+ }
+ return WriteRequestMessage(sessionId, requestSeqId, string(msgJson))
+}
+
+func WriteRequestMessage(sessionId string, requestSeqId int, message string) string {
+ if MessageDir == "" || sessionId == "" || requestSeqId <= 0 {
+ return ""
+ }
+ filename := fmt.Sprintf("%d_request.json", requestSeqId)
+
+ return AppendToSessionLogFile(sessionId, filename, message)
+}
+
+func AppendToStreamSessionLogJson(sessionId string, requestSeqId int, jsonableChunk any) string {
+ if MessageDir == "" || sessionId == "" || requestSeqId <= 0 {
+ return ""
+ }
+ chunkJson, err := json.Marshal(jsonableChunk)
+ if err != nil {
+ Error("Failed to marshal message", "session_id", sessionId, "request_seq_id", requestSeqId, "error", err)
+ return ""
+ }
+ return AppendToStreamSessionLog(sessionId, requestSeqId, string(chunkJson))
+}
+
+func AppendToStreamSessionLog(sessionId string, requestSeqId int, chunk string) string {
+ if MessageDir == "" || sessionId == "" || requestSeqId <= 0 {
+ return ""
+ }
+ filename := fmt.Sprintf("%d_response_stream.log", requestSeqId)
+ return AppendToSessionLogFile(sessionId, filename, chunk)
+}
+
+func WriteChatResponseJson(sessionId string, requestSeqId int, response any) string {
+ if MessageDir == "" || sessionId == "" || requestSeqId <= 0 {
+ return ""
+ }
+ responseJson, err := json.Marshal(response)
+ if err != nil {
+ Error("Failed to marshal response", "session_id", sessionId, "request_seq_id", requestSeqId, "error", err)
+ return ""
+ }
+ filename := fmt.Sprintf("%d_response.json", requestSeqId)
+
+ return AppendToSessionLogFile(sessionId, filename, string(responseJson))
+}
+
+func WriteToolResultsJson(sessionId string, requestSeqId int, toolResults any) string {
+ if MessageDir == "" || sessionId == "" || requestSeqId <= 0 {
+ return ""
+ }
+ toolResultsJson, err := json.Marshal(toolResults)
+ if err != nil {
+ Error("Failed to marshal tool results", "session_id", sessionId, "request_seq_id", requestSeqId, "error", err)
+ return ""
+ }
+ filename := fmt.Sprintf("%d_tool_results.json", requestSeqId)
+ return AppendToSessionLogFile(sessionId, filename, string(toolResultsJson))
+}
diff --git a/internal/logging/writer.go b/internal/logging/writer.go
index 8775f3752d52f3141e1cf51a11a734c3c6e523b1..e821338658a316ad0ffb6178e42c046addbfd1ab 100644
--- a/internal/logging/writer.go
+++ b/internal/logging/writer.go
@@ -45,6 +45,7 @@ type writer struct{}
func (w *writer) Write(p []byte) (int, error) {
d := logfmt.NewDecoder(bytes.NewReader(p))
+
for d.ScanRecord() {
msg := LogMessage{
ID: fmt.Sprintf("%d", time.Now().UnixNano()),
diff --git a/internal/lsp/client.go b/internal/lsp/client.go
index f65b3dee20a3ee0264742257ca78a116661f1165..24ff0238c355edb5499640b93f9e06f0f07568c9 100644
--- a/internal/lsp/client.go
+++ b/internal/lsp/client.go
@@ -96,10 +96,10 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
- fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
+ logging.Error("LSP Server", "err", scanner.Text())
}
if err := scanner.Err(); err != nil {
- fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
+ logging.Error("Error reading", "err", err)
}
}()
@@ -286,7 +286,7 @@ func (c *Client) SetServerState(state ServerState) {
// WaitForServerReady waits for the server to be ready by polling the server
// with a simple request until it responds successfully or times out
func (c *Client) WaitForServerReady(ctx context.Context) error {
- cnf := config.Get()
+ cfg := config.Get()
// Set initial state
c.SetServerState(StateStarting)
@@ -299,7 +299,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Waiting for LSP server to be ready...")
}
@@ -308,7 +308,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
// For TypeScript-like servers, we need to open some key files first
if serverType == ServerTypeTypeScript {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("TypeScript-like server detected, opening key configuration files")
}
c.openKeyConfigFiles(ctx)
@@ -325,7 +325,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
if err == nil {
// Server responded successfully
c.SetServerState(StateReady)
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("LSP server is ready")
}
return nil
@@ -333,7 +333,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
}
}
@@ -496,7 +496,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
// openTypeScriptFiles finds and opens TypeScript files to help initialize the server
func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
- cnf := config.Get()
+ cfg := config.Get()
filesOpened := 0
maxFilesToOpen := 5 // Limit to a reasonable number of files
@@ -526,7 +526,7 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
// Try to open the file
if err := c.OpenFile(ctx, path); err == nil {
filesOpened++
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opened TypeScript file for initialization", "file", path)
}
}
@@ -535,11 +535,11 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
return nil
})
- if err != nil && cnf.DebugLSP {
+ if err != nil && cfg.Options.DebugLSP {
logging.Debug("Error walking directory for TypeScript files", "error", err)
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opened TypeScript files for initialization", "count", filesOpened)
}
}
@@ -664,7 +664,7 @@ func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
}
func (c *Client) CloseFile(ctx context.Context, filepath string) error {
- cnf := config.Get()
+ cfg := config.Get()
uri := string(protocol.URIFromPath(filepath))
c.openFilesMu.Lock()
@@ -680,7 +680,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
},
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Closing file", "file", filepath)
}
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
@@ -704,7 +704,7 @@ func (c *Client) IsFileOpen(filepath string) bool {
// CloseAllFiles closes all currently open files
func (c *Client) CloseAllFiles(ctx context.Context) {
- cnf := config.Get()
+ cfg := config.Get()
c.openFilesMu.Lock()
filesToClose := make([]string, 0, len(c.openFiles))
@@ -719,12 +719,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
// Then close them all
for _, filePath := range filesToClose {
err := c.CloseFile(ctx, filePath)
- if err != nil && cnf.DebugLSP {
+ if err != nil && cfg.Options.DebugLSP {
logging.Warn("Error closing file", "file", filePath, "error", err)
}
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Closed all files", "files", filesToClose)
}
}
diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go
index 9eb258d761ee36a909cddec16b72b2a3d933a5b4..f2fbfd0a589651590185fe9f73fc222e5bd6b08d 100644
--- a/internal/lsp/handlers.go
+++ b/internal/lsp/handlers.go
@@ -82,13 +82,13 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche
// Notifications
func HandleServerMessage(params json.RawMessage) {
- cnf := config.Get()
+ cfg := config.Get()
var msg struct {
Type int `json:"type"`
Message string `json:"message"`
}
if err := json.Unmarshal(params, &msg); err == nil {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
}
}
diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go
index c3d5d762feeccaaa363a189fd8014b705a583681..5433fb552d6ee3dae390dcf74e3e1d9c8b0d74f9 100644
--- a/internal/lsp/transport.go
+++ b/internal/lsp/transport.go
@@ -18,9 +18,9 @@ func WriteMessage(w io.Writer, msg *Message) error {
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
- cnf := config.Get()
+ cfg := config.Get()
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
}
@@ -39,7 +39,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
// ReadMessage reads a single LSP message from the given reader
func ReadMessage(r *bufio.Reader) (*Message, error) {
- cnf := config.Get()
+ cfg := config.Get()
// Read headers
var contentLength int
for {
@@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
line = strings.TrimSpace(line)
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received header", "line", line)
}
@@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Content-Length", "length", contentLength)
}
@@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
return nil, fmt.Errorf("failed to read content: %w", err)
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received content", "content", string(content))
}
@@ -91,11 +91,11 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
// handleMessages reads and dispatches messages in a loop
func (c *Client) handleMessages() {
- cnf := config.Get()
+ cfg := config.Get()
for {
msg, err := ReadMessage(c.stdout)
if err != nil {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Error("Error reading message", "error", err)
}
return
@@ -103,7 +103,7 @@ func (c *Client) handleMessages() {
// Handle server->client request (has both Method and ID)
if msg.Method != "" && msg.ID != 0 {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
}
@@ -157,11 +157,11 @@ func (c *Client) handleMessages() {
c.notificationMu.RUnlock()
if ok {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Handling notification", "method", msg.Method)
}
go handler(msg.Params)
- } else if cnf.DebugLSP {
+ } else if cfg.Options.DebugLSP {
logging.Debug("No handler for notification", "method", msg.Method)
}
continue
@@ -174,12 +174,12 @@ func (c *Client) handleMessages() {
c.handlersMu.RUnlock()
if ok {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received response for request", "id", msg.ID)
}
ch <- msg
close(ch)
- } else if cnf.DebugLSP {
+ } else if cfg.Options.DebugLSP {
logging.Debug("No handler for response", "id", msg.ID)
}
}
@@ -188,10 +188,10 @@ func (c *Client) handleMessages() {
// Call makes a request and waits for the response
func (c *Client) Call(ctx context.Context, method string, params any, result any) error {
- cnf := config.Get()
+ cfg := config.Get()
id := c.nextID.Add(1)
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Making call", "method", method, "id", id)
}
@@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
return fmt.Errorf("failed to send request: %w", err)
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Request sent", "method", method, "id", id)
}
// Wait for response
resp := <-ch
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received response", "id", id)
}
@@ -249,8 +249,8 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
// Notify sends a notification (a request without an ID that doesn't expect a response)
func (c *Client) Notify(ctx context.Context, method string, params any) error {
- cnf := config.Get()
- if cnf.DebugLSP {
+ cfg := config.Get()
+ if cfg.Options.DebugLSP {
logging.Debug("Sending notification", "method", method)
}
diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go
index a69b3c10577d0c89ffb8aa9972a928201e2124f6..3c2dc05909bd8e6d473696efd09d22435f68dc10 100644
--- a/internal/lsp/watcher/watcher.go
+++ b/internal/lsp/watcher/watcher.go
@@ -43,7 +43,7 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher {
// AddRegistrations adds file watchers to track
func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) {
- cnf := config.Get()
+ cfg := config.Get()
logging.Debug("Adding file watcher registrations")
w.registrationMu.Lock()
@@ -53,7 +53,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
w.registrations = append(w.registrations, watchers...)
// Print detailed registration information for debugging
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Adding file watcher registrations",
"id", id,
"watchers", len(watchers),
@@ -122,7 +122,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
highPriorityFilesOpened := w.openHighPriorityFiles(ctx, serverName)
filesOpened += highPriorityFilesOpened
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opened high-priority files",
"count", highPriorityFilesOpened,
"serverName", serverName)
@@ -130,7 +130,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// If we've already opened enough high-priority files, we might not need more
if filesOpened >= maxFilesToOpen {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Reached file limit with high-priority files",
"filesOpened", filesOpened,
"maxFiles", maxFilesToOpen)
@@ -148,7 +148,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// Skip directories that should be excluded
if d.IsDir() {
if path != w.workspacePath && shouldExcludeDir(path) {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
@@ -176,7 +176,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
})
elapsedTime := time.Since(startTime)
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Limited workspace scan complete",
"filesOpened", filesOpened,
"maxFiles", maxFilesToOpen,
@@ -185,11 +185,11 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
)
}
- if err != nil && cnf.DebugLSP {
+ if err != nil && cfg.Options.DebugLSP {
logging.Debug("Error scanning workspace for files to open", "error", err)
}
}()
- } else if cnf.DebugLSP {
+ } else if cfg.Options.DebugLSP {
logging.Debug("Using on-demand file loading for server", "server", serverName)
}
}
@@ -197,7 +197,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// openHighPriorityFiles opens important files for the server type
// Returns the number of files opened
func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName string) int {
- cnf := config.Get()
+ cfg := config.Get()
filesOpened := 0
// Define patterns for high-priority files based on server type
@@ -265,7 +265,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
// Use doublestar.Glob to find files matching the pattern (supports ** patterns)
matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern)
if err != nil {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err)
}
continue
@@ -299,12 +299,12 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
for j := i; j < end; j++ {
fullPath := filesToOpen[j]
if err := w.client.OpenFile(ctx, fullPath); err != nil {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Error opening high-priority file", "path", fullPath, "error", err)
}
} else {
filesOpened++
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opened high-priority file", "path", fullPath)
}
}
@@ -321,7 +321,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
// WatchWorkspace sets up file watching for a workspace
func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath string) {
- cnf := config.Get()
+ cfg := config.Get()
w.workspacePath = workspacePath
// Store the watcher in the context for later use
@@ -356,7 +356,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
// Skip excluded directories (except workspace root)
if d.IsDir() && path != workspacePath {
if shouldExcludeDir(path) {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
@@ -409,7 +409,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
}
// Debug logging
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
matched, kind := w.isPathWatched(event.Name)
logging.Debug("File event",
"path", event.Name,
@@ -676,8 +676,8 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
// notifyFileEvent sends a didChangeWatchedFiles notification for a file event
func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
- cnf := config.Get()
- if cnf.DebugLSP {
+ cfg := config.Get()
+ if cfg.Options.DebugLSP {
logging.Debug("Notifying file event",
"uri", uri,
"changeType", changeType,
@@ -826,7 +826,7 @@ func shouldExcludeDir(dirPath string) bool {
// shouldExcludeFile returns true if the file should be excluded from opening
func shouldExcludeFile(filePath string) bool {
fileName := filepath.Base(filePath)
- cnf := config.Get()
+ cfg := config.Get()
// Skip dot files
if strings.HasPrefix(fileName, ".") {
return true
@@ -852,12 +852,12 @@ func shouldExcludeFile(filePath string) bool {
// Skip large files
if info.Size() > maxFileSize {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Skipping large file",
"path", filePath,
"size", info.Size(),
"maxSize", maxFileSize,
- "debug", cnf.Debug,
+ "debug", cfg.Options.Debug,
"sizeMB", float64(info.Size())/(1024*1024),
"maxSizeMB", float64(maxFileSize)/(1024*1024),
)
@@ -870,7 +870,7 @@ func shouldExcludeFile(filePath string) bool {
// openMatchingFile opens a file if it matches any of the registered patterns
func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
- cnf := config.Get()
+ cfg := config.Get()
// Skip directories
info, err := os.Stat(path)
if err != nil || info.IsDir() {
@@ -890,10 +890,10 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// Check if the file is a high-priority file that should be opened immediately
// This helps with project initialization for certain language servers
if isHighPriorityFile(path, serverName) {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opening high-priority file", "path", path, "serverName", serverName)
}
- if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
+ if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
logging.Error("Error opening high-priority file", "path", path, "error", err)
}
return
@@ -905,7 +905,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// Check file size - for preloading we're more conservative
if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
}
return
@@ -937,7 +937,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
if shouldOpen {
// Don't need to check if it's already open - the client.OpenFile handles that
- if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
+ if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
logging.Error("Error opening file", "path", path, "error", err)
}
}
diff --git a/internal/message/content.go b/internal/message/content.go
index 383134b596e62a5fc18b2c8404d770fc6a2d4112..3ab53e381aaf7755c141985ebe740dbc44356471 100644
--- a/internal/message/content.go
+++ b/internal/message/content.go
@@ -5,7 +5,7 @@ import (
"slices"
"time"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
)
type MessageRole string
@@ -71,9 +71,9 @@ type BinaryContent struct {
Data []byte
}
-func (bc BinaryContent) String(provider models.ModelProvider) string {
+func (bc BinaryContent) String(p provider.InferenceProvider) string {
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
- if provider == models.ProviderOpenAI {
+ if p == provider.InferenceProviderOpenAI {
return "data:" + bc.MIMEType + ";base64," + base64Encoded
}
return base64Encoded
@@ -113,7 +113,8 @@ type Message struct {
Role MessageRole
SessionID string
Parts []ContentPart
- Model models.ModelID
+ Model string
+ Provider string
CreatedAt int64
UpdatedAt int64
}
diff --git a/internal/message/message.go b/internal/message/message.go
index 9e241a0b011ee6277402709fdd8be3aefb5df6fe..7cd823bc3129df5f807ec478d9d6c02364c6cfec 100644
--- a/internal/message/message.go
+++ b/internal/message/message.go
@@ -8,15 +8,15 @@ import (
"time"
"github.com/charmbracelet/crush/internal/db"
- "github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
type CreateMessageParams struct {
- Role MessageRole
- Parts []ContentPart
- Model models.ModelID
+ Role MessageRole
+ Parts []ContentPart
+ Model string
+ Provider string
}
type Service interface {
@@ -70,6 +70,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
Role: string(params.Role),
Parts: string(partsJSON),
Model: sql.NullString{String: string(params.Model), Valid: true},
+ Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
})
if err != nil {
return Message{}, err
@@ -154,7 +155,8 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
SessionID: item.SessionID,
Role: MessageRole(item.Role),
Parts: parts,
- Model: models.ModelID(item.Model.String),
+ Model: item.Model.String,
+ Provider: item.Provider.String,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
}, nil
diff --git a/internal/tui/components/chat/header/header.go b/internal/tui/components/chat/header/header.go
index 78620161a75a3ade2e0e2416351c50699ac8bd4d..d924bdc3453dc3fce0351c490cb17b726fcc2549 100644
--- a/internal/tui/components/chat/header/header.go
+++ b/internal/tui/components/chat/header/header.go
@@ -7,7 +7,6 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fsext"
- "github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -112,11 +111,7 @@ func (h *header) details() string {
parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount)))
}
- cfg := config.Get()
- agentCfg := cfg.Agents[config.AgentCoder]
- selectedModelID := agentCfg.Model
- model := models.SupportedModels[selectedModelID]
-
+ model := config.GetAgentModel(config.AgentCoder)
percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100
formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage)))
parts = append(parts, formattedPercentage)
diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go
index d5e95b4e3ebded500f73840fda483d3be53ca71d..98d8b2979a90f46fa5901bc77d1e8b4a5105f04d 100644
--- a/internal/tui/components/chat/messages/messages.go
+++ b/internal/tui/components/chat/messages/messages.go
@@ -10,7 +10,8 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/lipgloss/v2"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/tui/components/anim"
"github.com/charmbracelet/crush/internal/tui/components/core"
@@ -290,8 +291,9 @@ func (m *assistantSectionModel) View() tea.View {
duration := finishTime.Sub(m.lastUserMessageTime)
infoMsg := t.S().Subtle.Render(duration.String())
icon := t.S().Subtle.Render(styles.ModelIcon)
- model := t.S().Muted.Render(models.SupportedModels[m.message.Model].Name)
- assistant := fmt.Sprintf("%s %s %s", icon, model, infoMsg)
+ model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model)
+ modelFormatted := t.S().Muted.Render(model.Name)
+ assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
return tea.NewView(
t.S().Base.PaddingLeft(2).Render(
core.Section(assistant, m.width-2),
diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go
index afd067adc06d99c3c9da911812750631423231e6..bfcc74c43a2727138d479af647ba461bdcc7520c 100644
--- a/internal/tui/components/chat/sidebar/sidebar.go
+++ b/internal/tui/components/chat/sidebar/sidebar.go
@@ -13,7 +13,6 @@ import (
"github.com/charmbracelet/crush/internal/diff"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
@@ -76,7 +75,7 @@ func NewSidebarCmp(history history.Service, lspClients map[string]*lsp.Client, c
}
func (m *sidebarCmp) Init() tea.Cmd {
- m.logo = m.logoBlock(false)
+ m.logo = m.logoBlock()
m.cwd = cwd()
return nil
}
@@ -231,9 +230,9 @@ func (m *sidebarCmp) loadSessionFiles() tea.Msg {
func (m *sidebarCmp) SetSize(width, height int) tea.Cmd {
if width < logoBreakpoint && (m.width == 0 || m.width >= logoBreakpoint) {
- m.logo = m.logoBlock(true)
+ m.logo = m.logoBlock()
} else if width >= logoBreakpoint && (m.width == 0 || m.width < logoBreakpoint) {
- m.logo = m.logoBlock(false)
+ m.logo = m.logoBlock()
}
m.width = width
@@ -245,9 +244,9 @@ func (m *sidebarCmp) GetSize() (int, int) {
return m.width, m.height
}
-func (m *sidebarCmp) logoBlock(compact bool) string {
+func (m *sidebarCmp) logoBlock() string {
t := styles.CurrentTheme()
- return logo.Render(version.Version, compact, logo.Opts{
+ return logo.Render(version.Version, true, logo.Opts{
FieldColor: t.Primary,
TitleColorA: t.Secondary,
TitleColorB: t.Primary,
@@ -406,7 +405,7 @@ func (m *sidebarCmp) mcpBlock() string {
mcpList := []string{section, ""}
- mcp := config.Get().MCPServers
+ mcp := config.Get().MCP
if len(mcp) == 0 {
return lipgloss.JoinVertical(
lipgloss.Left,
@@ -475,10 +474,7 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string {
}
func (s *sidebarCmp) currentModelBlock() string {
- cfg := config.Get()
- agentCfg := cfg.Agents[config.AgentCoder]
- selectedModelID := agentCfg.Model
- model := models.SupportedModels[selectedModelID]
+ model := config.GetAgentModel(config.AgentCoder)
t := styles.CurrentTheme()
diff --git a/internal/tui/components/chat/splash/keys.go b/internal/tui/components/chat/splash/keys.go
new file mode 100644
index 0000000000000000000000000000000000000000..df715c89e86971a0f788915737bf41a212c65b5a
--- /dev/null
+++ b/internal/tui/components/chat/splash/keys.go
@@ -0,0 +1,18 @@
+package splash
+
+import (
+ "github.com/charmbracelet/bubbles/v2/key"
+)
+
+type KeyMap struct {
+ Cancel key.Binding
+}
+
+func DefaultKeyMap() KeyMap {
+ return KeyMap{
+ Cancel: key.NewBinding(
+ key.WithKeys("esc"),
+ key.WithHelp("esc", "cancel"),
+ ),
+ }
+}
diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go
new file mode 100644
index 0000000000000000000000000000000000000000..75718b25471088607a66a19c7b8a56d36bd5d2d1
--- /dev/null
+++ b/internal/tui/components/chat/splash/splash.go
@@ -0,0 +1,85 @@
+package splash
+
+import (
+ "github.com/charmbracelet/bubbles/v2/key"
+ tea "github.com/charmbracelet/bubbletea/v2"
+ "github.com/charmbracelet/crush/internal/tui/components/core/layout"
+ "github.com/charmbracelet/crush/internal/tui/components/logo"
+ "github.com/charmbracelet/crush/internal/tui/styles"
+ "github.com/charmbracelet/crush/internal/tui/util"
+ "github.com/charmbracelet/crush/internal/version"
+ "github.com/charmbracelet/lipgloss/v2"
+)
+
+type Splash interface {
+ util.Model
+ layout.Sizeable
+ layout.Help
+}
+
+type splashCmp struct {
+ width, height int
+ keyMap KeyMap
+ logoRendered string
+}
+
+func New() Splash {
+ return &splashCmp{
+ width: 0,
+ height: 0,
+ keyMap: DefaultKeyMap(),
+ logoRendered: "",
+ }
+}
+
+// GetSize implements SplashPage.
+func (s *splashCmp) GetSize() (int, int) {
+ return s.width, s.height
+}
+
+// Init implements SplashPage.
+func (s *splashCmp) Init() tea.Cmd {
+ return nil
+}
+
+// SetSize implements SplashPage.
+func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
+ s.width = width
+ s.height = height
+ s.logoRendered = s.logoBlock()
+ return nil
+}
+
+// Update implements SplashPage.
+func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.WindowSizeMsg:
+ return s, s.SetSize(msg.Width, msg.Height)
+ }
+ return s, nil
+}
+
+// View implements SplashPage.
+func (s *splashCmp) View() tea.View {
+ content := lipgloss.JoinVertical(lipgloss.Left, s.logoRendered)
+ return tea.NewView(content)
+}
+
+func (m *splashCmp) logoBlock() string {
+ t := styles.CurrentTheme()
+ return logo.Render(version.Version, false, logo.Opts{
+ FieldColor: t.Primary,
+ TitleColorA: t.Secondary,
+ TitleColorB: t.Primary,
+ CharmColor: t.Secondary,
+ VersionColor: t.Primary,
+ Width: m.width - 2, // -2 for padding
+ })
+}
+
+// Bindings implements SplashPage.
+func (s *splashCmp) Bindings() []key.Binding {
+ return []key.Binding{
+ s.keyMap.Cancel,
+ }
+}
diff --git a/internal/tui/components/dialogs/commands/loader.go b/internal/tui/components/dialogs/commands/loader.go
index 9f70afa3cd60342028b6d3fd00e017221c179686..9aee528ee48d0f23e48c417f8bee5bc0e3f381c5 100644
--- a/internal/tui/components/dialogs/commands/loader.go
+++ b/internal/tui/components/dialogs/commands/loader.go
@@ -63,7 +63,7 @@ func buildCommandSources(cfg *config.Config) []commandSource {
// Project directory
sources = append(sources, commandSource{
- path: filepath.Join(cfg.Data.Directory, "commands"),
+ path: filepath.Join(cfg.Options.DataDirectory, "commands"),
prefix: ProjectCommandPrefix,
})
diff --git a/internal/tui/components/dialogs/models/keys.go b/internal/tui/components/dialogs/models/keys.go
index f9c3c549c30c6d95282e88818c247ec80a2f0e4c..bb70785172bac66d9fda905172572c881b2ecd35 100644
--- a/internal/tui/components/dialogs/models/keys.go
+++ b/internal/tui/components/dialogs/models/keys.go
@@ -8,13 +8,14 @@ type KeyMap struct {
Select,
Next,
Previous,
+ Tab,
Close key.Binding
}
func DefaultKeyMap() KeyMap {
return KeyMap{
Select: key.NewBinding(
- key.WithKeys("enter", "tab", "ctrl+y"),
+ key.WithKeys("enter", "ctrl+y"),
key.WithHelp("enter", "confirm"),
),
Next: key.NewBinding(
@@ -25,6 +26,10 @@ func DefaultKeyMap() KeyMap {
key.WithKeys("up", "ctrl+p"),
key.WithHelp("↑", "previous item"),
),
+ Tab: key.NewBinding(
+ key.WithKeys("tab"),
+ key.WithHelp("tab", "toggle type"),
+ ),
Close: key.NewBinding(
key.WithKeys("esc"),
key.WithHelp("esc", "cancel"),
@@ -38,6 +43,7 @@ func (k KeyMap) KeyBindings() []key.Binding {
k.Select,
k.Next,
k.Previous,
+ k.Tab,
k.Close,
}
}
@@ -57,10 +63,10 @@ func (k KeyMap) FullHelp() [][]key.Binding {
func (k KeyMap) ShortHelp() []key.Binding {
return []key.Binding{
key.NewBinding(
-
key.WithKeys("down", "up"),
key.WithHelp("↑↓", "choose"),
),
+ k.Tab,
k.Select,
k.Close,
}
diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go
index 906d87a9dfe65c1ec09bd5abaf4f9d6865545038..aa7a505bd19af72a55e134fc0b077085a761faa6 100644
--- a/internal/tui/components/dialogs/models/models.go
+++ b/internal/tui/components/dialogs/models/models.go
@@ -7,7 +7,7 @@ import (
"github.com/charmbracelet/bubbles/v2/key"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/tui/components/completions"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/core/list"
@@ -24,9 +24,15 @@ const (
defaultWidth = 60
)
+const (
+ LargeModelType int = iota
+ SmallModelType
+)
+
// ModelSelectedMsg is sent when a model is selected
type ModelSelectedMsg struct {
- Model models.Model
+ Model config.PreferredModel
+ ModelType config.ModelType
}
// CloseModelDialogMsg is sent when a model is selected
@@ -37,14 +43,20 @@ type ModelDialog interface {
dialogs.DialogModel
}
+type ModelOption struct {
+ Provider provider.Provider
+ Model provider.Model
+}
+
type modelDialogCmp struct {
width int
- wWidth int // Width of the terminal window
- wHeight int // Height of the terminal window
+ wWidth int
+ wHeight int
modelList list.ListModel
keyMap KeyMap
help help.Model
+ modelType int
}
func NewModelDialogCmp() ModelDialog {
@@ -77,49 +89,12 @@ func NewModelDialogCmp() ModelDialog {
width: defaultWidth,
keyMap: DefaultKeyMap(),
help: help,
+ modelType: LargeModelType,
}
}
-var ProviderPopularity = map[models.ModelProvider]int{
- models.ProviderAnthropic: 1,
- models.ProviderOpenAI: 2,
- models.ProviderGemini: 3,
- models.ProviderGROQ: 4,
- models.ProviderOpenRouter: 5,
- models.ProviderBedrock: 6,
- models.ProviderAzure: 7,
- models.ProviderVertexAI: 8,
- models.ProviderXAI: 9,
-}
-
-var ProviderName = map[models.ModelProvider]string{
- models.ProviderAnthropic: "Anthropic",
- models.ProviderOpenAI: "OpenAI",
- models.ProviderGemini: "Gemini",
- models.ProviderGROQ: "Groq",
- models.ProviderOpenRouter: "OpenRouter",
- models.ProviderBedrock: "AWS Bedrock",
- models.ProviderAzure: "Azure",
- models.ProviderVertexAI: "VertexAI",
- models.ProviderXAI: "xAI",
-}
-
func (m *modelDialogCmp) Init() tea.Cmd {
- cfg := config.Get()
- enabledProviders := getEnabledProviders(cfg)
-
- modelItems := []util.Model{}
- for _, provider := range enabledProviders {
- name, ok := ProviderName[provider]
- if !ok {
- name = string(provider) // Fallback to provider ID if name is not defined
- }
- modelItems = append(modelItems, commands.NewItemSection(name))
- for _, model := range getModelsForProvider(provider) {
- modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
- }
- }
- m.modelList.SetItems(modelItems)
+ m.SetModelType(m.modelType)
return m.modelList.Init()
}
@@ -128,21 +103,41 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.WindowSizeMsg:
m.wWidth = msg.Width
m.wHeight = msg.Height
+ m.SetModelType(m.modelType)
return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
case tea.KeyPressMsg:
switch {
case key.Matches(msg, m.keyMap.Select):
selectedItemInx := m.modelList.SelectedIndex()
if selectedItemInx == list.NoSelection {
- return m, nil // No item selected, do nothing
+ return m, nil
}
items := m.modelList.Items()
- selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
+ selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
+
+ var modelType config.ModelType
+ if m.modelType == LargeModelType {
+ modelType = config.LargeModel
+ } else {
+ modelType = config.SmallModel
+ }
return m, tea.Sequence(
util.CmdHandler(dialogs.CloseDialogMsg{}),
- util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
+ util.CmdHandler(ModelSelectedMsg{
+ Model: config.PreferredModel{
+ ModelID: selectedItem.Model.ID,
+ Provider: selectedItem.Provider.ID,
+ },
+ ModelType: modelType,
+ }),
)
+ case key.Matches(msg, m.keyMap.Tab):
+ if m.modelType == LargeModelType {
+ return m, m.SetModelType(SmallModelType)
+ } else {
+ return m, m.SetModelType(LargeModelType)
+ }
case key.Matches(msg, m.keyMap.Close):
return m, util.CmdHandler(dialogs.CloseDialogMsg{})
default:
@@ -157,9 +152,10 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (m *modelDialogCmp) View() tea.View {
t := styles.CurrentTheme()
listView := m.modelList.View()
+ radio := m.modelTypeRadio()
content := lipgloss.JoinVertical(
lipgloss.Left,
- t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
+ t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
listView.String(),
"",
t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
@@ -189,58 +185,6 @@ func (m *modelDialogCmp) listHeight() int {
return min(listHeigh, m.wHeight/2)
}
-func GetSelectedModel(cfg *config.Config) models.Model {
- agentCfg := cfg.Agents[config.AgentCoder]
- selectedModelID := agentCfg.Model
- return models.SupportedModels[selectedModelID]
-}
-
-func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
- var providers []models.ModelProvider
- for providerID, provider := range cfg.Providers {
- if !provider.Disabled {
- providers = append(providers, providerID)
- }
- }
-
- // Sort by provider popularity
- slices.SortFunc(providers, func(a, b models.ModelProvider) int {
- rA := ProviderPopularity[a]
- rB := ProviderPopularity[b]
-
- // models not included in popularity ranking default to last
- if rA == 0 {
- rA = 999
- }
- if rB == 0 {
- rB = 999
- }
- return rA - rB
- })
- return providers
-}
-
-func getModelsForProvider(provider models.ModelProvider) []models.Model {
- var providerModels []models.Model
- for _, model := range models.SupportedModels {
- if model.Provider == provider {
- providerModels = append(providerModels, model)
- }
- }
-
- // reverse alphabetical order (if llm naming was consistent latest would appear first)
- slices.SortFunc(providerModels, func(a, b models.Model) int {
- if a.Name > b.Name {
- return -1
- } else if a.Name < b.Name {
- return 1
- }
- return 0
- })
-
- return providerModels
-}
-
func (m *modelDialogCmp) Position() (int, int) {
row := m.wHeight/4 - 2 // just a bit above the center
col := m.wWidth / 2
@@ -259,3 +203,117 @@ func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
func (m *modelDialogCmp) ID() dialogs.DialogID {
return ModelsDialogID
}
+
+func (m *modelDialogCmp) modelTypeRadio() string {
+ t := styles.CurrentTheme()
+ choices := []string{"Large Task", "Small Task"}
+ iconSelected := "◉"
+ iconUnselected := "○"
+ if m.modelType == LargeModelType {
+ return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
+ }
+ return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
+}
+
+func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
+ m.modelType = modelType
+
+ providers := config.Providers()
+ modelItems := []util.Model{}
+ selectIndex := 0
+
+ cfg := config.Get()
+ var currentModel config.PreferredModel
+ if m.modelType == LargeModelType {
+ currentModel = cfg.Models.Large
+ } else {
+ currentModel = cfg.Models.Small
+ }
+
+ // Create a map to track which providers we've already added
+ addedProviders := make(map[provider.InferenceProvider]bool)
+
+ // First, add any configured providers that are not in the known providers list
+ // These should appear at the top of the list
+ knownProviders := provider.KnownProviders()
+ for providerID, providerConfig := range cfg.Providers {
+ if providerConfig.Disabled {
+ continue
+ }
+
+ // Check if this provider is not in the known providers list
+ if !slices.Contains(knownProviders, providerID) {
+ // Convert config provider to provider.Provider format
+ configProvider := provider.Provider{
+ Name: string(providerID), // Use provider ID as name for unknown providers
+ ID: providerID,
+ Models: make([]provider.Model, len(providerConfig.Models)),
+ }
+
+ // Convert models
+ for i, model := range providerConfig.Models {
+ configProvider.Models[i] = provider.Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ HasReasoningEffort: model.HasReasoningEffort,
+ DefaultReasoningEffort: model.ReasoningEffort,
+ SupportsImages: model.SupportsImages,
+ }
+ }
+
+ // Add this unknown provider to the list
+ name := configProvider.Name
+ if name == "" {
+ name = string(configProvider.ID)
+ }
+ modelItems = append(modelItems, commands.NewItemSection(name))
+ for _, model := range configProvider.Models {
+ modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
+ Provider: configProvider,
+ Model: model,
+ }))
+ if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
+ selectIndex = len(modelItems) - 1 // Set the selected index to the current model
+ }
+ }
+ addedProviders[providerID] = true
+ }
+ }
+
+ // Then add the known providers from the predefined list
+ for _, provider := range providers {
+ // Skip if we already added this provider as an unknown provider
+ if addedProviders[provider.ID] {
+ continue
+ }
+
+ // Check if this provider is configured and not disabled
+ if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
+ continue
+ }
+
+ name := provider.Name
+ if name == "" {
+ name = string(provider.ID)
+ }
+ modelItems = append(modelItems, commands.NewItemSection(name))
+ for _, model := range provider.Models {
+ modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
+ Provider: provider,
+ Model: model,
+ }))
+ if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
+ selectIndex = len(modelItems) - 1 // Set the selected index to the current model
+ }
+ }
+ }
+
+ return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
+}
diff --git a/internal/tui/components/logo/logo.go b/internal/tui/components/logo/logo.go
index 4b044c9dbd45284c72b7d03636d7399555e5f388..9d170ee6c2a0036ad9e4ca8b11c1a373fbb15080 100644
--- a/internal/tui/components/logo/logo.go
+++ b/internal/tui/components/logo/logo.go
@@ -26,6 +26,7 @@ type Opts struct {
TitleColorB color.Color // right gradient ramp point
CharmColor color.Color // Charm™ text color
VersionColor color.Color // Version text color
+ Width int // width of the rendered logo, used for truncation
}
// Render renders the Crush logo. Set the argument to true to render the narrow
@@ -76,7 +77,7 @@ func Render(version string, compact bool, o Opts) string {
}
// Right field.
- const rightWidth = 15
+ rightWidth := max(15, o.Width-crushWidth-leftWidth) // 2 for the gap.
const stepDownAt = 0
rightField := new(strings.Builder)
for i := range fieldHeight {
diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go
index 46ec1ce02d1e4573c04f99f45e1af2a2a3a4a731..c73876aff8a6898809a170fec176def30af81314 100644
--- a/internal/tui/components/logs/details.go
+++ b/internal/tui/components/logs/details.go
@@ -161,7 +161,6 @@ func (i *detailCmp) GetSize() (int, int) {
}
func (i *detailCmp) SetSize(width int, height int) tea.Cmd {
- logging.Info("Setting size for detail component", "width", width, "height", height)
i.width = width
i.height = height
i.viewport.SetWidth(i.width - 4)
diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go
index ffb6debb0f61cb1fcfa7e180b042b3b8325dd2e5..44d623847765175d3c38eb81122fa3d55abc430d 100644
--- a/internal/tui/page/chat/chat.go
+++ b/internal/tui/page/chat/chat.go
@@ -9,7 +9,6 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/tui/components/chat"
@@ -171,14 +170,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
util.CmdHandler(ChatFocusedMsg{Focused: false}),
)
case key.Matches(msg, p.keyMap.AddAttachment):
- cfg := config.Get()
- agentCfg := cfg.Agents[config.AgentCoder]
- selectedModelID := agentCfg.Model
- model := models.SupportedModels[selectedModelID]
- if model.SupportsAttachments {
+ model := config.GetAgentModel(config.AgentCoder)
+ if model.SupportsImages {
return p, util.CmdHandler(OpenFilePickerMsg{})
} else {
- return p, util.ReportWarn("File attachments are not supported by the current model: " + string(selectedModelID))
+ return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name)
}
case key.Matches(msg, p.keyMap.Tab):
if p.session.ID == "" {
diff --git a/internal/tui/tui.go b/internal/tui/tui.go
index c6dee6532993becfbda24d115b8e1e5d05e4fd60..fb77e5a8f30d8f4cd290d3a8d4026694c690a109 100644
--- a/internal/tui/tui.go
+++ b/internal/tui/tui.go
@@ -69,7 +69,7 @@ func (a appModel) Init() tea.Cmd {
// Check if we should show the init dialog
cmds = append(cmds, func() tea.Msg {
- shouldShow, err := config.ShouldShowInitDialog()
+ shouldShow, err := config.ProjectNeedsInitialization()
if err != nil {
return util.InfoMsg{
Type: util.InfoTypeError,
@@ -172,12 +172,19 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Model Switch
case models.ModelSelectedMsg:
- model, err := a.app.CoderAgent.Update(config.AgentCoder, msg.Model.ID)
- if err != nil {
- return a, util.ReportError(err)
+ config.UpdatePreferredModel(msg.ModelType, msg.Model)
+
+ // Update the agent with the new model/provider configuration
+ if err := a.app.UpdateAgentModel(); err != nil {
+ logging.ErrorPersist(fmt.Sprintf("Failed to update agent model: %v", err))
+ return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err))
}
- return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", model.Name))
+ modelTypeName := "large"
+ if msg.ModelType == config.SmallModel {
+ modelTypeName = "small"
+ }
+ return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.ModelID))
// File Picker
case chat.OpenFilePickerMsg:
@@ -221,8 +228,23 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if err == nil {
model := a.app.CoderAgent.Model()
contextWindow := model.ContextWindow
- tokens := session.CompletionTokens + session.PromptTokens
- if (tokens >= int64(float64(contextWindow)*0.95)) && config.Get().AutoCompact {
+ usedTokens := session.CompletionTokens + session.PromptTokens
+ remainingTokens := contextWindow - usedTokens
+
+ // Get effective max tokens for this agent (considering overrides)
+ maxTokens := a.app.CoderAgent.EffectiveMaxTokens()
+
+ // Apply 10% margin to max tokens
+ maxTokensWithMargin := int64(float64(maxTokens) * 1.1)
+
+ // Trigger auto-summarize if remaining tokens < max tokens + 10% margin
+ // Also ensure we have a reasonable minimum threshold to avoid too-frequent summaries
+ minThreshold := int64(1000) // Minimum 1000 tokens remaining before triggering
+ if maxTokensWithMargin < minThreshold {
+ maxTokensWithMargin = minThreshold
+ }
+
+ if remainingTokens < maxTokensWithMargin && !config.Get().Options.DisableAutoSummarize {
// Show compact confirmation dialog
cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{
Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false),