diff --git a/Taskfile.yaml b/Taskfile.yaml index 7e181d84513bc1abc5d5a32807f88a75898633ea..9453eb363513d5b9e0987a637d761eef1bfe63c3 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -44,3 +44,9 @@ tasks: desc: Allocations profile cmds: - go tool pprof -http :6061 'http://localhost:6060/debug/pprof/allocs' + + schema: + desc: Generate JSON schema for configuration + cmds: + - go run cmd/schema/main.go > crush-schema.json + - echo "Generated crush-schema.json" diff --git a/cmd/schema/README.md b/cmd/schema/README.md index 517fdb4d20fb9f2b819051bd72e6c33f5dea2195..c7f0102e590dcd98ce150e506c28f9016ef50bbc 100644 --- a/cmd/schema/README.md +++ b/cmd/schema/README.md @@ -1,64 +1,182 @@ # Crush Configuration Schema Generator -This tool generates a JSON Schema for the Crush configuration file. The schema can be used to validate configuration files and provide autocompletion in editors that support JSON Schema. +This tool automatically generates a JSON Schema for the Crush configuration file by using Go reflection to analyze the configuration structs. The schema provides validation, autocompletion, and documentation for configuration files. + +## Features + +- **Automated Generation**: Uses reflection to automatically generate schemas from Go structs +- **Always Up-to-Date**: Schema stays in sync with code changes automatically +- **Comprehensive**: Includes all configuration options, types, and validation rules +- **Enhanced**: Adds provider enums, model lists, and custom descriptions +- **Extensible**: Easy to add new fields and modify existing ones ## Usage ```bash +# Generate the schema go run cmd/schema/main.go > crush-schema.json + +# Or use the task runner +task schema ``` -This will generate a JSON Schema file that can be used to validate configuration files. +## How It Works + +The generator: + +1. **Reflects on Config Structs**: Analyzes the `config.Config` struct and all related types +2. **Generates Base Schema**: Creates JSON Schema definitions for all struct fields +3. **Enhances with Runtime Data**: Adds provider lists, model enums, and tool lists from the actual codebase +4. **Adds Custom Descriptions**: Provides meaningful descriptions for configuration options +5. **Sets Default Values**: Includes appropriate defaults for optional fields ## Schema Features The generated schema includes: -- All configuration options with descriptions -- Default values where applicable -- Validation for enum values (e.g., model IDs, provider types) -- Required fields -- Type checking +- **Type Safety**: Proper type definitions for all configuration fields +- **Validation**: Required fields, enum constraints, and format validation +- **Documentation**: Descriptions for all configuration options +- **Defaults**: Default values for optional settings +- **Provider Enums**: Current list of supported providers +- **Model Enums**: Available models from all configured providers +- **Tool Lists**: Valid tool names for agent configurations +- **Cross-References**: Proper relationships between different config sections + +## Adding New Configuration Fields + +To add new configuration options: + +1. **Add to Config Structs**: Add the field to the appropriate struct in `internal/config/` +2. **Add JSON Tags**: Include proper JSON tags with field names +3. **Regenerate Schema**: Run the schema generator to update the JSON schema +4. **Update Validation**: Add any custom validation logic if needed + +Example: +```go +type Options struct { + // ... existing fields ... + + // New field with JSON tag and description + NewFeature bool `json:"new_feature,omitempty"` +} +``` + +The schema generator will automatically: +- Detect the new field +- Generate appropriate JSON schema +- Add type information +- Include in validation ## Using the Schema -You can use the generated schema in several ways: +### Editor Integration + +Most modern editors support JSON Schema: + +**VS Code**: Add to your workspace settings: +```json +{ + "json.schemas": [ + { + "fileMatch": ["crush.json", ".crush.json"], + "url": "./crush-schema.json" + } + ] +} +``` -1. **Editor Integration**: Many editors (VS Code, JetBrains IDEs, etc.) support JSON Schema for validation and autocompletion. You can configure your editor to use the generated schema for `.crush.json` files. +**JetBrains IDEs**: Configure in Settings → Languages & Frameworks → Schemas and DTDs → JSON Schema Mappings -2. **Validation Tools**: You can use tools like [jsonschema](https://github.com/Julian/jsonschema) to validate your configuration files against the schema. +### Validation Tools -3. **Documentation**: The schema serves as documentation for the configuration options. +```bash +# Using jsonschema (Python) +pip install jsonschema +jsonschema -i crush.json crush-schema.json -## Example Configuration +# Using ajv-cli (Node.js) +npm install -g ajv-cli +ajv validate -s crush-schema.json -d crush.json +``` -Here's an example configuration that conforms to the schema: +### Configuration Example ```json { - "data": { - "directory": ".crush" + "models": { + "large": { + "model_id": "claude-3-5-sonnet-20241022", + "provider": "anthropic", + "reasoning_effort": "medium", + "max_tokens": 8192 + }, + "small": { + "model_id": "claude-3-5-haiku-20241022", + "provider": "anthropic" + } }, - "debug": false, "providers": { "anthropic": { - "apiKey": "your-api-key" + "id": "anthropic", + "provider_type": "anthropic", + "api_key": "your-api-key", + "disabled": false } }, "agents": { "coder": { - "model": "claude-3.7-sonnet", - "maxTokens": 5000, - "reasoningEffort": "medium" - }, - "task": { - "model": "claude-3.7-sonnet", - "maxTokens": 5000 + "id": "coder", + "name": "Coder", + "model": "large", + "disabled": false }, - "title": { - "model": "claude-3.7-sonnet", - "maxTokens": 80 + "custom-agent": { + "id": "custom-agent", + "name": "Custom Agent", + "description": "A custom agent for specific tasks", + "model": "small", + "allowed_tools": ["glob", "grep", "view"], + "allowed_mcp": { + "filesystem": ["read", "write"] + } + } + }, + "mcp": { + "filesystem": { + "command": "mcp-filesystem", + "args": ["--root", "/workspace"], + "type": "stdio" + } + }, + "lsp": { + "typescript": { + "command": "typescript-language-server", + "args": ["--stdio"], + "enabled": true + } + }, + "options": { + "context_paths": [ + "README.md", + "docs/", + ".cursorrules" + ], + "data_directory": ".crush", + "debug": false, + "tui": { + "compact_mode": false } } } ``` + +## Maintenance + +The schema generator is designed to be maintenance-free. As long as: + +- Configuration structs have proper JSON tags +- New enums are added to the enhancement functions +- The generator is run after significant config changes + +The schema will stay current with the codebase automatically. \ No newline at end of file diff --git a/cmd/schema/main.go b/cmd/schema/main.go index 9eb88769fd84772628df5332d3dcc1b1b234ac90..34939f1e84b9f3df04c7419a9ac4d7dfdc76386a 100644 --- a/cmd/schema/main.go +++ b/cmd/schema/main.go @@ -1,30 +1,70 @@ -// TODO: FIX THIS package main import ( "encoding/json" "fmt" "os" + "reflect" + "slices" + "strings" "github.com/charmbracelet/crush/internal/config" ) -// JSONSchemaType represents a JSON Schema type -type JSONSchemaType struct { - Type string `json:"type,omitempty"` - Description string `json:"description,omitempty"` - Properties map[string]any `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` - AdditionalProperties any `json:"additionalProperties,omitempty"` - Enum []any `json:"enum,omitempty"` - Items map[string]any `json:"items,omitempty"` - OneOf []map[string]any `json:"oneOf,omitempty"` - AnyOf []map[string]any `json:"anyOf,omitempty"` - Default any `json:"default,omitempty"` +// JSONSchema represents a JSON Schema +type JSONSchema struct { + Schema string `json:"$schema,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Type string `json:"type,omitempty"` + Properties map[string]*JSONSchema `json:"properties,omitempty"` + Items *JSONSchema `json:"items,omitempty"` + Required []string `json:"required,omitempty"` + AdditionalProperties any `json:"additionalProperties,omitempty"` + Enum []any `json:"enum,omitempty"` + Default any `json:"default,omitempty"` + Definitions map[string]*JSONSchema `json:"definitions,omitempty"` + Ref string `json:"$ref,omitempty"` + OneOf []*JSONSchema `json:"oneOf,omitempty"` + AnyOf []*JSONSchema `json:"anyOf,omitempty"` + AllOf []*JSONSchema `json:"allOf,omitempty"` + Not *JSONSchema `json:"not,omitempty"` + Format string `json:"format,omitempty"` + Pattern string `json:"pattern,omitempty"` + MinLength *int `json:"minLength,omitempty"` + MaxLength *int `json:"maxLength,omitempty"` + Minimum *float64 `json:"minimum,omitempty"` + Maximum *float64 `json:"maximum,omitempty"` + ExclusiveMinimum *float64 `json:"exclusiveMinimum,omitempty"` + ExclusiveMaximum *float64 `json:"exclusiveMaximum,omitempty"` + MultipleOf *float64 `json:"multipleOf,omitempty"` + MinItems *int `json:"minItems,omitempty"` + MaxItems *int `json:"maxItems,omitempty"` + UniqueItems *bool `json:"uniqueItems,omitempty"` + MinProperties *int `json:"minProperties,omitempty"` + MaxProperties *int `json:"maxProperties,omitempty"` +} + +// SchemaGenerator generates JSON schemas from Go types +type SchemaGenerator struct { + definitions map[string]*JSONSchema + visited map[reflect.Type]bool +} + +// NewSchemaGenerator creates a new schema generator +func NewSchemaGenerator() *SchemaGenerator { + return &SchemaGenerator{ + definitions: make(map[string]*JSONSchema), + visited: make(map[reflect.Type]bool), + } } func main() { - schema := generateSchema() + // Enable mock providers to avoid API calls during schema generation + config.UseMockProviders = true + + generator := NewSchemaGenerator() + schema := generator.GenerateSchema() // Pretty print the schema encoder := json.NewEncoder(os.Stdout) @@ -35,261 +75,457 @@ func main() { } } -func generateSchema() map[string]any { - schema := map[string]any{ - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "Crush Configuration", - "description": "Configuration schema for the Crush application", - "type": "object", - "properties": map[string]any{}, - } - - // Add Data configuration - schema["properties"].(map[string]any)["data"] = map[string]any{ - "type": "object", - "description": "Storage configuration", - "properties": map[string]any{ - "directory": map[string]any{ - "type": "string", - "description": "Directory where application data is stored", - "default": ".crush", - }, - }, - "required": []string{"directory"}, - } - - // Add working directory - schema["properties"].(map[string]any)["wd"] = map[string]any{ - "type": "string", - "description": "Working directory for the application", - } - - // Add debug flags - schema["properties"].(map[string]any)["debug"] = map[string]any{ - "type": "boolean", - "description": "Enable debug mode", - "default": false, - } - - schema["properties"].(map[string]any)["debugLSP"] = map[string]any{ - "type": "boolean", - "description": "Enable LSP debug mode", - "default": false, - } - - schema["properties"].(map[string]any)["contextPaths"] = map[string]any{ - "type": "array", - "description": "Context paths for the application", - "items": map[string]any{ - "type": "string", - }, - "default": []string{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", - }, - } - - schema["properties"].(map[string]any)["tui"] = map[string]any{ - "type": "object", - "description": "Terminal User Interface configuration", - "properties": map[string]any{ - "theme": map[string]any{ - "type": "string", - "description": "TUI theme name", - "default": "crush", - "enum": []string{ - "crush", - "catppuccin", - "dracula", - "flexoki", - "gruvbox", - "monokai", - "onedark", - "tokyonight", - "tron", - }, - }, - }, - } - - // Add MCP servers - schema["properties"].(map[string]any)["mcpServers"] = map[string]any{ - "type": "object", - "description": "Model Control Protocol server configurations", - "additionalProperties": map[string]any{ - "type": "object", - "description": "MCP server configuration", - "properties": map[string]any{ - "command": map[string]any{ - "type": "string", - "description": "Command to execute for the MCP server", - }, - "env": map[string]any{ - "type": "array", - "description": "Environment variables for the MCP server", - "items": map[string]any{ - "type": "string", - }, - }, - "args": map[string]any{ - "type": "array", - "description": "Command arguments for the MCP server", - "items": map[string]any{ - "type": "string", - }, - }, - "type": map[string]any{ - "type": "string", - "description": "Type of MCP server", - "enum": []string{"stdio", "sse"}, - "default": "stdio", - }, - "url": map[string]any{ - "type": "string", - "description": "URL for SSE type MCP servers", - }, - "headers": map[string]any{ - "type": "object", - "description": "HTTP headers for SSE type MCP servers", - "additionalProperties": map[string]any{ - "type": "string", - }, - }, - }, - "required": []string{"command"}, - }, - } - - // Add providers - providerSchema := map[string]any{ - "type": "object", - "description": "LLM provider configurations", - "additionalProperties": map[string]any{ - "type": "object", - "description": "Provider configuration", - "properties": map[string]any{ - "apiKey": map[string]any{ - "type": "string", - "description": "API key for the provider", - }, - "disabled": map[string]any{ - "type": "boolean", - "description": "Whether the provider is disabled", - "default": false, - }, - }, - }, - } - - providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{ - "type": "string", - "description": "Provider type", - "enum": []string{}, - } - - schema["properties"].(map[string]any)["providers"] = providerSchema - - // Add agents - agentSchema := map[string]any{ - "type": "object", - "description": "Agent configurations", - "additionalProperties": map[string]any{ - "type": "object", - "description": "Agent configuration", - "properties": map[string]any{ - "model": map[string]any{ - "type": "string", - "description": "Model ID for the agent", - }, - "maxTokens": map[string]any{ - "type": "integer", - "description": "Maximum tokens for the agent", - "minimum": 1, - }, - "reasoningEffort": map[string]any{ - "type": "string", - "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", - "enum": []string{"low", "medium", "high"}, - }, - }, - "required": []string{"model"}, - }, - } - - // Add model enum - modelEnum := []string{} - - agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum - - // Add specific agent properties - agentProperties := map[string]any{} - knownAgents := []string{ +// GenerateSchema generates the complete JSON schema for the Crush configuration +func (g *SchemaGenerator) GenerateSchema() *JSONSchema { + // Generate schema for the main Config struct + configType := reflect.TypeOf(config.Config{}) + configSchema := g.generateTypeSchema(configType) + + // Create the root schema + schema := &JSONSchema{ + Schema: "http://json-schema.org/draft-07/schema#", + Title: "Crush Configuration", + Description: "Configuration schema for the Crush application", + Type: configSchema.Type, + Properties: configSchema.Properties, + Required: configSchema.Required, + Definitions: g.definitions, + } + + // Add custom enhancements + g.enhanceSchema(schema) + + return schema +} + +// generateTypeSchema generates a JSON schema for a given Go type +func (g *SchemaGenerator) generateTypeSchema(t reflect.Type) *JSONSchema { + // Handle pointers + if t.Kind() == reflect.Ptr { + return g.generateTypeSchema(t.Elem()) + } + + // Check if we've already processed this type + if g.visited[t] { + // Return a reference to avoid infinite recursion + return &JSONSchema{ + Ref: fmt.Sprintf("#/definitions/%s", t.Name()), + } + } + + switch t.Kind() { + case reflect.String: + return &JSONSchema{Type: "string"} + case reflect.Bool: + return &JSONSchema{Type: "boolean"} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return &JSONSchema{Type: "integer"} + case reflect.Float32, reflect.Float64: + return &JSONSchema{Type: "number"} + case reflect.Slice, reflect.Array: + itemSchema := g.generateTypeSchema(t.Elem()) + return &JSONSchema{ + Type: "array", + Items: itemSchema, + } + case reflect.Map: + valueSchema := g.generateTypeSchema(t.Elem()) + return &JSONSchema{ + Type: "object", + AdditionalProperties: valueSchema, + } + case reflect.Struct: + return g.generateStructSchema(t) + case reflect.Interface: + // For interface{} types, allow any value + return &JSONSchema{} + default: + // Fallback for unknown types + return &JSONSchema{} + } +} + +// generateStructSchema generates a JSON schema for a struct type +func (g *SchemaGenerator) generateStructSchema(t reflect.Type) *JSONSchema { + // Mark as visited to prevent infinite recursion + g.visited[t] = true + + schema := &JSONSchema{ + Type: "object", + Properties: make(map[string]*JSONSchema), + } + + var required []string + + for i := range t.NumField() { + field := t.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Get JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue + } + + // Parse JSON tag + jsonName, options := parseJSONTag(jsonTag) + if jsonName == "" { + jsonName = strings.ToLower(field.Name) + } + + // Generate field schema + fieldSchema := g.generateTypeSchema(field.Type) + + // Add description from field name if not present + if fieldSchema.Description == "" { + fieldSchema.Description = generateFieldDescription(field.Name, field.Type) + } + + // Check if field is required (not omitempty and not a pointer) + if !slices.Contains(options, "omitempty") && field.Type.Kind() != reflect.Ptr { + required = append(required, jsonName) + } + + schema.Properties[jsonName] = fieldSchema + } + + if len(required) > 0 { + schema.Required = required + } + + // Store in definitions if it's a named type + if t.Name() != "" { + g.definitions[t.Name()] = schema + } + + return schema +} + +// parseJSONTag parses a JSON struct tag +func parseJSONTag(tag string) (name string, options []string) { + if tag == "" { + return "", nil + } + + parts := strings.Split(tag, ",") + name = parts[0] + if len(parts) > 1 { + options = parts[1:] + } + return name, options +} + +// generateFieldDescription generates a description for a field based on its name and type +func generateFieldDescription(fieldName string, fieldType reflect.Type) string { + // Convert camelCase to words + words := camelCaseToWords(fieldName) + description := strings.Join(words, " ") + + // Add type-specific information + switch fieldType.Kind() { + case reflect.Bool: + if !strings.Contains(strings.ToLower(description), "enable") && + !strings.Contains(strings.ToLower(description), "disable") { + description = "Enable " + strings.ToLower(description) + } + case reflect.Slice: + if !strings.HasSuffix(description, "s") { + description = description + " list" + } + case reflect.Map: + description = description + " configuration" + } + + return description +} + +// camelCaseToWords converts camelCase to separate words +func camelCaseToWords(s string) []string { + var words []string + var currentWord strings.Builder + + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + if currentWord.Len() > 0 { + words = append(words, currentWord.String()) + currentWord.Reset() + } + } + currentWord.WriteRune(r) + } + + if currentWord.Len() > 0 { + words = append(words, currentWord.String()) + } + + return words +} + +// enhanceSchema adds custom enhancements to the generated schema +func (g *SchemaGenerator) enhanceSchema(schema *JSONSchema) { + // Add provider enums + g.addProviderEnums(schema) + + // Add model enums + g.addModelEnums(schema) + + // Add agent enums + g.addAgentEnums(schema) + + // Add tool enums + g.addToolEnums(schema) + + // Add MCP type enums + g.addMCPTypeEnums(schema) + + // Add model type enums + g.addModelTypeEnums(schema) + + // Add default values + g.addDefaultValues(schema) + + // Add custom descriptions + g.addCustomDescriptions(schema) +} + +// addProviderEnums adds provider enums to the schema +func (g *SchemaGenerator) addProviderEnums(schema *JSONSchema) { + providers := config.Providers() + var providerIDs []any + for _, p := range providers { + providerIDs = append(providerIDs, string(p.ID)) + } + + // Add to PreferredModel provider field + if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { + if providerProp, exists := preferredModelDef.Properties["provider"]; exists { + providerProp.Enum = providerIDs + } + } + + // Add to ProviderConfig ID field + if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { + if idProp, exists := providerConfigDef.Properties["id"]; exists { + idProp.Enum = providerIDs + } + } +} + +// addModelEnums adds model enums to the schema +func (g *SchemaGenerator) addModelEnums(schema *JSONSchema) { + providers := config.Providers() + var modelIDs []any + for _, p := range providers { + for _, m := range p.Models { + modelIDs = append(modelIDs, m.ID) + } + } + + // Add to PreferredModel model_id field + if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { + if modelIDProp, exists := preferredModelDef.Properties["model_id"]; exists { + modelIDProp.Enum = modelIDs + } + } +} + +// addAgentEnums adds agent ID enums to the schema +func (g *SchemaGenerator) addAgentEnums(schema *JSONSchema) { + agentIDs := []any{ string(config.AgentCoder), string(config.AgentTask), } - for _, agentName := range knownAgents { - agentProperties[agentName] = map[string]any{ - "$ref": "#/definitions/agent", - } - } - - // Create a combined schema that allows both specific agents and additional ones - combinedAgentSchema := map[string]any{ - "type": "object", - "description": "Agent configurations", - "properties": agentProperties, - "additionalProperties": agentSchema["additionalProperties"], - } - - schema["properties"].(map[string]any)["agents"] = combinedAgentSchema - schema["definitions"] = map[string]any{ - "agent": agentSchema["additionalProperties"], - } - - // Add LSP configuration - schema["properties"].(map[string]any)["lsp"] = map[string]any{ - "type": "object", - "description": "Language Server Protocol configurations", - "additionalProperties": map[string]any{ - "type": "object", - "description": "LSP configuration for a language", - "properties": map[string]any{ - "disabled": map[string]any{ - "type": "boolean", - "description": "Whether the LSP is disabled", - "default": false, - }, - "command": map[string]any{ - "type": "string", - "description": "Command to execute for the LSP server", - }, - "args": map[string]any{ - "type": "array", - "description": "Command arguments for the LSP server", - "items": map[string]any{ - "type": "string", - }, - }, - "options": map[string]any{ - "type": "object", - "description": "Additional options for the LSP server", - }, - }, - "required": []string{"command"}, - }, + if agentDef, exists := schema.Definitions["Agent"]; exists { + if idProp, exists := agentDef.Properties["id"]; exists { + idProp.Enum = agentIDs + } } +} - return schema +// addToolEnums adds tool enums to the schema +func (g *SchemaGenerator) addToolEnums(schema *JSONSchema) { + tools := []any{ + "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", + } + + if agentDef, exists := schema.Definitions["Agent"]; exists { + if allowedToolsProp, exists := agentDef.Properties["allowed_tools"]; exists { + if allowedToolsProp.Items != nil { + allowedToolsProp.Items.Enum = tools + } + } + } +} + +// addMCPTypeEnums adds MCP type enums to the schema +func (g *SchemaGenerator) addMCPTypeEnums(schema *JSONSchema) { + mcpTypes := []any{ + string(config.MCPStdio), + string(config.MCPSse), + } + + if mcpDef, exists := schema.Definitions["MCP"]; exists { + if typeProp, exists := mcpDef.Properties["type"]; exists { + typeProp.Enum = mcpTypes + } + } +} + +// addModelTypeEnums adds model type enums to the schema +func (g *SchemaGenerator) addModelTypeEnums(schema *JSONSchema) { + modelTypes := []any{ + string(config.LargeModel), + string(config.SmallModel), + } + + if agentDef, exists := schema.Definitions["Agent"]; exists { + if modelProp, exists := agentDef.Properties["model"]; exists { + modelProp.Enum = modelTypes + } + } +} + +// addDefaultValues adds default values to the schema +func (g *SchemaGenerator) addDefaultValues(schema *JSONSchema) { + // Add default context paths + if optionsDef, exists := schema.Definitions["Options"]; exists { + if contextPathsProp, exists := optionsDef.Properties["context_paths"]; exists { + contextPathsProp.Default = []any{ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md", + } + } + if dataDirProp, exists := optionsDef.Properties["data_directory"]; exists { + dataDirProp.Default = ".crush" + } + if debugProp, exists := optionsDef.Properties["debug"]; exists { + debugProp.Default = false + } + if debugLSPProp, exists := optionsDef.Properties["debug_lsp"]; exists { + debugLSPProp.Default = false + } + if disableAutoSummarizeProp, exists := optionsDef.Properties["disable_auto_summarize"]; exists { + disableAutoSummarizeProp.Default = false + } + } + + // Add default MCP type + if mcpDef, exists := schema.Definitions["MCP"]; exists { + if typeProp, exists := mcpDef.Properties["type"]; exists { + typeProp.Default = string(config.MCPStdio) + } + } + + // Add default TUI options + if tuiOptionsDef, exists := schema.Definitions["TUIOptions"]; exists { + if compactModeProp, exists := tuiOptionsDef.Properties["compact_mode"]; exists { + compactModeProp.Default = false + } + } + + // Add default provider disabled + if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { + if disabledProp, exists := providerConfigDef.Properties["disabled"]; exists { + disabledProp.Default = false + } + } + + // Add default agent disabled + if agentDef, exists := schema.Definitions["Agent"]; exists { + if disabledProp, exists := agentDef.Properties["disabled"]; exists { + disabledProp.Default = false + } + } + + // Add default LSP disabled + if lspConfigDef, exists := schema.Definitions["LSPConfig"]; exists { + if disabledProp, exists := lspConfigDef.Properties["enabled"]; exists { + disabledProp.Default = true + } + } +} + +// addCustomDescriptions adds custom descriptions to improve the schema +func (g *SchemaGenerator) addCustomDescriptions(schema *JSONSchema) { + // Enhance main config descriptions + if schema.Properties != nil { + if modelsProp, exists := schema.Properties["models"]; exists { + modelsProp.Description = "Preferred model configurations for large and small model types" + } + if providersProp, exists := schema.Properties["providers"]; exists { + providersProp.Description = "LLM provider configurations" + } + if agentsProp, exists := schema.Properties["agents"]; exists { + agentsProp.Description = "Agent configurations for different tasks" + } + if mcpProp, exists := schema.Properties["mcp"]; exists { + mcpProp.Description = "Model Control Protocol server configurations" + } + if lspProp, exists := schema.Properties["lsp"]; exists { + lspProp.Description = "Language Server Protocol configurations" + } + if optionsProp, exists := schema.Properties["options"]; exists { + optionsProp.Description = "General application options and settings" + } + } + + // Enhance specific field descriptions + if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { + if apiKeyProp, exists := providerConfigDef.Properties["api_key"]; exists { + apiKeyProp.Description = "API key for authenticating with the provider" + } + if baseURLProp, exists := providerConfigDef.Properties["base_url"]; exists { + baseURLProp.Description = "Base URL for the provider API (required for custom providers)" + } + if extraHeadersProp, exists := providerConfigDef.Properties["extra_headers"]; exists { + extraHeadersProp.Description = "Additional HTTP headers to send with requests" + } + if extraParamsProp, exists := providerConfigDef.Properties["extra_params"]; exists { + extraParamsProp.Description = "Additional provider-specific parameters" + } + } + + if agentDef, exists := schema.Definitions["Agent"]; exists { + if allowedToolsProp, exists := agentDef.Properties["allowed_tools"]; exists { + allowedToolsProp.Description = "List of tools this agent is allowed to use (if nil, all tools are allowed)" + } + if allowedMCPProp, exists := agentDef.Properties["allowed_mcp"]; exists { + allowedMCPProp.Description = "Map of MCP servers this agent can use and their allowed tools" + } + if allowedLSPProp, exists := agentDef.Properties["allowed_lsp"]; exists { + allowedLSPProp.Description = "List of LSP servers this agent can use (if nil, all LSPs are allowed)" + } + if contextPathsProp, exists := agentDef.Properties["context_paths"]; exists { + contextPathsProp.Description = "Custom context paths for this agent (additive to global context paths)" + } + } + + if mcpDef, exists := schema.Definitions["MCP"]; exists { + if commandProp, exists := mcpDef.Properties["command"]; exists { + commandProp.Description = "Command to execute for stdio MCP servers" + } + if urlProp, exists := mcpDef.Properties["url"]; exists { + urlProp.Description = "URL for SSE MCP servers" + } + if headersProp, exists := mcpDef.Properties["headers"]; exists { + headersProp.Description = "HTTP headers for SSE MCP servers" + } + } } diff --git a/crush-schema.json b/crush-schema.json index 5412a4badecb3e9d49022a69b3c7eb20fce0812b..f5fa562c5aff42972eb2308c3374969e5d42cac8 100644 --- a/crush-schema.json +++ b/crush-schema.json @@ -1,383 +1,1505 @@ { "$schema": "http://json-schema.org/draft-07/schema#", - "definitions": { - "agent": { - "description": "Agent configuration", - "properties": { - "maxTokens": { - "description": "Maximum tokens for the agent", - "minimum": 1, - "type": "integer" - }, - "model": { - "description": "Model ID for the agent", - "enum": [ - "gpt-4.1", - "llama-3.3-70b-versatile", - "azure.gpt-4.1", - "openrouter.gpt-4o", - "openrouter.o1-mini", - "openrouter.claude-3-haiku", - "claude-3-opus", - "gpt-4o", - "gpt-4o-mini", - "o1", - "meta-llama/llama-4-maverick-17b-128e-instruct", - "azure.o3-mini", - "openrouter.gpt-4o-mini", - "openrouter.o1", - "claude-3.5-haiku", - "o4-mini", - "azure.gpt-4.1-mini", - "openrouter.o3", - "grok-3-beta", - "o3-mini", - "qwen-qwq", - "azure.o1", - "openrouter.gemini-2.5-flash", - "openrouter.gemini-2.5", - "o1-mini", - "azure.gpt-4o", - "openrouter.gpt-4.1-mini", - "openrouter.claude-3.5-sonnet", - "openrouter.o3-mini", - "gpt-4.1-mini", - "gpt-4.5-preview", - "gpt-4.1-nano", - "deepseek-r1-distill-llama-70b", - "azure.gpt-4o-mini", - "openrouter.gpt-4.1", - "bedrock.claude-3.7-sonnet", - "claude-3-haiku", - "o3", - "gemini-2.0-flash-lite", - "azure.o3", - "azure.gpt-4.5-preview", - "openrouter.claude-3-opus", - "grok-3-mini-fast-beta", - "claude-4-sonnet", - "azure.o4-mini", - "grok-3-fast-beta", - "claude-3.5-sonnet", - "azure.o1-mini", - "openrouter.claude-3.7-sonnet", - "openrouter.gpt-4.5-preview", - "grok-3-mini-beta", - "claude-3.7-sonnet", - "gemini-2.0-flash", - "openrouter.deepseek-r1-free", - "vertexai.gemini-2.5-flash", - "vertexai.gemini-2.5", - "o1-pro", - "gemini-2.5", - "meta-llama/llama-4-scout-17b-16e-instruct", - "azure.gpt-4.1-nano", - "openrouter.gpt-4.1-nano", - "gemini-2.5-flash", - "openrouter.o4-mini", - "openrouter.claude-3.5-haiku", - "claude-4-opus", - "openrouter.o1-pro" - ], - "type": "string" - }, - "reasoningEffort": { - "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", - "enum": ["low", "medium", "high"], - "type": "string" - } - }, - "required": ["model"], - "type": "object" - } - }, + "title": "Crush Configuration", "description": "Configuration schema for the Crush application", + "type": "object", "properties": { "agents": { + "description": "Agent configurations for different tasks", + "type": "object", "additionalProperties": { - "description": "Agent configuration", + "type": "object", "properties": { - "maxTokens": { - "description": "Maximum tokens for the agent", - "minimum": 1, - "type": "integer" + "allowed_lsp": { + "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", + "type": "array", + "items": { + "type": "string" + } + }, + "allowed_mcp": { + "description": "Map of MCP servers this agent can use and their allowed tools", + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "allowed_tools": { + "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", + "type": "array", + "items": { + "type": "string", + "enum": [ + "bash", + "edit", + "fetch", + "glob", + "grep", + "ls", + "sourcegraph", + "view", + "write", + "agent" + ] + } + }, + "context_paths": { + "description": "Custom context paths for this agent (additive to global context paths)", + "type": "array", + "items": { + "type": "string" + } + }, + "description": { + "description": "Description", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "coder", + "task" + ] }, "model": { - "description": "Model ID for the agent", + "description": "Model", + "type": "string", "enum": [ - "gpt-4.1", - "llama-3.3-70b-versatile", - "azure.gpt-4.1", - "openrouter.gpt-4o", - "openrouter.o1-mini", - "openrouter.claude-3-haiku", - "claude-3-opus", - "gpt-4o", - "gpt-4o-mini", - "o1", - "meta-llama/llama-4-maverick-17b-128e-instruct", - "azure.o3-mini", - "openrouter.gpt-4o-mini", - "openrouter.o1", - "claude-3.5-haiku", - "o4-mini", - "azure.gpt-4.1-mini", - "openrouter.o3", - "grok-3-beta", - "o3-mini", - "qwen-qwq", - "azure.o1", - "openrouter.gemini-2.5-flash", - "openrouter.gemini-2.5", - "o1-mini", - "azure.gpt-4o", - "openrouter.gpt-4.1-mini", - "openrouter.claude-3.5-sonnet", - "openrouter.o3-mini", - "gpt-4.1-mini", - "gpt-4.5-preview", - "gpt-4.1-nano", - "deepseek-r1-distill-llama-70b", - "azure.gpt-4o-mini", - "openrouter.gpt-4.1", - "bedrock.claude-3.7-sonnet", - "claude-3-haiku", - "o3", - "gemini-2.0-flash-lite", - "azure.o3", - "azure.gpt-4.5-preview", - "openrouter.claude-3-opus", - "grok-3-mini-fast-beta", - "claude-4-sonnet", - "azure.o4-mini", - "grok-3-fast-beta", - "claude-3.5-sonnet", - "azure.o1-mini", - "openrouter.claude-3.7-sonnet", - "openrouter.gpt-4.5-preview", - "grok-3-mini-beta", - "claude-3.7-sonnet", - "gemini-2.0-flash", - "openrouter.deepseek-r1-free", - "vertexai.gemini-2.5-flash", - "vertexai.gemini-2.5", - "o1-pro", - "gemini-2.5", - "meta-llama/llama-4-scout-17b-16e-instruct", - "azure.gpt-4.1-nano", - "openrouter.gpt-4.1-nano", - "gemini-2.5-flash", - "openrouter.o4-mini", - "openrouter.claude-3.5-haiku", - "claude-4-opus", - "openrouter.o1-pro" - ], - "type": "string" + "large", + "small" + ] }, - "reasoningEffort": { - "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", - "enum": ["low", "medium", "high"], + "name": { + "description": "Name", "type": "string" } }, - "required": ["model"], - "type": "object" - }, - "description": "Agent configurations", - "properties": { - "coder": { - "$ref": "#/definitions/agent" - }, - "task": { - "$ref": "#/definitions/agent" - }, - "title": { - "$ref": "#/definitions/agent" - } - }, - "type": "object" - }, - "contextPaths": { - "default": [ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md" - ], - "description": "Context paths for the application", - "items": { - "type": "string" - }, - "type": "array" - }, - "data": { - "description": "Storage configuration", - "properties": { - "directory": { - "default": ".crush", - "description": "Directory where application data is stored", - "type": "string" - } - }, - "required": ["directory"], - "type": "object" - }, - "debug": { - "default": false, - "description": "Enable debug mode", - "type": "boolean" - }, - "debugLSP": { - "default": false, - "description": "Enable LSP debug mode", - "type": "boolean" + "required": [ + "id", + "name", + "disabled", + "model", + "allowed_tools", + "allowed_mcp", + "allowed_lsp", + "context_paths" + ] + } }, "lsp": { + "description": "Language Server Protocol configurations", + "type": "object", "additionalProperties": { - "description": "LSP configuration for a language", + "type": "object", "properties": { "args": { - "description": "Command arguments for the LSP server", + "description": "Args", + "type": "array", "items": { "type": "string" - }, - "type": "array" + } }, "command": { - "description": "Command to execute for the LSP server", + "description": "Command", "type": "string" }, - "disabled": { - "default": false, - "description": "Whether the LSP is disabled", - "type": "boolean" + "enabled": { + "description": "Disabled", + "type": "boolean", + "default": true }, "options": { - "description": "Additional options for the LSP server", - "type": "object" + "description": "Options" } }, - "required": ["command"], - "type": "object" - }, - "description": "Language Server Protocol configurations", - "type": "object" + "required": [ + "enabled", + "command", + "args", + "options" + ] + } }, - "mcpServers": { + "mcp": { + "description": "Model Control Protocol server configurations", + "type": "object", "additionalProperties": { - "description": "MCP server configuration", + "type": "object", "properties": { "args": { - "description": "Command arguments for the MCP server", + "description": "Args", + "type": "array", "items": { "type": "string" - }, - "type": "array" + } }, "command": { - "description": "Command to execute for the MCP server", + "description": "Command to execute for stdio MCP servers", "type": "string" }, "env": { - "description": "Environment variables for the MCP server", + "description": "Env list", + "type": "array", "items": { "type": "string" - }, - "type": "array" + } }, "headers": { + "description": "HTTP headers for SSE MCP servers", + "type": "object", "additionalProperties": { "type": "string" - }, - "description": "HTTP headers for SSE type MCP servers", - "type": "object" + } }, "type": { - "default": "stdio", - "description": "Type of MCP server", - "enum": ["stdio", "sse"], - "type": "string" + "description": "Type", + "type": "string", + "enum": [ + "stdio", + "sse" + ], + "default": "stdio" }, "url": { - "description": "URL for SSE type MCP servers", + "description": "URL for SSE MCP servers", "type": "string" } }, - "required": ["command"], - "type": "object" + "required": [ + "command", + "env", + "args", + "type", + "url", + "headers" + ] + } + }, + "models": { + "description": "Preferred model configurations for large and small model types", + "type": "object", + "properties": { + "large": { + "description": "Large", + "type": "object", + "properties": { + "max_tokens": { + "description": "Max Tokens", + "type": "integer" + }, + "model_id": { + "description": "Model I D", + "type": "string", + "enum": [ + "claude-3-opus", + "claude-3-haiku", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-preview", + "o1-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "grok-beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-haiku" + ] + }, + "provider": { + "description": "Provider", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "think": { + "description": "Enable think", + "type": "boolean" + } + }, + "required": [ + "model_id", + "provider" + ] + }, + "small": { + "description": "Small", + "$ref": "#/definitions/PreferredModel" + } }, - "description": "Model Control Protocol server configurations", - "type": "object" + "required": [ + "large", + "small" + ] + }, + "options": { + "description": "General application options and settings", + "type": "object", + "properties": { + "context_paths": { + "description": "Context Paths", + "type": "array", + "items": { + "type": "string" + }, + "default": [ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md" + ] + }, + "data_directory": { + "description": "Data Directory", + "type": "string", + "default": ".crush" + }, + "debug": { + "description": "Enable debug", + "type": "boolean", + "default": false + }, + "debug_lsp": { + "description": "Enable debug l s p", + "type": "boolean", + "default": false + }, + "disable_auto_summarize": { + "description": "Disable Auto Summarize", + "type": "boolean", + "default": false + }, + "tui": { + "description": "T U I", + "type": "object", + "properties": { + "compact_mode": { + "description": "Enable compact mode", + "type": "boolean", + "default": false + } + }, + "required": [ + "compact_mode" + ] + } + }, + "required": [ + "context_paths", + "tui", + "debug", + "debug_lsp", + "disable_auto_summarize", + "data_directory" + ] }, "providers": { + "description": "LLM provider configurations", + "type": "object", "additionalProperties": { - "description": "Provider configuration", + "type": "object", "properties": { - "apiKey": { - "description": "API key for the provider", + "api_key": { + "description": "API key for authenticating with the provider", + "type": "string" + }, + "base_url": { + "description": "Base URL for the provider API (required for custom providers)", + "type": "string" + }, + "default_large_model": { + "description": "Default Large Model", + "type": "string" + }, + "default_small_model": { + "description": "Default Small Model", "type": "string" }, "disabled": { - "default": false, - "description": "Whether the provider is disabled", - "type": "boolean" + "description": "Disabled", + "type": "boolean", + "default": false + }, + "extra_headers": { + "description": "Additional HTTP headers to send with requests", + "type": "object", + "additionalProperties": { + "type": "string" + } }, - "provider": { - "description": "Provider type", + "extra_params": { + "description": "Additional provider-specific parameters", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "id": { + "description": "I D", + "type": "string", "enum": [ "anthropic", "openai", "gemini", - "groq", - "openrouter", - "bedrock", - "azure", - "vertexai" - ], + "xai", + "openrouter" + ] + }, + "models": { + "description": "Models", + "type": "array", + "items": { + "type": "object", + "properties": { + "can_reason": { + "description": "Enable can reason", + "type": "boolean" + }, + "context_window": { + "description": "Context Window", + "type": "integer" + }, + "cost_per_1m_in": { + "description": "Cost Per1 M In", + "type": "number" + }, + "cost_per_1m_in_cached": { + "description": "Cost Per1 M In Cached", + "type": "number" + }, + "cost_per_1m_out": { + "description": "Cost Per1 M Out", + "type": "number" + }, + "cost_per_1m_out_cached": { + "description": "Cost Per1 M Out Cached", + "type": "number" + }, + "default_max_tokens": { + "description": "Default Max Tokens", + "type": "integer" + }, + "has_reasoning_effort": { + "description": "Enable has reasoning effort", + "type": "boolean" + }, + "id": { + "description": "I D", + "type": "string" + }, + "model": { + "description": "Name", + "type": "string" + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "supports_attachments": { + "description": "Enable supports images", + "type": "boolean" + } + }, + "required": [ + "id", + "model", + "cost_per_1m_in", + "cost_per_1m_out", + "cost_per_1m_in_cached", + "cost_per_1m_out_cached", + "context_window", + "default_max_tokens", + "can_reason", + "reasoning_effort", + "has_reasoning_effort", + "supports_attachments" + ] + } + }, + "provider_type": { + "description": "Provider Type", + "type": "string" + } + }, + "required": [ + "id", + "provider_type", + "disabled" + ] + } + } + }, + "required": [ + "models", + "options" + ], + "definitions": { + "Agent": { + "type": "object", + "properties": { + "allowed_lsp": { + "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", + "type": "array", + "items": { + "type": "string" + } + }, + "allowed_mcp": { + "description": "Map of MCP servers this agent can use and their allowed tools", + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "allowed_tools": { + "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", + "type": "array", + "items": { + "type": "string", + "enum": [ + "bash", + "edit", + "fetch", + "glob", + "grep", + "ls", + "sourcegraph", + "view", + "write", + "agent" + ] + } + }, + "context_paths": { + "description": "Custom context paths for this agent (additive to global context paths)", + "type": "array", + "items": { "type": "string" } }, - "type": "object" + "description": { + "description": "Description", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "coder", + "task" + ] + }, + "model": { + "description": "Model", + "type": "string", + "enum": [ + "large", + "small" + ] + }, + "name": { + "description": "Name", + "type": "string" + } }, - "description": "LLM provider configurations", - "type": "object" + "required": [ + "id", + "name", + "disabled", + "model", + "allowed_tools", + "allowed_mcp", + "allowed_lsp", + "context_paths" + ] + }, + "Config": { + "type": "object", + "properties": { + "agents": { + "description": "Agent configurations for different tasks", + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "allowed_lsp": { + "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", + "type": "array", + "items": { + "type": "string" + } + }, + "allowed_mcp": { + "description": "Map of MCP servers this agent can use and their allowed tools", + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "allowed_tools": { + "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", + "type": "array", + "items": { + "type": "string", + "enum": [ + "bash", + "edit", + "fetch", + "glob", + "grep", + "ls", + "sourcegraph", + "view", + "write", + "agent" + ] + } + }, + "context_paths": { + "description": "Custom context paths for this agent (additive to global context paths)", + "type": "array", + "items": { + "type": "string" + } + }, + "description": { + "description": "Description", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "coder", + "task" + ] + }, + "model": { + "description": "Model", + "type": "string", + "enum": [ + "large", + "small" + ] + }, + "name": { + "description": "Name", + "type": "string" + } + }, + "required": [ + "id", + "name", + "disabled", + "model", + "allowed_tools", + "allowed_mcp", + "allowed_lsp", + "context_paths" + ] + } + }, + "lsp": { + "description": "Language Server Protocol configurations", + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "args": { + "description": "Args", + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "description": "Command", + "type": "string" + }, + "enabled": { + "description": "Disabled", + "type": "boolean", + "default": true + }, + "options": { + "description": "Options" + } + }, + "required": [ + "enabled", + "command", + "args", + "options" + ] + } + }, + "mcp": { + "description": "Model Control Protocol server configurations", + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "args": { + "description": "Args", + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "description": "Command to execute for stdio MCP servers", + "type": "string" + }, + "env": { + "description": "Env list", + "type": "array", + "items": { + "type": "string" + } + }, + "headers": { + "description": "HTTP headers for SSE MCP servers", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "type": { + "description": "Type", + "type": "string", + "enum": [ + "stdio", + "sse" + ], + "default": "stdio" + }, + "url": { + "description": "URL for SSE MCP servers", + "type": "string" + } + }, + "required": [ + "command", + "env", + "args", + "type", + "url", + "headers" + ] + } + }, + "models": { + "description": "Preferred model configurations for large and small model types", + "type": "object", + "properties": { + "large": { + "description": "Large", + "type": "object", + "properties": { + "max_tokens": { + "description": "Max Tokens", + "type": "integer" + }, + "model_id": { + "description": "Model I D", + "type": "string", + "enum": [ + "claude-3-opus", + "claude-3-haiku", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-preview", + "o1-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "grok-beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-haiku" + ] + }, + "provider": { + "description": "Provider", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "think": { + "description": "Enable think", + "type": "boolean" + } + }, + "required": [ + "model_id", + "provider" + ] + }, + "small": { + "description": "Small", + "$ref": "#/definitions/PreferredModel" + } + }, + "required": [ + "large", + "small" + ] + }, + "options": { + "description": "General application options and settings", + "type": "object", + "properties": { + "context_paths": { + "description": "Context Paths", + "type": "array", + "items": { + "type": "string" + }, + "default": [ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md" + ] + }, + "data_directory": { + "description": "Data Directory", + "type": "string", + "default": ".crush" + }, + "debug": { + "description": "Enable debug", + "type": "boolean", + "default": false + }, + "debug_lsp": { + "description": "Enable debug l s p", + "type": "boolean", + "default": false + }, + "disable_auto_summarize": { + "description": "Disable Auto Summarize", + "type": "boolean", + "default": false + }, + "tui": { + "description": "T U I", + "type": "object", + "properties": { + "compact_mode": { + "description": "Enable compact mode", + "type": "boolean", + "default": false + } + }, + "required": [ + "compact_mode" + ] + } + }, + "required": [ + "context_paths", + "tui", + "debug", + "debug_lsp", + "disable_auto_summarize", + "data_directory" + ] + }, + "providers": { + "description": "LLM provider configurations", + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "api_key": { + "description": "API key for authenticating with the provider", + "type": "string" + }, + "base_url": { + "description": "Base URL for the provider API (required for custom providers)", + "type": "string" + }, + "default_large_model": { + "description": "Default Large Model", + "type": "string" + }, + "default_small_model": { + "description": "Default Small Model", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "extra_headers": { + "description": "Additional HTTP headers to send with requests", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "extra_params": { + "description": "Additional provider-specific parameters", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "models": { + "description": "Models", + "type": "array", + "items": { + "type": "object", + "properties": { + "can_reason": { + "description": "Enable can reason", + "type": "boolean" + }, + "context_window": { + "description": "Context Window", + "type": "integer" + }, + "cost_per_1m_in": { + "description": "Cost Per1 M In", + "type": "number" + }, + "cost_per_1m_in_cached": { + "description": "Cost Per1 M In Cached", + "type": "number" + }, + "cost_per_1m_out": { + "description": "Cost Per1 M Out", + "type": "number" + }, + "cost_per_1m_out_cached": { + "description": "Cost Per1 M Out Cached", + "type": "number" + }, + "default_max_tokens": { + "description": "Default Max Tokens", + "type": "integer" + }, + "has_reasoning_effort": { + "description": "Enable has reasoning effort", + "type": "boolean" + }, + "id": { + "description": "I D", + "type": "string" + }, + "model": { + "description": "Name", + "type": "string" + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "supports_attachments": { + "description": "Enable supports images", + "type": "boolean" + } + }, + "required": [ + "id", + "model", + "cost_per_1m_in", + "cost_per_1m_out", + "cost_per_1m_in_cached", + "cost_per_1m_out_cached", + "context_window", + "default_max_tokens", + "can_reason", + "reasoning_effort", + "has_reasoning_effort", + "supports_attachments" + ] + } + }, + "provider_type": { + "description": "Provider Type", + "type": "string" + } + }, + "required": [ + "id", + "provider_type", + "disabled" + ] + } + } + }, + "required": [ + "models", + "options" + ] }, - "tui": { - "description": "Terminal User Interface configuration", + "LSPConfig": { + "type": "object", "properties": { - "theme": { - "default": "crush", - "description": "TUI theme name", + "args": { + "description": "Args", + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "description": "Command", + "type": "string" + }, + "enabled": { + "description": "Disabled", + "type": "boolean", + "default": true + }, + "options": { + "description": "Options" + } + }, + "required": [ + "enabled", + "command", + "args", + "options" + ] + }, + "MCP": { + "type": "object", + "properties": { + "args": { + "description": "Args", + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "description": "Command to execute for stdio MCP servers", + "type": "string" + }, + "env": { + "description": "Env list", + "type": "array", + "items": { + "type": "string" + } + }, + "headers": { + "description": "HTTP headers for SSE MCP servers", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "type": { + "description": "Type", + "type": "string", "enum": [ - "crush", - "catppuccin", - "dracula", - "flexoki", - "gruvbox", - "monokai", - "onedark", - "tokyonight", - "tron" + "stdio", + "sse" ], + "default": "stdio" + }, + "url": { + "description": "URL for SSE MCP servers", + "type": "string" + } + }, + "required": [ + "command", + "env", + "args", + "type", + "url", + "headers" + ] + }, + "Model": { + "type": "object", + "properties": { + "can_reason": { + "description": "Enable can reason", + "type": "boolean" + }, + "context_window": { + "description": "Context Window", + "type": "integer" + }, + "cost_per_1m_in": { + "description": "Cost Per1 M In", + "type": "number" + }, + "cost_per_1m_in_cached": { + "description": "Cost Per1 M In Cached", + "type": "number" + }, + "cost_per_1m_out": { + "description": "Cost Per1 M Out", + "type": "number" + }, + "cost_per_1m_out_cached": { + "description": "Cost Per1 M Out Cached", + "type": "number" + }, + "default_max_tokens": { + "description": "Default Max Tokens", + "type": "integer" + }, + "has_reasoning_effort": { + "description": "Enable has reasoning effort", + "type": "boolean" + }, + "id": { + "description": "I D", + "type": "string" + }, + "model": { + "description": "Name", + "type": "string" + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "supports_attachments": { + "description": "Enable supports images", + "type": "boolean" + } + }, + "required": [ + "id", + "model", + "cost_per_1m_in", + "cost_per_1m_out", + "cost_per_1m_in_cached", + "cost_per_1m_out_cached", + "context_window", + "default_max_tokens", + "can_reason", + "reasoning_effort", + "has_reasoning_effort", + "supports_attachments" + ] + }, + "Options": { + "description": "General application options and settings", + "type": "object", + "properties": { + "context_paths": { + "description": "Context Paths", + "type": "array", + "items": { + "type": "string" + }, + "default": [ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md" + ] + }, + "data_directory": { + "description": "Data Directory", + "type": "string", + "default": ".crush" + }, + "debug": { + "description": "Enable debug", + "type": "boolean", + "default": false + }, + "debug_lsp": { + "description": "Enable debug l s p", + "type": "boolean", + "default": false + }, + "disable_auto_summarize": { + "description": "Disable Auto Summarize", + "type": "boolean", + "default": false + }, + "tui": { + "description": "T U I", + "type": "object", + "properties": { + "compact_mode": { + "description": "Enable compact mode", + "type": "boolean", + "default": false + } + }, + "required": [ + "compact_mode" + ] + } + }, + "required": [ + "context_paths", + "tui", + "debug", + "debug_lsp", + "disable_auto_summarize", + "data_directory" + ] + }, + "PreferredModel": { + "description": "Large", + "type": "object", + "properties": { + "max_tokens": { + "description": "Max Tokens", + "type": "integer" + }, + "model_id": { + "description": "Model I D", + "type": "string", + "enum": [ + "claude-3-opus", + "claude-3-haiku", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-preview", + "o1-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "grok-beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-haiku" + ] + }, + "provider": { + "description": "Provider", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "think": { + "description": "Enable think", + "type": "boolean" + } + }, + "required": [ + "model_id", + "provider" + ] + }, + "PreferredModels": { + "description": "Preferred model configurations for large and small model types", + "type": "object", + "properties": { + "large": { + "description": "Large", + "type": "object", + "properties": { + "max_tokens": { + "description": "Max Tokens", + "type": "integer" + }, + "model_id": { + "description": "Model I D", + "type": "string", + "enum": [ + "claude-3-opus", + "claude-3-haiku", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-preview", + "o1-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "grok-beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-haiku" + ] + }, + "provider": { + "description": "Provider", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "think": { + "description": "Enable think", + "type": "boolean" + } + }, + "required": [ + "model_id", + "provider" + ] + }, + "small": { + "description": "Small", + "$ref": "#/definitions/PreferredModel" + } + }, + "required": [ + "large", + "small" + ] + }, + "ProviderConfig": { + "type": "object", + "properties": { + "api_key": { + "description": "API key for authenticating with the provider", + "type": "string" + }, + "base_url": { + "description": "Base URL for the provider API (required for custom providers)", + "type": "string" + }, + "default_large_model": { + "description": "Default Large Model", + "type": "string" + }, + "default_small_model": { + "description": "Default Small Model", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "extra_headers": { + "description": "Additional HTTP headers to send with requests", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "extra_params": { + "description": "Additional provider-specific parameters", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "models": { + "description": "Models", + "type": "array", + "items": { + "type": "object", + "properties": { + "can_reason": { + "description": "Enable can reason", + "type": "boolean" + }, + "context_window": { + "description": "Context Window", + "type": "integer" + }, + "cost_per_1m_in": { + "description": "Cost Per1 M In", + "type": "number" + }, + "cost_per_1m_in_cached": { + "description": "Cost Per1 M In Cached", + "type": "number" + }, + "cost_per_1m_out": { + "description": "Cost Per1 M Out", + "type": "number" + }, + "cost_per_1m_out_cached": { + "description": "Cost Per1 M Out Cached", + "type": "number" + }, + "default_max_tokens": { + "description": "Default Max Tokens", + "type": "integer" + }, + "has_reasoning_effort": { + "description": "Enable has reasoning effort", + "type": "boolean" + }, + "id": { + "description": "I D", + "type": "string" + }, + "model": { + "description": "Name", + "type": "string" + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "supports_attachments": { + "description": "Enable supports images", + "type": "boolean" + } + }, + "required": [ + "id", + "model", + "cost_per_1m_in", + "cost_per_1m_out", + "cost_per_1m_in_cached", + "cost_per_1m_out_cached", + "context_window", + "default_max_tokens", + "can_reason", + "reasoning_effort", + "has_reasoning_effort", + "supports_attachments" + ] + } + }, + "provider_type": { + "description": "Provider Type", "type": "string" } }, - "type": "object" + "required": [ + "id", + "provider_type", + "disabled" + ] }, - "wd": { - "description": "Working directory for the application", - "type": "string" + "TUIOptions": { + "description": "T U I", + "type": "object", + "properties": { + "compact_mode": { + "description": "Enable compact mode", + "type": "boolean", + "default": false + } + }, + "required": [ + "compact_mode" + ] } - }, - "title": "Crush Configuration", - "type": "object" + } } diff --git a/internal/config/config.go b/internal/config/config.go index 69a528d3c57eba4ef4b8802d800bfb0a7a764c14..3caf9f01c4afdba4dd2c29c43fc690dd360173ef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -732,6 +732,34 @@ func defaultConfigBasedOnEnv() *Config { "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), } + // Find the VertexAI provider definition to get default models + for _, p := range providers { + if p.ID == provider.InferenceProviderVertexAI { + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + for _, model := range p.Models { + configModel := Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + } + // Set reasoning effort for reasoning models + if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { + configModel.HasReasoningEffort = model.HasReasoningEffort + configModel.ReasoningEffort = model.DefaultReasoningEffort + } + providerConfig.Models = append(providerConfig.Models, configModel) + } + break + } + } cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig } @@ -743,6 +771,34 @@ func defaultConfigBasedOnEnv() *Config { if providerConfig.ExtraParams["region"] == "" { providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION") } + // Find the Bedrock provider definition to get default models + for _, p := range providers { + if p.ID == provider.InferenceProviderBedrock { + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + for _, model := range p.Models { + configModel := Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + } + // Set reasoning effort for reasoning models + if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { + configModel.HasReasoningEffort = model.HasReasoningEffort + configModel.ReasoningEffort = model.DefaultReasoningEffort + } + providerConfig.Models = append(providerConfig.Models, configModel) + } + break + } + } cfg.Providers[provider.InferenceProviderBedrock] = providerConfig } return cfg diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f69b3c84701b07c6df948c7abca6e37a65e3c69e..b48a9eba0a92a9f9239d6f6e3526c24cc8790ac9 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -93,7 +93,7 @@ func TestInit_SingletonBehavior(t *testing.T) { require.NoError(t, err1) require.NoError(t, err2) - assert.Same(t, cfg1, cfg2) // Should be the same instance + assert.Same(t, cfg1, cfg2) } func TestGet_BeforeInitialization(t *testing.T) { @@ -124,7 +124,7 @@ func TestLoadConfig_NoConfigFiles(t *testing.T) { cfg, err := Init(cwdDir, false) require.NoError(t, err) - assert.Len(t, cfg.Providers, 0) // No providers without env vars or config files + assert.Len(t, cfg.Providers, 0) assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) } @@ -133,7 +133,6 @@ func TestLoadConfig_OnlyGlobalConfig(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create global config file globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -187,7 +186,6 @@ func TestLoadConfig_OnlyLocalConfig(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create local config file localConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderAnthropic: { @@ -239,7 +237,6 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create global config globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -279,7 +276,6 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config that overrides and adds localConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -327,14 +323,11 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { require.NoError(t, err) assert.Len(t, cfg.Providers, 2) - // Check that local config overrode global openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] assert.Equal(t, "local-key", openaiProvider.APIKey) - // Check that local config added new provider assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) - // Check that context paths were merged assert.Contains(t, cfg.Options.ContextPaths, "global-context.md") assert.Contains(t, cfg.Options.ContextPaths, "local-context.md") assert.True(t, cfg.Options.TUI.CompactMode) @@ -345,7 +338,6 @@ func TestLoadConfig_MalformedGlobalJSON(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create malformed global config configPath := filepath.Join(testConfigDir, "crush.json") require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644)) @@ -359,7 +351,6 @@ func TestLoadConfig_MalformedLocalJSON(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create malformed local config localConfigPath := filepath.Join(cwdDir, "crush.json") require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644)) @@ -409,7 +400,6 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set all supported API keys os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") os.Setenv("OPENAI_API_KEY", "test-openai-key") os.Setenv("GEMINI_API_KEY", "test-gemini-key") @@ -421,7 +411,6 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { require.NoError(t, err) assert.Len(t, cfg.Providers, 5) - // Verify each provider is configured correctly anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey) assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType) @@ -449,7 +438,6 @@ func TestEnvVars_PartialEnvironmentVariables(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set only some API keys os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") os.Setenv("OPENAI_API_KEY", "test-openai-key") @@ -467,7 +455,6 @@ func TestEnvVars_VertexAIConfiguration(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set VertexAI environment variables os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") @@ -488,7 +475,6 @@ func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set Google Cloud vars but not the use flag os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") @@ -503,7 +489,6 @@ func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set AWS credentials os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key") os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") os.Setenv("AWS_DEFAULT_REGION", "us-east-1") @@ -523,7 +508,6 @@ func TestEnvVars_AWSBedrockWithProfile(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set AWS profile os.Setenv("AWS_PROFILE", "test-profile") os.Setenv("AWS_REGION", "eu-west-1") @@ -541,7 +525,6 @@ func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set AWS container credentials os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test") os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1") @@ -556,7 +539,6 @@ func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set both region variables - AWS_DEFAULT_REGION should take priority os.Setenv("AWS_ACCESS_KEY_ID", "test-key") os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") os.Setenv("AWS_DEFAULT_REGION", "us-west-2") @@ -574,7 +556,6 @@ func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set only AWS_REGION (not AWS_DEFAULT_REGION) os.Setenv("AWS_ACCESS_KEY_ID", "test-key") os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") os.Setenv("AWS_REGION", "us-east-1") @@ -591,7 +572,6 @@ func TestEnvVars_NoAWSCredentials(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Don't set any AWS credentials cfg, err := Init(cwdDir, false) require.NoError(t, err) @@ -603,15 +583,12 @@ func TestEnvVars_CustomEnvironmentVariables(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Test that environment variables are properly resolved from provider definitions - // This test assumes the provider system uses $VARIABLE_NAME format os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key") cfg, err := Init(cwdDir, false) require.NoError(t, err) if len(cfg.Providers) > 0 { - // Verify that the environment variable was resolved if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists { assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey) } @@ -623,11 +600,11 @@ func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set multiple types of environment variables os.Setenv("ANTHROPIC_API_KEY", "test-anthropic") os.Setenv("OPENAI_API_KEY", "test-openai") os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key") os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret") os.Setenv("AWS_DEFAULT_REGION", "us-west-1") @@ -636,7 +613,6 @@ func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) { require.NoError(t, err) - // Should have API key providers + VertexAI + Bedrock expectedProviders := []provider.InferenceProvider{ provider.InferenceProviderAnthropic, provider.InferenceProviderOpenAI, @@ -696,14 +672,11 @@ func TestHasAWSCredentials_NoCredentials(t *testing.T) { assert.False(t, hasAWSCredentials()) } -// Provider Configuration Tests - func TestProviderMerging_GlobalToBase(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create global config with provider globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -719,6 +692,12 @@ func TestProviderMerging_GlobalToBase(t *testing.T) { ContextWindow: 8192, DefaultMaxTokens: 4096, }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, }, }, }, @@ -739,7 +718,7 @@ func TestProviderMerging_GlobalToBase(t *testing.T) { assert.Equal(t, "global-openai-key", openaiProvider.APIKey) assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel) assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel) - assert.Len(t, openaiProvider.Models, 1) + assert.Len(t, openaiProvider.Models, 2) } func TestProviderMerging_LocalToBase(t *testing.T) { @@ -747,7 +726,6 @@ func TestProviderMerging_LocalToBase(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create local config with provider localConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderAnthropic: { @@ -755,6 +733,25 @@ func TestProviderMerging_LocalToBase(t *testing.T) { APIKey: "local-anthropic-key", ProviderType: provider.TypeAnthropic, DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CostPer1MIn: 15.0, + CostPer1MOut: 75.0, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CostPer1MIn: 0.25, + CostPer1MOut: 1.25, + }, + }, }, }, } @@ -772,6 +769,8 @@ func TestProviderMerging_LocalToBase(t *testing.T) { anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey) assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel) + assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModel) + assert.Len(t, anthropicProvider.Models, 2) } func TestProviderMerging_ConflictingSettings(t *testing.T) { @@ -779,7 +778,6 @@ func TestProviderMerging_ConflictingSettings(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create global config globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -788,6 +786,26 @@ func TestProviderMerging_ConflictingSettings(t *testing.T) { ProviderType: provider.TypeOpenAI, DefaultLargeModel: "gpt-4", DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + { + ID: "gpt-4-turbo", + Name: "GPT-4 Turbo", + ContextWindow: 128000, + DefaultMaxTokens: 4096, + }, + }, }, }, } @@ -804,7 +822,6 @@ func TestProviderMerging_ConflictingSettings(t *testing.T) { provider.InferenceProviderOpenAI: { APIKey: "local-key", DefaultLargeModel: "gpt-4-turbo", - // Test disabled separately - don't disable here as it causes nil pointer }, }, } @@ -819,11 +836,9 @@ func TestProviderMerging_ConflictingSettings(t *testing.T) { require.NoError(t, err) openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - // Local should override global assert.Equal(t, "local-key", openaiProvider.APIKey) assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel) - assert.False(t, openaiProvider.Disabled) // Should not be disabled - // Global values should remain where not overridden + assert.False(t, openaiProvider.Disabled) assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel) } @@ -834,22 +849,51 @@ func TestProviderMerging_CustomVsKnownProviders(t *testing.T) { customProviderID := provider.InferenceProvider("custom-provider") - // Create config with both known and custom providers globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ - // Known provider - some fields should not be overrideable provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - BaseURL: "should-not-override", - ProviderType: provider.TypeAnthropic, // Should not override + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + BaseURL: "should-not-override", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, }, - // Custom provider - all fields should be configurable customProviderID: { - ID: customProviderID, - APIKey: "custom-key", - BaseURL: "https://custom.api.com", - ProviderType: provider.TypeOpenAI, + ID: customProviderID, + APIKey: "custom-key", + BaseURL: "https://custom.api.com", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "custom-large", + DefaultSmallModel: "custom-small", + Models: []Model{ + { + ID: "custom-large", + Name: "Custom Large", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "custom-small", + Name: "Custom Small", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, }, }, } @@ -882,14 +926,12 @@ func TestProviderMerging_CustomVsKnownProviders(t *testing.T) { require.NoError(t, err) - // Known provider should not have BaseURL/ProviderType overridden openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL) assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType) - // Custom provider should have all fields configurable customProvider := cfg.Providers[customProviderID] - assert.Equal(t, "custom-key", customProvider.APIKey) // Should preserve from global + assert.Equal(t, "custom-key", customProvider.APIKey) assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL) assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType) } @@ -901,14 +943,12 @@ func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) { customProviderID := provider.InferenceProvider("custom-provider") - // Create config with custom provider missing BaseURL globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ customProviderID: { ID: customProviderID, APIKey: "custom-key", ProviderType: provider.TypeOpenAI, - // Missing BaseURL }, }, } @@ -922,7 +962,6 @@ func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) { cfg, err := Init(cwdDir, false) require.NoError(t, err) - // Provider should be filtered out due to validation failure assert.NotContains(t, cfg.Providers, customProviderID) } @@ -939,7 +978,6 @@ func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) { ID: customProviderID, BaseURL: "https://custom.api.com", ProviderType: provider.TypeOpenAI, - // Missing APIKey }, }, } @@ -994,10 +1032,26 @@ func TestProviderValidation_KnownProviderValid(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - // BaseURL not required for known providers + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, + }, }, } @@ -1022,10 +1076,48 @@ func TestProviderValidation_DisabledProvider(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - Disabled: true, + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Disabled: true, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, + }, + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + Disabled: false, // This one is enabled + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + }, }, }, } @@ -1039,9 +1131,10 @@ func TestProviderValidation_DisabledProvider(t *testing.T) { cfg, err := Init(cwdDir, false) require.NoError(t, err) - // Disabled providers should still be in the config but marked as disabled assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled) + assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) + assert.False(t, cfg.Providers[provider.InferenceProviderAnthropic].Disabled) } func TestProviderModels_AddingNewModels(t *testing.T) { @@ -1052,9 +1145,11 @@ func TestProviderModels_AddingNewModels(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-4-turbo", Models: []Model{ { ID: "gpt-4", @@ -1098,7 +1193,7 @@ func TestProviderModels_AddingNewModels(t *testing.T) { require.NoError(t, err) openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.Len(t, openaiProvider.Models, 2) // Should have both models + assert.Len(t, openaiProvider.Models, 2) modelIDs := make([]string, len(openaiProvider.Models)) for i, model := range openaiProvider.Models { @@ -1116,9 +1211,11 @@ func TestProviderModels_DuplicateModelHandling(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-4", Models: []Model{ { ID: "gpt-4", @@ -1136,7 +1233,7 @@ func TestProviderModels_DuplicateModelHandling(t *testing.T) { provider.InferenceProviderOpenAI: { Models: []Model{ { - ID: "gpt-4", // Same ID as global + ID: "gpt-4", Name: "GPT-4 Updated", ContextWindow: 16384, DefaultMaxTokens: 8192, @@ -1162,13 +1259,12 @@ func TestProviderModels_DuplicateModelHandling(t *testing.T) { require.NoError(t, err) openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.Len(t, openaiProvider.Models, 1) // Should not duplicate + assert.Len(t, openaiProvider.Models, 1) - // Should keep the original model (global config) model := openaiProvider.Models[0] assert.Equal(t, "gpt-4", model.ID) - assert.Equal(t, "GPT-4", model.Name) // Original name - assert.Equal(t, int64(8192), model.ContextWindow) // Original context window + assert.Equal(t, "GPT-4", model.Name) + assert.Equal(t, int64(8192), model.ContextWindow) } func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { @@ -1179,9 +1275,11 @@ func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-4", Models: []Model{ { ID: "gpt-4", @@ -1224,14 +1322,11 @@ func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { assert.True(t, model.SupportsImages) } -// Agent Configuration Tests - func TestDefaultAgents_CoderAgent(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider so we can test agent configuration os.Setenv("ANTHROPIC_API_KEY", "test-key") cfg, err := Init(cwdDir, false) @@ -1246,7 +1341,6 @@ func TestDefaultAgents_CoderAgent(t *testing.T) { assert.Equal(t, LargeModel, coderAgent.Model) assert.False(t, coderAgent.Disabled) assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths) - // Coder agent should have all tools available (nil means all tools) assert.Nil(t, coderAgent.AllowedTools) } @@ -1255,7 +1349,6 @@ func TestDefaultAgents_TaskAgent(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider so we can test agent configuration os.Setenv("ANTHROPIC_API_KEY", "test-key") cfg, err := Init(cwdDir, false) @@ -1271,11 +1364,9 @@ func TestDefaultAgents_TaskAgent(t *testing.T) { assert.False(t, taskAgent.Disabled) assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths) - // Task agent should have restricted tools expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} assert.Equal(t, expectedTools, taskAgent.AllowedTools) - // Task agent should have no MCPs or LSPs by default assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP) assert.Equal(t, []string{}, taskAgent.AllowedLSP) } @@ -1285,10 +1376,8 @@ func TestAgentMerging_CustomAgent(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config with custom agent globalConfig := Config{ Agents: map[AgentID]Agent{ AgentID("custom-agent"): { @@ -1302,6 +1391,23 @@ func TestAgentMerging_CustomAgent(t *testing.T) { ContextPaths: []string{"custom-context.md"}, }, }, + MCP: map[string]MCP{ + "mcp1": { + Type: MCPStdio, + Command: "test-mcp-command", + Args: []string{"--test"}, + }, + }, + LSP: map[string]LSPConfig{ + "typescript": { + Command: "typescript-language-server", + Args: []string{"--stdio"}, + }, + "go": { + Command: "gopls", + Args: []string{}, + }, + }, } configPath := filepath.Join(testConfigDir, "crush.json") @@ -1314,7 +1420,6 @@ func TestAgentMerging_CustomAgent(t *testing.T) { require.NoError(t, err) - // Should have default agents plus custom agent assert.Contains(t, cfg.Agents, AgentCoder) assert.Contains(t, cfg.Agents, AgentTask) assert.Contains(t, cfg.Agents, AgentID("custom-agent")) @@ -1326,7 +1431,6 @@ func TestAgentMerging_CustomAgent(t *testing.T) { assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools) assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP) assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP) - // Context paths should be additive (default + custom) expectedContextPaths := append(defaultContextPaths, "custom-context.md") assert.Equal(t, expectedContextPaths, customAgent.ContextPaths) } @@ -1336,17 +1440,28 @@ func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config that modifies the default coder agent globalConfig := Config{ Agents: map[AgentID]Agent{ AgentCoder: { - Model: SmallModel, // Change from default LargeModel + Model: SmallModel, AllowedMCP: map[string][]string{"mcp1": {"tool1"}}, AllowedLSP: []string{"typescript"}, - ContextPaths: []string{"coder-specific.md"}, // Should be additive + ContextPaths: []string{"coder-specific.md"}, + }, + }, + MCP: map[string]MCP{ + "mcp1": { + Type: MCPStdio, + Command: "test-mcp-command", + Args: []string{"--test"}, + }, + }, + LSP: map[string]LSPConfig{ + "typescript": { + Command: "typescript-language-server", + Args: []string{"--stdio"}, }, }, } @@ -1362,16 +1477,13 @@ func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) { require.NoError(t, err) coderAgent := cfg.Agents[AgentCoder] - // Should preserve default values for unspecified fields assert.Equal(t, AgentCoder, coderAgent.ID) assert.Equal(t, "Coder", coderAgent.Name) assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description) - // Context paths should be additive (default + custom) expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md") assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) - // Should update specified fields assert.Equal(t, SmallModel, coderAgent.Model) assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP) assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP) @@ -1382,22 +1494,31 @@ func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config that modifies the default task agent - // Note: Only model, MCP, and LSP should be configurable for known agents globalConfig := Config{ Agents: map[AgentID]Agent{ AgentTask: { - Model: SmallModel, // Should be updated - AllowedMCP: map[string][]string{"search-mcp": nil}, // Should be updated - AllowedLSP: []string{"python"}, // Should be updated - // These should be ignored for known agents: - Name: "Search Agent", // Should be ignored - Description: "Custom search agent", // Should be ignored - Disabled: true, // Should be ignored - AllowedTools: []string{"glob", "grep", "view"}, // Should be ignored + Model: SmallModel, + AllowedMCP: map[string][]string{"search-mcp": nil}, + AllowedLSP: []string{"python"}, + Name: "Search Agent", + Description: "Custom search agent", + Disabled: true, + AllowedTools: []string{"glob", "grep", "view"}, + }, + }, + MCP: map[string]MCP{ + "search-mcp": { + Type: MCPStdio, + Command: "search-mcp-command", + Args: []string{"--search"}, + }, + }, + LSP: map[string]LSPConfig{ + "python": { + Command: "pylsp", + Args: []string{}, }, }, } @@ -1413,13 +1534,11 @@ func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) { require.NoError(t, err) taskAgent := cfg.Agents[AgentTask] - // Should preserve default values for protected fields - assert.Equal(t, "Task", taskAgent.Name) // Should remain default - assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) // Should remain default - assert.False(t, taskAgent.Disabled) // Should remain default - assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) // Should remain default + assert.Equal(t, "Task", taskAgent.Name) + assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) + assert.False(t, taskAgent.Disabled) + assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) - // Should update configurable fields assert.Equal(t, SmallModel, taskAgent.Model) assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP) assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP) @@ -1430,10 +1549,8 @@ func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with custom agent globalConfig := Config{ Agents: map[AgentID]Agent{ AgentID("test-agent"): { @@ -1464,6 +1581,13 @@ func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { AllowedMCP: map[string][]string{"local-mcp": {"tool1"}}, }, }, + MCP: map[string]MCP{ + "local-mcp": { + Type: MCPStdio, + Command: "local-mcp-command", + Args: []string{"--local"}, + }, + }, } localConfigPath := filepath.Join(cwdDir, "crush.json") @@ -1476,7 +1600,6 @@ func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { require.NoError(t, err) testAgent := cfg.Agents[AgentID("test-agent")] - // Local should override global assert.Equal(t, "Local Agent", testAgent.Name) assert.Equal(t, "Local description", testAgent.Description) assert.Equal(t, SmallModel, testAgent.Model) @@ -1490,10 +1613,8 @@ func TestAgentModelTypeAssignment(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config with agents using different model types globalConfig := Config{ Agents: map[AgentID]Agent{ AgentID("large-agent"): { @@ -1509,7 +1630,6 @@ func TestAgentModelTypeAssignment(t *testing.T) { AgentID("default-agent"): { ID: AgentID("default-agent"), Name: "Default Model Agent", - // No model specified - should default to LargeModel }, }, } @@ -1526,7 +1646,7 @@ func TestAgentModelTypeAssignment(t *testing.T) { assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model) assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model) - assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) // Should default to LargeModel + assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) } func TestAgentContextPathOverrides(t *testing.T) { @@ -1534,10 +1654,8 @@ func TestAgentContextPathOverrides(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config with custom context paths globalConfig := Config{ Options: Options{ ContextPaths: []string{"global-context.md", "shared-context.md"}, @@ -1551,7 +1669,6 @@ func TestAgentContextPathOverrides(t *testing.T) { AgentID("default-context-agent"): { ID: AgentID("default-context-agent"), Name: "Default Context Agent", - // No ContextPaths specified - should use global }, }, } @@ -1566,32 +1683,25 @@ func TestAgentContextPathOverrides(t *testing.T) { require.NoError(t, err) - // Agent with custom context paths should have default + global + custom paths (additive) customAgent := cfg.Agents[AgentID("custom-context-agent")] expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md") assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths) - // Agent without custom context paths should use global + defaults defaultAgent := cfg.Agents[AgentID("default-context-agent")] expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md") assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths) - // Default agents should also use the merged context paths coderAgent := cfg.Agents[AgentCoder] assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) } -// Options and Settings Tests - func TestOptionsMerging_ContextPaths(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with context paths globalConfig := Config{ Options: Options{ ContextPaths: []string{"global1.md", "global2.md"}, @@ -1604,7 +1714,6 @@ func TestOptionsMerging_ContextPaths(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config with additional context paths localConfig := Config{ Options: Options{ ContextPaths: []string{"local1.md", "local2.md"}, @@ -1620,7 +1729,6 @@ func TestOptionsMerging_ContextPaths(t *testing.T) { require.NoError(t, err) - // Context paths should be merged: defaults + global + local expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md") assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths) } @@ -1630,14 +1738,12 @@ func TestOptionsMerging_TUIOptions(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with TUI options globalConfig := Config{ Options: Options{ TUI: TUIOptions{ - CompactMode: false, // Default value + CompactMode: false, }, }, } @@ -1648,7 +1754,6 @@ func TestOptionsMerging_TUIOptions(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config that enables compact mode localConfig := Config{ Options: Options{ TUI: TUIOptions{ @@ -1666,7 +1771,6 @@ func TestOptionsMerging_TUIOptions(t *testing.T) { require.NoError(t, err) - // Local config should override global assert.True(t, cfg.Options.TUI.CompactMode) } @@ -1675,10 +1779,8 @@ func TestOptionsMerging_DebugFlags(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with debug flags globalConfig := Config{ Options: Options{ Debug: false, @@ -1693,7 +1795,6 @@ func TestOptionsMerging_DebugFlags(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config that enables debug flags localConfig := Config{ Options: Options{ DebugLSP: true, @@ -1710,10 +1811,9 @@ func TestOptionsMerging_DebugFlags(t *testing.T) { require.NoError(t, err) - // Local config should override global for boolean flags - assert.False(t, cfg.Options.Debug) // Not set in local, remains global value - assert.True(t, cfg.Options.DebugLSP) // Set to true in local - assert.True(t, cfg.Options.DisableAutoSummarize) // Set to true in local + assert.False(t, cfg.Options.Debug) + assert.True(t, cfg.Options.DebugLSP) + assert.True(t, cfg.Options.DisableAutoSummarize) } func TestOptionsMerging_DataDirectory(t *testing.T) { @@ -1721,10 +1821,8 @@ func TestOptionsMerging_DataDirectory(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with custom data directory globalConfig := Config{ Options: Options{ DataDirectory: "global-data", @@ -1737,7 +1835,6 @@ func TestOptionsMerging_DataDirectory(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config with different data directory localConfig := Config{ Options: Options{ DataDirectory: "local-data", @@ -1753,7 +1850,6 @@ func TestOptionsMerging_DataDirectory(t *testing.T) { require.NoError(t, err) - // Local config should override global assert.Equal(t, "local-data", cfg.Options.DataDirectory) } @@ -1762,15 +1858,12 @@ func TestOptionsMerging_DefaultValues(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // No config files - should use defaults cfg, err := Init(cwdDir, false) require.NoError(t, err) - // Should have default values assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory) assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) assert.False(t, cfg.Options.TUI.CompactMode) @@ -1784,10 +1877,8 @@ func TestOptionsMerging_DebugFlagFromInit(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config with debug false globalConfig := Config{ Options: Options{ Debug: false, @@ -1800,7 +1891,6 @@ func TestOptionsMerging_DebugFlagFromInit(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Init with debug=true should override config cfg, err := Init(cwdDir, true) require.NoError(t, err) @@ -1895,85 +1985,20 @@ func TestModelSelection_PreferredModelSelection(t *testing.T) { assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider) } -func TestModelSelection_GetAgentModel(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider with known models - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - CanReason: true, - SupportsImages: true, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - CanReason: false, - SupportsImages: false, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetAgentModel for default agents - coderModel := GetAgentModel(AgentCoder) - assert.Equal(t, "gpt-4", coderModel.ID) // Coder uses LargeModel - assert.Equal(t, "GPT-4", coderModel.Name) - assert.True(t, coderModel.CanReason) - assert.True(t, coderModel.SupportsImages) - - taskModel := GetAgentModel(AgentTask) - assert.Equal(t, "gpt-4", taskModel.ID) // Task also uses LargeModel by default - assert.Equal(t, "GPT-4", taskModel.Name) -} - -func TestModelSelection_GetAgentModelWithCustomModelType(t *testing.T) { +func TestValidation_InvalidModelReference(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up provider and custom agent with SmallModel globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { ID: provider.InferenceProviderOpenAI, APIKey: "test-key", ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", + DefaultLargeModel: "non-existent-model", DefaultSmallModel: "gpt-3.5-turbo", Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, { ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo", @@ -1983,54 +2008,6 @@ func TestModelSelection_GetAgentModelWithCustomModelType(t *testing.T) { }, }, }, - Agents: map[AgentID]Agent{ - AgentID("small-agent"): { - ID: AgentID("small-agent"), - Name: "Small Agent", - Model: SmallModel, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetAgentModel for custom agent with SmallModel - smallAgentModel := GetAgentModel(AgentID("small-agent")) - assert.Equal(t, "gpt-3.5-turbo", smallAgentModel.ID) - assert.Equal(t, "GPT-3.5 Turbo", smallAgentModel.Name) -} - -func TestModelSelection_GetAgentProvider(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up multiple providers - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - }, - provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "anthropic-key", - ProviderType: provider.TypeAnthropic, - DefaultLargeModel: "claude-3-opus", - DefaultSmallModel: "claude-3-haiku", - }, - }, } configPath := filepath.Join(testConfigDir, "crush.json") @@ -2040,27 +2017,18 @@ func TestModelSelection_GetAgentProvider(t *testing.T) { require.NoError(t, os.WriteFile(configPath, data, 0o644)) _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetAgentProvider - coderProvider := GetAgentProvider(AgentCoder) - assert.NotEmpty(t, coderProvider.ID) - assert.NotEmpty(t, coderProvider.APIKey) - assert.NotEmpty(t, coderProvider.ProviderType) + assert.Error(t, err) } -func TestModelSelection_GetProviderModel(t *testing.T) { +func TestValidation_EmptyAPIKey(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up provider with specific models globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", ProviderType: provider.TypeOpenAI, Models: []Model{ { @@ -2068,16 +2036,6 @@ func TestModelSelection_GetProviderModel(t *testing.T) { Name: "GPT-4", ContextWindow: 8192, DefaultMaxTokens: 4096, - CostPer1MIn: 30.0, - CostPer1MOut: 60.0, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - CostPer1MIn: 1.5, - CostPer1MOut: 2.0, }, }, }, @@ -2091,107 +2049,22 @@ func TestModelSelection_GetProviderModel(t *testing.T) { require.NoError(t, os.WriteFile(configPath, data, 0o644)) _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetProviderModel - gpt4Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-4") - assert.Equal(t, "gpt-4", gpt4Model.ID) - assert.Equal(t, "GPT-4", gpt4Model.Name) - assert.Equal(t, int64(8192), gpt4Model.ContextWindow) - assert.Equal(t, 30.0, gpt4Model.CostPer1MIn) - - gpt35Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-3.5-turbo") - assert.Equal(t, "gpt-3.5-turbo", gpt35Model.ID) - assert.Equal(t, "GPT-3.5 Turbo", gpt35Model.Name) - assert.Equal(t, 1.5, gpt35Model.CostPer1MIn) - - // Test non-existent model - nonExistentModel := GetProviderModel(provider.InferenceProviderOpenAI, "non-existent") - assert.Empty(t, nonExistentModel.ID) + assert.Error(t, err) } -func TestModelSelection_GetModel(t *testing.T) { +func TestValidation_InvalidAgentModelType(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up provider with models - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetModel - largeModel := GetModel(LargeModel) - assert.Equal(t, "gpt-4", largeModel.ID) - assert.Equal(t, "GPT-4", largeModel.Name) - - smallModel := GetModel(SmallModel) - assert.Equal(t, "gpt-3.5-turbo", smallModel.ID) - assert.Equal(t, "GPT-3.5 Turbo", smallModel.Name) -} - -func TestModelSelection_UpdatePreferredModel(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() + os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Set up multiple providers with OpenAI first to ensure it's selected initially globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - {ID: "gpt-4", Name: "GPT-4"}, - {ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo"}, - }, - }, - provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "anthropic-key", - ProviderType: provider.TypeAnthropic, - DefaultLargeModel: "claude-3-opus", - DefaultSmallModel: "claude-3-haiku", - Models: []Model{ - {ID: "claude-3-opus", Name: "Claude 3 Opus"}, - {ID: "claude-3-haiku", Name: "Claude 3 Haiku"}, - }, + Agents: map[AgentID]Agent{ + AgentID("invalid-agent"): { + ID: AgentID("invalid-agent"), + Name: "Invalid Agent", + Model: ModelType("invalid"), }, }, } @@ -2203,97 +2076,5 @@ func TestModelSelection_UpdatePreferredModel(t *testing.T) { require.NoError(t, os.WriteFile(configPath, data, 0o644)) _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Get initial preferred models (should be OpenAI since it's listed first) - initialLargeModel := GetModel(LargeModel) - initialSmallModel := GetModel(SmallModel) - - // Verify initial models are OpenAI models - assert.Equal(t, "claude-3-opus", initialLargeModel.ID) - assert.Equal(t, "claude-3-haiku", initialSmallModel.ID) - - // Update preferred models to Anthropic - newLargeModel := PreferredModel{ - ModelID: "gpt-4", - Provider: provider.InferenceProviderOpenAI, - } - newSmallModel := PreferredModel{ - ModelID: "gpt-3.5-turbo", - Provider: provider.InferenceProviderOpenAI, - } - - err = UpdatePreferredModel(LargeModel, newLargeModel) - require.NoError(t, err) - - err = UpdatePreferredModel(SmallModel, newSmallModel) - require.NoError(t, err) - - // Verify models were updated - updatedLargeModel := GetModel(LargeModel) - assert.Equal(t, "gpt-4", updatedLargeModel.ID) - assert.NotEqual(t, initialLargeModel.ID, updatedLargeModel.ID) - - updatedSmallModel := GetModel(SmallModel) - assert.Equal(t, "gpt-3.5-turbo", updatedSmallModel.ID) - assert.NotEqual(t, initialSmallModel.ID, updatedSmallModel.ID) -} - -func TestModelSelection_InvalidModelType(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - _, err := Init(cwdDir, false) - require.NoError(t, err) - - // Test UpdatePreferredModel with invalid model type - invalidModel := PreferredModel{ - ModelID: "some-model", - Provider: provider.InferenceProviderAnthropic, - } - - err = UpdatePreferredModel(ModelType("invalid"), invalidModel) assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown model type") -} - -func TestModelSelection_NonExistentAgent(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - _, err := Init(cwdDir, false) - require.NoError(t, err) - - // Test GetAgentModel with non-existent agent - nonExistentModel := GetAgentModel(AgentID("non-existent")) - assert.Empty(t, nonExistentModel.ID) - - // Test GetAgentProvider with non-existent agent - nonExistentProvider := GetAgentProvider(AgentID("non-existent")) - assert.Empty(t, nonExistentProvider.ID) -} - -func TestModelSelection_NonExistentProvider(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - _, err := Init(cwdDir, false) - require.NoError(t, err) - - // Test GetProviderModel with non-existent provider - nonExistentModel := GetProviderModel(provider.InferenceProvider("non-existent"), "some-model") - assert.Empty(t, nonExistentModel.ID) } diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go index af92cc2c33f0b0adbe65dbd728b29727c35aeaa8..73d39d761b15ae682b272f356c62234aaa3ca0ad 100644 --- a/internal/config/provider_mock.go +++ b/internal/config/provider_mock.go @@ -6,6 +6,7 @@ import ( // MockProviders returns a mock list of providers for testing. // This avoids making API calls during tests and provides consistent test data. +// Simplified version with only default models from each provider. func MockProviders() []provider.Provider { return []provider.Provider{ { @@ -14,43 +15,19 @@ func MockProviders() []provider.Provider { APIKey: "$ANTHROPIC_API_KEY", APIEndpoint: "$ANTHROPIC_API_ENDPOINT", Type: provider.TypeAnthropic, - DefaultLargeModelID: "claude-3-opus", - DefaultSmallModelID: "claude-3-haiku", + DefaultLargeModelID: "claude-sonnet-4-20250514", + DefaultSmallModelID: "claude-3-5-haiku-20241022", Models: []provider.Model{ { - ID: "claude-3-opus", - Name: "Claude 3 Opus", - CostPer1MIn: 15.0, - CostPer1MOut: 75.0, - CostPer1MInCached: 18.75, - CostPer1MOutCached: 1.5, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: true, - }, - { - ID: "claude-3-haiku", - Name: "Claude 3 Haiku", - CostPer1MIn: 0.25, - CostPer1MOut: 1.25, - CostPer1MInCached: 0.3, - CostPer1MOutCached: 0.03, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: true, - }, - { - ID: "claude-3-5-sonnet-20241022", - Name: "Claude 3.5 Sonnet", + ID: "claude-sonnet-4-20250514", + Name: "Claude Sonnet 4", CostPer1MIn: 3.0, CostPer1MOut: 15.0, CostPer1MInCached: 3.75, CostPer1MOutCached: 0.3, ContextWindow: 200000, - DefaultMaxTokens: 8192, - CanReason: false, + DefaultMaxTokens: 50000, + CanReason: true, SupportsImages: true, }, { @@ -61,7 +38,7 @@ func MockProviders() []provider.Provider { CostPer1MInCached: 1.0, CostPer1MOutCached: 0.08, ContextWindow: 200000, - DefaultMaxTokens: 8192, + DefaultMaxTokens: 5000, CanReason: false, SupportsImages: true, }, @@ -73,44 +50,22 @@ func MockProviders() []provider.Provider { APIKey: "$OPENAI_API_KEY", APIEndpoint: "$OPENAI_API_ENDPOINT", Type: provider.TypeOpenAI, - DefaultLargeModelID: "gpt-4", - DefaultSmallModelID: "gpt-3.5-turbo", + DefaultLargeModelID: "codex-mini-latest", + DefaultSmallModelID: "gpt-4o", Models: []provider.Model{ { - ID: "gpt-4", - Name: "GPT-4", - CostPer1MIn: 30.0, - CostPer1MOut: 60.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 8192, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: false, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - CostPer1MIn: 1.0, - CostPer1MOut: 2.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 4096, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: false, - }, - { - ID: "gpt-4-turbo", - Name: "GPT-4 Turbo", - CostPer1MIn: 10.0, - CostPer1MOut: 30.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 128000, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: true, + ID: "codex-mini-latest", + Name: "Codex Mini", + CostPer1MIn: 1.5, + CostPer1MOut: 6.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.375, + ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: true, + HasReasoningEffort: true, + DefaultReasoningEffort: "medium", + SupportsImages: true, }, { ID: "gpt-4o", @@ -120,50 +75,10 @@ func MockProviders() []provider.Provider { CostPer1MInCached: 0.0, CostPer1MOutCached: 1.25, ContextWindow: 128000, - DefaultMaxTokens: 16384, + DefaultMaxTokens: 20000, CanReason: false, SupportsImages: true, }, - { - ID: "gpt-4o-mini", - Name: "GPT-4o-mini", - CostPer1MIn: 0.15, - CostPer1MOut: 0.6, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.075, - ContextWindow: 128000, - DefaultMaxTokens: 16384, - CanReason: false, - SupportsImages: true, - }, - { - ID: "o1-preview", - Name: "o1-preview", - CostPer1MIn: 15.0, - CostPer1MOut: 60.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 128000, - DefaultMaxTokens: 32768, - CanReason: true, - HasReasoningEffort: true, - DefaultReasoningEffort: "medium", - SupportsImages: true, - }, - { - ID: "o1-mini", - Name: "o1-mini", - CostPer1MIn: 3.0, - CostPer1MOut: 12.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 128000, - DefaultMaxTokens: 65536, - CanReason: true, - HasReasoningEffort: true, - DefaultReasoningEffort: "medium", - SupportsImages: true, - }, }, }, { @@ -183,7 +98,7 @@ func MockProviders() []provider.Provider { CostPer1MInCached: 1.625, CostPer1MOutCached: 0.31, ContextWindow: 1048576, - DefaultMaxTokens: 65536, + DefaultMaxTokens: 50000, CanReason: true, SupportsImages: true, }, @@ -195,7 +110,7 @@ func MockProviders() []provider.Provider { CostPer1MInCached: 0.3833, CostPer1MOutCached: 0.075, ContextWindow: 1048576, - DefaultMaxTokens: 65535, + DefaultMaxTokens: 50000, CanReason: true, SupportsImages: true, }, @@ -207,18 +122,135 @@ func MockProviders() []provider.Provider { APIKey: "$XAI_API_KEY", APIEndpoint: "https://api.x.ai/v1", Type: provider.TypeXAI, - DefaultLargeModelID: "grok-beta", - DefaultSmallModelID: "grok-beta", + DefaultLargeModelID: "grok-3", + DefaultSmallModelID: "grok-3-mini", Models: []provider.Model{ { - ID: "grok-beta", - Name: "Grok Beta", - CostPer1MIn: 5.0, + ID: "grok-3", + Name: "Grok 3", + CostPer1MIn: 3.0, CostPer1MOut: 15.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.75, ContextWindow: 131072, - DefaultMaxTokens: 4096, + DefaultMaxTokens: 20000, CanReason: false, - SupportsImages: true, + SupportsImages: false, + }, + { + ID: "grok-3-mini", + Name: "Grok 3 Mini", + CostPer1MIn: 0.3, + CostPer1MOut: 0.5, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.075, + ContextWindow: 131072, + DefaultMaxTokens: 20000, + CanReason: true, + SupportsImages: false, + }, + }, + }, + { + Name: "Azure OpenAI", + ID: provider.InferenceProviderAzure, + APIKey: "$AZURE_OPENAI_API_KEY", + APIEndpoint: "$AZURE_OPENAI_API_ENDPOINT", + Type: provider.TypeAzure, + DefaultLargeModelID: "o4-mini", + DefaultSmallModelID: "gpt-4o", + Models: []provider.Model{ + { + ID: "o4-mini", + Name: "o4 Mini", + CostPer1MIn: 1.1, + CostPer1MOut: 4.4, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.275, + ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: true, + HasReasoningEffort: false, + DefaultReasoningEffort: "medium", + SupportsImages: true, + }, + { + ID: "gpt-4o", + Name: "GPT-4o", + CostPer1MIn: 2.5, + CostPer1MOut: 10.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 1.25, + ContextWindow: 128000, + DefaultMaxTokens: 20000, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "AWS Bedrock", + ID: provider.InferenceProviderBedrock, + Type: provider.TypeBedrock, + DefaultLargeModelID: "anthropic.claude-sonnet-4-20250514-v1:0", + DefaultSmallModelID: "anthropic.claude-3-5-haiku-20241022-v1:0", + Models: []provider.Model{ + { + ID: "anthropic.claude-sonnet-4-20250514-v1:0", + Name: "AWS Claude Sonnet 4", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.3, + ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: true, + SupportsImages: true, + }, + { + ID: "anthropic.claude-3-5-haiku-20241022-v1:0", + Name: "AWS Claude 3.5 Haiku", + CostPer1MIn: 0.8, + CostPer1MOut: 4.0, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "Google Vertex AI", + ID: provider.InferenceProviderVertexAI, + Type: provider.TypeVertexAI, + DefaultLargeModelID: "gemini-2.5-pro", + DefaultSmallModelID: "gemini-2.5-flash", + Models: []provider.Model{ + { + ID: "gemini-2.5-pro", + Name: "Gemini 2.5 Pro", + CostPer1MIn: 1.25, + CostPer1MOut: 10.0, + CostPer1MInCached: 1.625, + CostPer1MOutCached: 0.31, + ContextWindow: 1048576, + DefaultMaxTokens: 50000, + CanReason: true, + SupportsImages: true, + }, + { + ID: "gemini-2.5-flash", + Name: "Gemini 2.5 Flash", + CostPer1MIn: 0.3, + CostPer1MOut: 2.5, + CostPer1MInCached: 0.3833, + CostPer1MOutCached: 0.075, + ContextWindow: 1048576, + DefaultMaxTokens: 50000, + CanReason: true, + SupportsImages: true, }, }, }, @@ -228,28 +260,32 @@ func MockProviders() []provider.Provider { APIKey: "$OPENROUTER_API_KEY", APIEndpoint: "https://openrouter.ai/api/v1", Type: provider.TypeOpenAI, - DefaultLargeModelID: "anthropic/claude-3.5-sonnet", - DefaultSmallModelID: "anthropic/claude-3.5-haiku", + DefaultLargeModelID: "anthropic/claude-sonnet-4", + DefaultSmallModelID: "anthropic/claude-haiku-3.5", Models: []provider.Model{ { - ID: "anthropic/claude-3.5-sonnet", - Name: "Claude 3.5 Sonnet", - CostPer1MIn: 3.0, - CostPer1MOut: 15.0, - ContextWindow: 200000, - DefaultMaxTokens: 8192, - CanReason: false, - SupportsImages: true, + ID: "anthropic/claude-sonnet-4", + Name: "Anthropic: Claude Sonnet 4", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.3, + ContextWindow: 200000, + DefaultMaxTokens: 32000, + CanReason: true, + SupportsImages: true, }, { - ID: "anthropic/claude-3.5-haiku", - Name: "Claude 3.5 Haiku", - CostPer1MIn: 0.8, - CostPer1MOut: 4.0, - ContextWindow: 200000, - DefaultMaxTokens: 8192, - CanReason: false, - SupportsImages: true, + ID: "anthropic/claude-haiku-3.5", + Name: "Anthropic: Claude 3.5 Haiku", + CostPer1MIn: 0.8, + CostPer1MOut: 4.0, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, }, }, }, diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 70224d194a6689d85602d6a0f7d92e03b02fa1b2..8f2a31f06ab121fa049e7ca8bed159976cb2e92f 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,7 +1,6 @@ package config import ( - "encoding/json" "testing" "github.com/charmbracelet/crush/internal/fur/provider" @@ -9,8 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestMockProviders(t *testing.T) { - // Enable mock providers for testing +func TestProviders_MockEnabled(t *testing.T) { originalUseMock := UseMockProviders UseMockProviders = true defer func() { @@ -18,94 +16,38 @@ func TestMockProviders(t *testing.T) { ResetProviders() }() - // Reset providers to ensure we get fresh mock data ResetProviders() - providers := Providers() - require.NotEmpty(t, providers, "Mock providers should not be empty") + require.NotEmpty(t, providers) - // Verify we have the expected mock providers providerIDs := make(map[provider.InferenceProvider]bool) for _, p := range providers { providerIDs[p.ID] = true } - assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider") - assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider") - assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider") - - // Verify Anthropic provider details - var anthropicProvider provider.Provider - for _, p := range providers { - if p.ID == provider.InferenceProviderAnthropic { - anthropicProvider = p - break - } - } - - assert.Equal(t, "Anthropic", anthropicProvider.Name) - assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type) - assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID) - assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID) - assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models") - - // Verify model details - var opusModel provider.Model - for _, m := range anthropicProvider.Models { - if m.ID == "claude-3-opus" { - opusModel = m - break - } - } - - assert.Equal(t, "Claude 3 Opus", opusModel.Name) - assert.Equal(t, int64(200000), opusModel.ContextWindow) - assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens) - assert.True(t, opusModel.SupportsImages) -} - -func TestProvidersWithoutMock(t *testing.T) { - // Ensure mock is disabled - originalUseMock := UseMockProviders - UseMockProviders = false - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - // Reset providers to ensure we get fresh data - ResetProviders() - - // This will try to make an actual API call or use cached data - providers := Providers() - - // We can't guarantee what we'll get here since it depends on network/cache - // but we can at least verify the function doesn't panic - t.Logf("Got %d providers without mock", len(providers)) + assert.True(t, providerIDs[provider.InferenceProviderAnthropic]) + assert.True(t, providerIDs[provider.InferenceProviderOpenAI]) + assert.True(t, providerIDs[provider.InferenceProviderGemini]) } -func TestResetProviders(t *testing.T) { - // Enable mock providers +func TestProviders_ResetFunctionality(t *testing.T) { UseMockProviders = true defer func() { UseMockProviders = false ResetProviders() }() - // Get providers once providers1 := Providers() require.NotEmpty(t, providers1) - // Reset and get again ResetProviders() providers2 := Providers() require.NotEmpty(t, providers2) - // Should get the same mock data assert.Equal(t, len(providers1), len(providers2)) } -func TestReasoningEffortSupport(t *testing.T) { +func TestProviders_ModelCapabilities(t *testing.T) { originalUseMock := UseMockProviders UseMockProviders = true defer func() { @@ -125,156 +67,15 @@ func TestReasoningEffortSupport(t *testing.T) { } require.NotEmpty(t, openaiProvider.ID) - var reasoningModel, nonReasoningModel provider.Model + var foundReasoning, foundNonReasoning bool for _, model := range openaiProvider.Models { if model.CanReason && model.HasReasoningEffort { - reasoningModel = model - } else if !model.CanReason { - nonReasoningModel = model - } - } - - require.NotEmpty(t, reasoningModel.ID) - assert.Equal(t, "medium", reasoningModel.DefaultReasoningEffort) - assert.True(t, reasoningModel.HasReasoningEffort) - - require.NotEmpty(t, nonReasoningModel.ID) - assert.False(t, nonReasoningModel.HasReasoningEffort) - assert.Empty(t, nonReasoningModel.DefaultReasoningEffort) -} - -func TestReasoningEffortConfigTransfer(t *testing.T) { - originalUseMock := UseMockProviders - UseMockProviders = true - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - ResetProviders() - t.Setenv("OPENAI_API_KEY", "test-openai-key") - - cfg, err := Init(t.TempDir(), false) - require.NoError(t, err) - - openaiProviderConfig, exists := cfg.Providers[provider.InferenceProviderOpenAI] - require.True(t, exists) - - var foundReasoning, foundNonReasoning bool - for _, model := range openaiProviderConfig.Models { - if model.CanReason && model.HasReasoningEffort && model.ReasoningEffort != "" { - assert.Equal(t, "medium", model.ReasoningEffort) - assert.True(t, model.HasReasoningEffort) foundReasoning = true } else if !model.CanReason { - assert.Empty(t, model.ReasoningEffort) - assert.False(t, model.HasReasoningEffort) foundNonReasoning = true } } - assert.True(t, foundReasoning, "Should find at least one reasoning model") - assert.True(t, foundNonReasoning, "Should find at least one non-reasoning model") -} - -func TestNewProviders(t *testing.T) { - originalUseMock := UseMockProviders - UseMockProviders = true - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - ResetProviders() - providers := Providers() - require.NotEmpty(t, providers) - - var xaiProvider, openRouterProvider provider.Provider - for _, p := range providers { - switch p.ID { - case provider.InferenceProviderXAI: - xaiProvider = p - case provider.InferenceProviderOpenRouter: - openRouterProvider = p - } - } - - require.NotEmpty(t, xaiProvider.ID) - assert.Equal(t, "xAI", xaiProvider.Name) - assert.Equal(t, "grok-beta", xaiProvider.DefaultLargeModelID) - - require.NotEmpty(t, openRouterProvider.ID) - assert.Equal(t, "OpenRouter", openRouterProvider.Name) - assert.Equal(t, "anthropic/claude-3.5-sonnet", openRouterProvider.DefaultLargeModelID) -} - -func TestO1ModelsInMockProvider(t *testing.T) { - originalUseMock := UseMockProviders - UseMockProviders = true - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - ResetProviders() - providers := Providers() - - var openaiProvider provider.Provider - for _, p := range providers { - if p.ID == provider.InferenceProviderOpenAI { - openaiProvider = p - break - } - } - require.NotEmpty(t, openaiProvider.ID) - - modelTests := []struct { - id string - name string - }{ - {"o1-preview", "o1-preview"}, - {"o1-mini", "o1-mini"}, - } - - for _, test := range modelTests { - var model provider.Model - var found bool - for _, m := range openaiProvider.Models { - if m.ID == test.id { - model = m - found = true - break - } - } - require.True(t, found, "Should find %s model", test.id) - assert.Equal(t, test.name, model.Name) - assert.True(t, model.CanReason) - assert.True(t, model.HasReasoningEffort) - assert.Equal(t, "medium", model.DefaultReasoningEffort) - } -} - -func TestPreferredModelReasoningEffort(t *testing.T) { - // Test that PreferredModel struct can hold reasoning effort - preferredModel := PreferredModel{ - ModelID: "o1-preview", - Provider: provider.InferenceProviderOpenAI, - ReasoningEffort: "high", - } - - assert.Equal(t, "o1-preview", preferredModel.ModelID) - assert.Equal(t, provider.InferenceProviderOpenAI, preferredModel.Provider) - assert.Equal(t, "high", preferredModel.ReasoningEffort) - - // Test JSON marshaling/unmarshaling - jsonData, err := json.Marshal(preferredModel) - require.NoError(t, err) - - var unmarshaled PreferredModel - err = json.Unmarshal(jsonData, &unmarshaled) - require.NoError(t, err) - - assert.Equal(t, preferredModel.ModelID, unmarshaled.ModelID) - assert.Equal(t, preferredModel.Provider, unmarshaled.Provider) - assert.Equal(t, preferredModel.ReasoningEffort, unmarshaled.ReasoningEffort) -} + assert.True(t, foundReasoning) + assert.True(t, foundNonReasoning) +} \ No newline at end of file diff --git a/internal/config/shell.go b/internal/config/shell.go index a12ecd1da3b82c113175a1f4825877a7fb94a95c..74931bfefc3a9e16e830fac2c3478a6f0d5396f2 100644 --- a/internal/config/shell.go +++ b/internal/config/shell.go @@ -71,4 +71,3 @@ func resolveCommandAPIKey(command string) (string, error) { logging.Debug("Command executed successfully", "command", command, "result", result) return result, nil } -