From b9bedbae80046a5ae03be38e897bba96661b28d2 Mon Sep 17 00:00:00 2001 From: Bryan Vaz <9157498+bryanvaz@users.noreply.github.com> Date: Wed, 25 Jun 2025 15:23:40 -0400 Subject: [PATCH 01/40] feat: add github copilot provider (#230) * feat: add github copilot * fix: add support for claude4 --- .gitignore | 1 + README.md | 72 +++- internal/config/config.go | 116 ++++- internal/llm/agent/agent.go | 26 +- internal/llm/models/copilot.go | 219 ++++++++++ internal/llm/models/models.go | 18 +- internal/llm/provider/anthropic.go | 24 +- internal/llm/provider/copilot.go | 671 +++++++++++++++++++++++++++++ internal/llm/provider/provider.go | 12 + internal/llm/tools/view.go | 2 + internal/logging/logger.go | 133 +++++- internal/logging/writer.go | 1 + opencode-schema.json | 29 +- 13 files changed, 1276 insertions(+), 48 deletions(-) create mode 100644 internal/llm/models/copilot.go create mode 100644 internal/llm/provider/copilot.go diff --git a/.gitignore b/.gitignore index 36ff9c73267bcc5c7b8ece367108972dad21c1e2..3a206a7f28440a2632f13402e74d0567613e0419 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,4 @@ Thumbs.db .opencode/ opencode +opencode.md diff --git a/README.md b/README.md index 0d3e299daeafa9cf4e2a340fd5f1d77f99fb0efb..eee06acd9aa1cc12b2a829ade101adbbffc63185 100644 --- a/README.md +++ b/README.md @@ -96,22 +96,23 @@ You can enable or disable this feature in your configuration file: You can configure OpenCode using environment variables: -| Environment Variable | Purpose | -| -------------------------- | ------------------------------------------------------ | -| `ANTHROPIC_API_KEY` | For Claude models | -| `OPENAI_API_KEY` | For OpenAI models | -| `GEMINI_API_KEY` | For Google Gemini models | -| `VERTEXAI_PROJECT` | For Google Cloud VertexAI (Gemini) | -| `VERTEXAI_LOCATION` | For Google Cloud VertexAI (Gemini) | -| `GROQ_API_KEY` | For Groq models | -| `AWS_ACCESS_KEY_ID` | For AWS Bedrock (Claude) | -| `AWS_SECRET_ACCESS_KEY` | For AWS Bedrock (Claude) | -| `AWS_REGION` | For AWS Bedrock (Claude) | -| `AZURE_OPENAI_ENDPOINT` | For Azure OpenAI models | -| `AZURE_OPENAI_API_KEY` | For Azure OpenAI models (optional when using Entra ID) | -| `AZURE_OPENAI_API_VERSION` | For Azure OpenAI models | -| `LOCAL_ENDPOINT` | For self-hosted models | -| `SHELL` | Default shell to use (if not specified in config) | +| Environment Variable | Purpose | +| -------------------------- | -------------------------------------------------------------------------------- | +| `ANTHROPIC_API_KEY` | For Claude models | +| `OPENAI_API_KEY` | For OpenAI models | +| `GEMINI_API_KEY` | For Google Gemini models | +| `GITHUB_TOKEN` | For Github Copilot models (see [Using Github Copilot](#using-github-copilot)) | +| `VERTEXAI_PROJECT` | For Google Cloud VertexAI (Gemini) | +| `VERTEXAI_LOCATION` | For Google Cloud VertexAI (Gemini) | +| `GROQ_API_KEY` | For Groq models | +| `AWS_ACCESS_KEY_ID` | For AWS Bedrock (Claude) | +| `AWS_SECRET_ACCESS_KEY` | For AWS Bedrock (Claude) | +| `AWS_REGION` | For AWS Bedrock (Claude) | +| `AZURE_OPENAI_ENDPOINT` | For Azure OpenAI models | +| `AZURE_OPENAI_API_KEY` | For Azure OpenAI models (optional when using Entra ID) | +| `AZURE_OPENAI_API_VERSION` | For Azure OpenAI models | +| `LOCAL_ENDPOINT` | For self-hosted models | +| `SHELL` | Default shell to use (if not specified in config) | ### Shell Configuration @@ -146,6 +147,9 @@ This is useful if you want to use a different shell than your default system she "apiKey": "your-api-key", "disabled": false }, + "copilot": { + "disabled": false + }, "groq": { "apiKey": "your-api-key", "disabled": false @@ -216,6 +220,23 @@ OpenCode supports a variety of AI models from different providers: - Claude 3 Haiku - Claude 3 Opus +### GitHub Copilot + +- GPT-3.5 Turbo +- GPT-4 +- GPT-4o +- GPT-4o Mini +- GPT-4.1 +- Claude 3.5 Sonnet +- Claude 3.7 Sonnet +- Claude 3.7 Sonnet Thinking +- Claude Sonnet 4 +- O1 +- O3 Mini +- O4 Mini +- Gemini 2.0 Flash +- Gemini 2.5 Pro + ### Google - Gemini 2.5 @@ -579,6 +600,25 @@ The AI assistant can access LSP features through the `diagnostics` tool, allowin While the LSP client implementation supports the full LSP protocol (including completions, hover, definition, etc.), currently only diagnostics are exposed to the AI assistant. +## Using Github Copilot + +_Copilot support is currently experimental._ + +### Requirements +- [Copilot chat in the IDE](https://github.com/settings/copilot) enabled in GitHub settings +- One of: + - VSCode Github Copilot chat extension + - Github `gh` CLI + - Neovim Github Copilot plugin (`copilot.vim` or `copilot.lua`) + - Github token with copilot permissions + +If using one of the above plugins or cli tools, make sure you use the authenticate +the tool with your github account. This should create a github token at one of the following locations: +- ~/.config/github-copilot/[hosts,apps].json +- $XDG_CONFIG_HOME/github-copilot/[hosts,apps].json + +If using an explicit github token, you may either set the $GITHUB_TOKEN environment variable or add it to the opencode.json config file at `providers.copilot.apiKey`. + ## Using a self-hosted model provider OpenCode can also load and use models from a self-hosted (OpenAI-like) provider. diff --git a/internal/config/config.go b/internal/config/config.go index 5a0905bba239c0d7c79f669801ef9b3a5caa9cf9..630fac9b6e9375a6dd192ca9df6582bd1e4a92c3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,6 +7,7 @@ import ( "log/slog" "os" "path/filepath" + "runtime" "strings" "github.com/opencode-ai/opencode/internal/llm/models" @@ -161,6 +162,7 @@ func Load(workingDir string, debug bool) (*Config, error) { } if os.Getenv("OPENCODE_DEV_DEBUG") == "true" { loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log") + messagesPath := fmt.Sprintf("%s/%s", cfg.Data.Directory, "messages") // if file does not exist create it if _, err := os.Stat(loggingFile); os.IsNotExist(err) { @@ -172,6 +174,13 @@ func Load(workingDir string, debug bool) (*Config, error) { } } + if _, err := os.Stat(messagesPath); os.IsNotExist(err) { + if err := os.MkdirAll(messagesPath, 0o756); err != nil { + return cfg, fmt.Errorf("failed to create directory: %w", err) + } + } + logging.MessageDir = messagesPath + sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) if err != nil { return cfg, fmt.Errorf("failed to open log file: %w", err) @@ -245,6 +254,7 @@ func setDefaults(debug bool) { // environment variables and configuration file. func setProviderDefaults() { // Set all API keys we can find in the environment + // Note: Viper does not default if the json apiKey is "" if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { viper.SetDefault("providers.anthropic.apiKey", apiKey) } @@ -267,16 +277,32 @@ func setProviderDefaults() { // api-key may be empty when using Entra ID credentials – that's okay viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY")) } + if apiKey, err := LoadGitHubToken(); err == nil && apiKey != "" { + viper.SetDefault("providers.copilot.apiKey", apiKey) + if viper.GetString("providers.copilot.apiKey") == "" { + viper.Set("providers.copilot.apiKey", apiKey) + } + } // Use this order to set the default models - // 1. Anthropic - // 2. OpenAI - // 3. Google Gemini - // 4. Groq - // 5. OpenRouter - // 6. AWS Bedrock - // 7. Azure - // 8. Google Cloud VertexAI + // 1. Copilot + // 2. Anthropic + // 3. OpenAI + // 4. Google Gemini + // 5. Groq + // 6. OpenRouter + // 7. AWS Bedrock + // 8. Azure + // 9. Google Cloud VertexAI + + // copilot configuration + if key := viper.GetString("providers.copilot.apiKey"); strings.TrimSpace(key) != "" { + viper.SetDefault("agents.coder.model", models.CopilotGPT4o) + viper.SetDefault("agents.summarizer.model", models.CopilotGPT4o) + viper.SetDefault("agents.task.model", models.CopilotGPT4o) + viper.SetDefault("agents.title.model", models.CopilotGPT4o) + return + } // Anthropic configuration if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" { @@ -399,6 +425,14 @@ func hasVertexAICredentials() bool { return false } +func hasCopilotCredentials() bool { + // Check for explicit Copilot parameters + if token, _ := LoadGitHubToken(); token != "" { + return true + } + return false +} + // readConfig handles the result of reading a configuration file. func readConfig(err error) error { if err == nil { @@ -440,6 +474,9 @@ func applyDefaultValues() { // It validates model IDs and providers, ensuring they are supported. func validateAgent(cfg *Config, name AgentName, agent Agent) error { // Check if model exists + // TODO: If a copilot model is specified, but model is not found, + // it might be new model. The https://api.githubcopilot.com/models + // endpoint should be queried to validate if the model is supported. model, modelExists := models.SupportedModels[agent.Model] if !modelExists { logging.Warn("unsupported model configured, reverting to default", @@ -584,6 +621,7 @@ func Validate() error { // Validate providers for provider, providerCfg := range cfg.Providers { if providerCfg.APIKey == "" && !providerCfg.Disabled { + fmt.Printf("provider has no API key, marking as disabled %s", provider) logging.Warn("provider has no API key, marking as disabled", "provider", provider) providerCfg.Disabled = true cfg.Providers[provider] = providerCfg @@ -631,6 +669,18 @@ func getProviderAPIKey(provider models.ModelProvider) string { // setDefaultModelForAgent sets a default model for an agent based on available providers func setDefaultModelForAgent(agent AgentName) bool { + if hasCopilotCredentials() { + maxTokens := int64(5000) + if agent == AgentTitle { + maxTokens = 80 + } + + cfg.Agents[agent] = Agent{ + Model: models.CopilotGPT4o, + MaxTokens: maxTokens, + } + return true + } // Check providers in order of preference if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { maxTokens := int64(5000) @@ -878,3 +928,53 @@ func UpdateTheme(themeName string) error { config.TUI.Theme = themeName }) } + +// Tries to load Github token from all possible locations +func LoadGitHubToken() (string, error) { + // First check environment variable + if token := os.Getenv("GITHUB_TOKEN"); token != "" { + return token, nil + } + + // Get config directory + var configDir string + if xdgConfig := os.Getenv("XDG_CONFIG_HOME"); xdgConfig != "" { + configDir = xdgConfig + } else if runtime.GOOS == "windows" { + if localAppData := os.Getenv("LOCALAPPDATA"); localAppData != "" { + configDir = localAppData + } else { + configDir = filepath.Join(os.Getenv("HOME"), "AppData", "Local") + } + } else { + configDir = filepath.Join(os.Getenv("HOME"), ".config") + } + + // Try both hosts.json and apps.json files + filePaths := []string{ + filepath.Join(configDir, "github-copilot", "hosts.json"), + filepath.Join(configDir, "github-copilot", "apps.json"), + } + + for _, filePath := range filePaths { + data, err := os.ReadFile(filePath) + if err != nil { + continue + } + + var config map[string]map[string]interface{} + if err := json.Unmarshal(data, &config); err != nil { + continue + } + + for key, value := range config { + if strings.Contains(key, "github.com") { + if oauthToken, ok := value["oauth_token"].(string); ok { + return oauthToken, nil + } + } + } + } + + return "", fmt.Errorf("GitHub token not found in standard locations") +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 4f31fe75d688aa2c4fdd80a4f633fe35d45125cc..20b10fd374fdf67867e0b626058bc68b8d7496d5 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -162,6 +162,7 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str if err != nil { return err } + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) parts := []message.ContentPart{message.TextContent{Text: content}} response, err := a.titleProvider.SendMessages( ctx, @@ -230,6 +231,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac } func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent { + cfg := config.Get() // List existing messages; if none, start title generation asynchronously. msgs, err := a.messages.List(ctx, sessionID) if err != nil { @@ -288,7 +290,13 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string } return a.err(fmt.Errorf("failed to process events: %w", err)) } - logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) + if cfg.Debug { + seqId := (len(msgHistory) + 1) / 2 + toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults) + logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath) + } else { + logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) + } if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { // We are not done, we need to respond with the tool response msgHistory = append(msgHistory, agentMessage, *toolResults) @@ -312,6 +320,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string } func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ @@ -325,7 +334,6 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg // Add the session and message ID into the context if needed by tools. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) - ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) // Process each event in the stream. for event := range eventChan { @@ -357,10 +365,17 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg default: // Continue processing var tool tools.BaseTool - for _, availableTools := range a.tools { - if availableTools.Info().Name == toolCall.Name { - tool = availableTools + for _, availableTool := range a.tools { + if availableTool.Info().Name == toolCall.Name { + tool = availableTool + break } + // Monkey patch for Copilot Sonnet-4 tool repetition obfuscation + // if strings.HasPrefix(toolCall.Name, availableTool.Info().Name) && + // strings.HasPrefix(toolCall.Name, availableTool.Info().Name+availableTool.Info().Name) { + // tool = availableTool + // break + // } } // Tool not found @@ -553,6 +568,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { a.Publish(pubsub.CreatedEvent, event) return } + summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID) if len(msgs) == 0 { event = AgentEvent{ diff --git a/internal/llm/models/copilot.go b/internal/llm/models/copilot.go new file mode 100644 index 0000000000000000000000000000000000000000..f6ec91cddf6c8537e7b1ff66ff4240071dc62f3c --- /dev/null +++ b/internal/llm/models/copilot.go @@ -0,0 +1,219 @@ +package models + +const ( + ProviderCopilot ModelProvider = "copilot" + + // GitHub Copilot models + CopilotGTP35Turbo ModelID = "copilot.gpt-3.5-turbo" + CopilotGPT4o ModelID = "copilot.gpt-4o" + CopilotGPT4oMini ModelID = "copilot.gpt-4o-mini" + CopilotGPT41 ModelID = "copilot.gpt-4.1" + CopilotClaude35 ModelID = "copilot.claude-3.5-sonnet" + CopilotClaude37 ModelID = "copilot.claude-3.7-sonnet" + CopilotClaude4 ModelID = "copilot.claude-sonnet-4" + CopilotO1 ModelID = "copilot.o1" + CopilotO3Mini ModelID = "copilot.o3-mini" + CopilotO4Mini ModelID = "copilot.o4-mini" + CopilotGemini20 ModelID = "copilot.gemini-2.0-flash" + CopilotGemini25 ModelID = "copilot.gemini-2.5-pro" + CopilotGPT4 ModelID = "copilot.gpt-4" + CopilotClaude37Thought ModelID = "copilot.claude-3.7-sonnet-thought" +) + +var CopilotAnthropicModels = []ModelID{ + CopilotClaude35, + CopilotClaude37, + CopilotClaude37Thought, + CopilotClaude4, +} + +// GitHub Copilot models available through GitHub's API +var CopilotModels = map[ModelID]Model{ + CopilotGTP35Turbo: { + ID: CopilotGTP35Turbo, + Name: "GitHub Copilot GPT-3.5-turbo", + Provider: ProviderCopilot, + APIModel: "gpt-3.5-turbo", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 16_384, + DefaultMaxTokens: 4096, + SupportsAttachments: true, + }, + CopilotGPT4o: { + ID: CopilotGPT4o, + Name: "GitHub Copilot GPT-4o", + Provider: ProviderCopilot, + APIModel: "gpt-4o", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 128_000, + DefaultMaxTokens: 16384, + SupportsAttachments: true, + }, + CopilotGPT4oMini: { + ID: CopilotGPT4oMini, + Name: "GitHub Copilot GPT-4o Mini", + Provider: ProviderCopilot, + APIModel: "gpt-4o-mini", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 128_000, + DefaultMaxTokens: 4096, + SupportsAttachments: true, + }, + CopilotGPT41: { + ID: CopilotGPT41, + Name: "GitHub Copilot GPT-4.1", + Provider: ProviderCopilot, + APIModel: "gpt-4.1", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 128_000, + DefaultMaxTokens: 16384, + CanReason: true, + SupportsAttachments: true, + }, + CopilotClaude35: { + ID: CopilotClaude35, + Name: "GitHub Copilot Claude 3.5 Sonnet", + Provider: ProviderCopilot, + APIModel: "claude-3.5-sonnet", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 90_000, + DefaultMaxTokens: 8192, + SupportsAttachments: true, + }, + CopilotClaude37: { + ID: CopilotClaude37, + Name: "GitHub Copilot Claude 3.7 Sonnet", + Provider: ProviderCopilot, + APIModel: "claude-3.7-sonnet", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 200_000, + DefaultMaxTokens: 16384, + SupportsAttachments: true, + }, + CopilotClaude4: { + ID: CopilotClaude4, + Name: "GitHub Copilot Claude Sonnet 4", + Provider: ProviderCopilot, + APIModel: "claude-sonnet-4", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 128_000, + DefaultMaxTokens: 16000, + SupportsAttachments: true, + }, + CopilotO1: { + ID: CopilotO1, + Name: "GitHub Copilot o1", + Provider: ProviderCopilot, + APIModel: "o1", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 200_000, + DefaultMaxTokens: 100_000, + CanReason: true, + SupportsAttachments: false, + }, + CopilotO3Mini: { + ID: CopilotO3Mini, + Name: "GitHub Copilot o3-mini", + Provider: ProviderCopilot, + APIModel: "o3-mini", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 200_000, + DefaultMaxTokens: 100_000, + CanReason: true, + SupportsAttachments: false, + }, + CopilotO4Mini: { + ID: CopilotO4Mini, + Name: "GitHub Copilot o4-mini", + Provider: ProviderCopilot, + APIModel: "o4-mini", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 128_000, + DefaultMaxTokens: 16_384, + CanReason: true, + SupportsAttachments: true, + }, + CopilotGemini20: { + ID: CopilotGemini20, + Name: "GitHub Copilot Gemini 2.0 Flash", + Provider: ProviderCopilot, + APIModel: "gemini-2.0-flash-001", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 1_000_000, + DefaultMaxTokens: 8192, + SupportsAttachments: true, + }, + CopilotGemini25: { + ID: CopilotGemini25, + Name: "GitHub Copilot Gemini 2.5 Pro", + Provider: ProviderCopilot, + APIModel: "gemini-2.5-pro", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 128_000, + DefaultMaxTokens: 64000, + SupportsAttachments: true, + }, + CopilotGPT4: { + ID: CopilotGPT4, + Name: "GitHub Copilot GPT-4", + Provider: ProviderCopilot, + APIModel: "gpt-4", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 32_768, + DefaultMaxTokens: 4096, + SupportsAttachments: true, + }, + CopilotClaude37Thought: { + ID: CopilotClaude37Thought, + Name: "GitHub Copilot Claude 3.7 Sonnet Thinking", + Provider: ProviderCopilot, + APIModel: "claude-3.7-sonnet-thought", + CostPer1MIn: 0.0, // Included in GitHub Copilot subscription + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.0, + ContextWindow: 200_000, + DefaultMaxTokens: 16384, + CanReason: true, + SupportsAttachments: true, + }, +} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 47d217184de54f7e2937286cd2c64c9e98c4a02b..2bcb508e9928722a1bd2cc4151bf43529645dc74 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -36,14 +36,15 @@ const ( // Providers in order of popularity var ProviderPopularity = map[ModelProvider]int{ - ProviderAnthropic: 1, - ProviderOpenAI: 2, - ProviderGemini: 3, - ProviderGROQ: 4, - ProviderOpenRouter: 5, - ProviderBedrock: 6, - ProviderAzure: 7, - ProviderVertexAI: 8, + ProviderCopilot: 1, + ProviderAnthropic: 2, + ProviderOpenAI: 3, + ProviderGemini: 4, + ProviderGROQ: 5, + ProviderOpenRouter: 6, + ProviderBedrock: 7, + ProviderAzure: 8, + ProviderVertexAI: 9, } var SupportedModels = map[ModelID]Model{ @@ -93,4 +94,5 @@ func init() { maps.Copy(SupportedModels, OpenRouterModels) maps.Copy(SupportedModels, XAIModels) maps.Copy(SupportedModels, VertexAIGeminiModels) + maps.Copy(SupportedModels, CopilotModels) } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index badf6a3a07df27f6494bdbf9692f174e0a17a1ce..213d4b94a34beb4f858ecf2fed99da040a870f32 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -14,7 +14,7 @@ import ( "github.com/anthropics/anthropic-sdk-go/option" "github.com/opencode-ai/opencode/internal/config" "github.com/opencode-ai/opencode/internal/llm/models" - "github.com/opencode-ai/opencode/internal/llm/tools" + toolsPkg "github.com/opencode-ai/opencode/internal/llm/tools" "github.com/opencode-ai/opencode/internal/logging" "github.com/opencode-ai/opencode/internal/message" ) @@ -118,7 +118,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic return } -func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { +func (a *anthropicClient) convertTools(tools []toolsPkg.BaseTool) []anthropic.ToolUnionParam { anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) for i, tool := range tools { @@ -195,7 +195,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to } } -func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) { +func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []toolsPkg.BaseTool) (resposne *ProviderResponse, err error) { preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) cfg := config.Get() if cfg.Debug { @@ -244,12 +244,24 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message, } } -func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { +func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []toolsPkg.BaseTool) <-chan ProviderEvent { preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) cfg := config.Get() + + var sessionId string + requestSeqId := (len(messages) + 1) / 2 if cfg.Debug { - // jsonData, _ := json.Marshal(preparedMessages) - // logging.Debug("Prepared messages", "messages", string(jsonData)) + if sid, ok := ctx.Value(toolsPkg.SessionIDContextKey).(string); ok { + sessionId = sid + } + jsonData, _ := json.Marshal(preparedMessages) + if sessionId != "" { + filepath := logging.WriteRequestMessageJson(sessionId, requestSeqId, preparedMessages) + logging.Debug("Prepared messages", "filepath", filepath) + } else { + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + } attempts := 0 eventChan := make(chan ProviderEvent) diff --git a/internal/llm/provider/copilot.go b/internal/llm/provider/copilot.go new file mode 100644 index 0000000000000000000000000000000000000000..5d70e718ae781d086711884b550a3995d720edff --- /dev/null +++ b/internal/llm/provider/copilot.go @@ -0,0 +1,671 @@ +package provider + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "time" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/shared" + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + toolsPkg "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/logging" + "github.com/opencode-ai/opencode/internal/message" +) + +type copilotOptions struct { + reasoningEffort string + extraHeaders map[string]string + bearerToken string +} + +type CopilotOption func(*copilotOptions) + +type copilotClient struct { + providerOptions providerClientOptions + options copilotOptions + client openai.Client + httpClient *http.Client +} + +type CopilotClient ProviderClient + +// CopilotTokenResponse represents the response from GitHub's token exchange endpoint +type CopilotTokenResponse struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` +} + +func (c *copilotClient) isAnthropicModel() bool { + for _, modelId := range models.CopilotAnthropicModels { + if c.providerOptions.model.ID == modelId { + return true + } + } + return false +} + +// loadGitHubToken loads the GitHub OAuth token from the standard GitHub CLI/Copilot locations + +// exchangeGitHubToken exchanges a GitHub token for a Copilot bearer token +func (c *copilotClient) exchangeGitHubToken(githubToken string) (string, error) { + req, err := http.NewRequest("GET", "https://api.github.com/copilot_internal/v2/token", nil) + if err != nil { + return "", fmt.Errorf("failed to create token exchange request: %w", err) + } + + req.Header.Set("Authorization", "Token "+githubToken) + req.Header.Set("User-Agent", "OpenCode/1.0") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to exchange GitHub token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp CopilotTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", fmt.Errorf("failed to decode token response: %w", err) + } + + return tokenResp.Token, nil +} + +func newCopilotClient(opts providerClientOptions) CopilotClient { + copilotOpts := copilotOptions{ + reasoningEffort: "medium", + } + // Apply copilot-specific options + for _, o := range opts.copilotOptions { + o(&copilotOpts) + } + + // Create HTTP client for token exchange + httpClient := &http.Client{ + Timeout: 30 * time.Second, + } + + var bearerToken string + + // If bearer token is already provided, use it + if copilotOpts.bearerToken != "" { + bearerToken = copilotOpts.bearerToken + } else { + // Try to get GitHub token from multiple sources + var githubToken string + + // 1. Environment variable + githubToken = os.Getenv("GITHUB_TOKEN") + + // 2. API key from options + if githubToken == "" { + githubToken = opts.apiKey + } + + // 3. Standard GitHub CLI/Copilot locations + if githubToken == "" { + var err error + githubToken, err = config.LoadGitHubToken() + if err != nil { + logging.Debug("Failed to load GitHub token from standard locations", "error", err) + } + } + + if githubToken == "" { + logging.Error("GitHub token is required for Copilot provider. Set GITHUB_TOKEN environment variable, configure it in opencode.json, or ensure GitHub CLI/Copilot is properly authenticated.") + return &copilotClient{ + providerOptions: opts, + options: copilotOpts, + httpClient: httpClient, + } + } + + // Create a temporary client for token exchange + tempClient := &copilotClient{ + providerOptions: opts, + options: copilotOpts, + httpClient: httpClient, + } + + // Exchange GitHub token for bearer token + var err error + bearerToken, err = tempClient.exchangeGitHubToken(githubToken) + if err != nil { + logging.Error("Failed to exchange GitHub token for Copilot bearer token", "error", err) + return &copilotClient{ + providerOptions: opts, + options: copilotOpts, + httpClient: httpClient, + } + } + } + + copilotOpts.bearerToken = bearerToken + + // GitHub Copilot API base URL + baseURL := "https://api.githubcopilot.com" + + openaiClientOptions := []option.RequestOption{ + option.WithBaseURL(baseURL), + option.WithAPIKey(bearerToken), // Use bearer token as API key + } + + // Add GitHub Copilot specific headers + openaiClientOptions = append(openaiClientOptions, + option.WithHeader("Editor-Version", "OpenCode/1.0"), + option.WithHeader("Editor-Plugin-Version", "OpenCode/1.0"), + option.WithHeader("Copilot-Integration-Id", "vscode-chat"), + ) + + // Add any extra headers + if copilotOpts.extraHeaders != nil { + for key, value := range copilotOpts.extraHeaders { + openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) + } + } + + client := openai.NewClient(openaiClientOptions...) + // logging.Debug("Copilot client created", "opts", opts, "copilotOpts", copilotOpts, "model", opts.model) + return &copilotClient{ + providerOptions: opts, + options: copilotOpts, + client: client, + httpClient: httpClient, + } +} + +func (c *copilotClient) convertMessages(messages []message.Message) (copilotMessages []openai.ChatCompletionMessageParamUnion) { + // Add system message first + copilotMessages = append(copilotMessages, openai.SystemMessage(c.providerOptions.systemMessage)) + + for _, msg := range messages { + switch msg.Role { + case message.User: + var content []openai.ChatCompletionContentPartUnionParam + textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} + content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) + + for _, binaryContent := range msg.BinaryContent() { + imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderCopilot)} + imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} + content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) + } + + copilotMessages = append(copilotMessages, openai.UserMessage(content)) + + case message.Assistant: + assistantMsg := openai.ChatCompletionAssistantMessageParam{ + Role: "assistant", + } + + if msg.Content().String() != "" { + assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ + OfString: openai.String(msg.Content().String()), + } + } + + if len(msg.ToolCalls()) > 0 { + assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls())) + for i, call := range msg.ToolCalls() { + assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{ + ID: call.ID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: call.Name, + Arguments: call.Input, + }, + } + } + } + + copilotMessages = append(copilotMessages, openai.ChatCompletionMessageParamUnion{ + OfAssistant: &assistantMsg, + }) + + case message.Tool: + for _, result := range msg.ToolResults() { + copilotMessages = append(copilotMessages, + openai.ToolMessage(result.Content, result.ToolCallID), + ) + } + } + } + + return +} + +func (c *copilotClient) convertTools(tools []toolsPkg.BaseTool) []openai.ChatCompletionToolParam { + copilotTools := make([]openai.ChatCompletionToolParam, len(tools)) + + for i, tool := range tools { + info := tool.Info() + copilotTools[i] = openai.ChatCompletionToolParam{ + Function: openai.FunctionDefinitionParam{ + Name: info.Name, + Description: openai.String(info.Description), + Parameters: openai.FunctionParameters{ + "type": "object", + "properties": info.Parameters, + "required": info.Required, + }, + }, + } + } + + return copilotTools +} + +func (c *copilotClient) finishReason(reason string) message.FinishReason { + switch reason { + case "stop": + return message.FinishReasonEndTurn + case "length": + return message.FinishReasonMaxTokens + case "tool_calls": + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown + } +} + +func (c *copilotClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + params := openai.ChatCompletionNewParams{ + Model: openai.ChatModel(c.providerOptions.model.APIModel), + Messages: messages, + Tools: tools, + } + + if c.providerOptions.model.CanReason == true { + params.MaxCompletionTokens = openai.Int(c.providerOptions.maxTokens) + switch c.options.reasoningEffort { + case "low": + params.ReasoningEffort = shared.ReasoningEffortLow + case "medium": + params.ReasoningEffort = shared.ReasoningEffortMedium + case "high": + params.ReasoningEffort = shared.ReasoningEffortHigh + default: + params.ReasoningEffort = shared.ReasoningEffortMedium + } + } else { + params.MaxTokens = openai.Int(c.providerOptions.maxTokens) + } + + return params +} + +func (c *copilotClient) send(ctx context.Context, messages []message.Message, tools []toolsPkg.BaseTool) (response *ProviderResponse, err error) { + params := c.preparedParams(c.convertMessages(messages), c.convertTools(tools)) + cfg := config.Get() + var sessionId string + requestSeqId := (len(messages) + 1) / 2 + if cfg.Debug { + // jsonData, _ := json.Marshal(params) + // logging.Debug("Prepared messages", "messages", string(jsonData)) + if sid, ok := ctx.Value(toolsPkg.SessionIDContextKey).(string); ok { + sessionId = sid + } + jsonData, _ := json.Marshal(params) + if sessionId != "" { + filepath := logging.WriteRequestMessageJson(sessionId, requestSeqId, params) + logging.Debug("Prepared messages", "filepath", filepath) + } else { + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + } + + attempts := 0 + for { + attempts++ + copilotResponse, err := c.client.Chat.Completions.New( + ctx, + params, + ) + + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := c.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr + } + if retry { + logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr + } + + content := "" + if copilotResponse.Choices[0].Message.Content != "" { + content = copilotResponse.Choices[0].Message.Content + } + + toolCalls := c.toolCalls(*copilotResponse) + finishReason := c.finishReason(string(copilotResponse.Choices[0].FinishReason)) + + if len(toolCalls) > 0 { + finishReason = message.FinishReasonToolUse + } + + return &ProviderResponse{ + Content: content, + ToolCalls: toolCalls, + Usage: c.usage(*copilotResponse), + FinishReason: finishReason, + }, nil + } +} + +func (c *copilotClient) stream(ctx context.Context, messages []message.Message, tools []toolsPkg.BaseTool) <-chan ProviderEvent { + params := c.preparedParams(c.convertMessages(messages), c.convertTools(tools)) + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + } + + cfg := config.Get() + var sessionId string + requestSeqId := (len(messages) + 1) / 2 + if cfg.Debug { + if sid, ok := ctx.Value(toolsPkg.SessionIDContextKey).(string); ok { + sessionId = sid + } + jsonData, _ := json.Marshal(params) + if sessionId != "" { + filepath := logging.WriteRequestMessageJson(sessionId, requestSeqId, params) + logging.Debug("Prepared messages", "filepath", filepath) + } else { + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + + } + + attempts := 0 + eventChan := make(chan ProviderEvent) + + go func() { + for { + attempts++ + copilotStream := c.client.Chat.Completions.NewStreaming( + ctx, + params, + ) + + acc := openai.ChatCompletionAccumulator{} + currentContent := "" + toolCalls := make([]message.ToolCall, 0) + + var currentToolCallId string + var currentToolCall openai.ChatCompletionMessageToolCall + var msgToolCalls []openai.ChatCompletionMessageToolCall + for copilotStream.Next() { + chunk := copilotStream.Current() + acc.AddChunk(chunk) + + if cfg.Debug { + logging.AppendToStreamSessionLogJson(sessionId, requestSeqId, chunk) + } + + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: choice.Delta.Content, + } + currentContent += choice.Delta.Content + } + } + + if c.isAnthropicModel() { + // Monkeypatch adapter for Sonnet-4 multi-tool use + for _, choice := range chunk.Choices { + if choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0 { + toolCall := choice.Delta.ToolCalls[0] + // Detect tool use start + if currentToolCallId == "" { + if toolCall.ID != "" { + currentToolCallId = toolCall.ID + currentToolCall = openai.ChatCompletionMessageToolCall{ + ID: toolCall.ID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunction{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + } + } else { + // Delta tool use + if toolCall.ID == "" { + currentToolCall.Function.Arguments += toolCall.Function.Arguments + } else { + // Detect new tool use + if toolCall.ID != currentToolCallId { + msgToolCalls = append(msgToolCalls, currentToolCall) + currentToolCallId = toolCall.ID + currentToolCall = openai.ChatCompletionMessageToolCall{ + ID: toolCall.ID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunction{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + } + } + } + } + if choice.FinishReason == "tool_calls" { + msgToolCalls = append(msgToolCalls, currentToolCall) + acc.ChatCompletion.Choices[0].Message.ToolCalls = msgToolCalls + } + } + } + } + + err := copilotStream.Err() + if err == nil || errors.Is(err, io.EOF) { + if cfg.Debug { + respFilepath := logging.WriteChatResponseJson(sessionId, requestSeqId, acc.ChatCompletion) + logging.Debug("Chat completion response", "filepath", respFilepath) + } + // Stream completed successfully + finishReason := c.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)) + if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 { + toolCalls = append(toolCalls, c.toolCalls(acc.ChatCompletion)...) + } + if len(toolCalls) > 0 { + finishReason = message.FinishReasonToolUse + } + + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: c.usage(acc.ChatCompletion), + FinishReason: finishReason, + }, + } + close(eventChan) + return + } + + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := c.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) + return + } + // shouldRetry is not catching the max retries... + // TODO: Figure out why + if attempts > maxRetries { + logging.Warn("Maximum retry attempts reached for rate limit", "attempts", attempts, "max_retries", maxRetries) + retry = false + } + if retry { + logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d (paused for %d ms)", attempts, maxRetries, after), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() == nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) + return + } + }() + + return eventChan +} + +func (c *copilotClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *openai.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + // Check for token expiration (401 Unauthorized) + if apierr.StatusCode == 401 { + // Try to refresh the bearer token + var githubToken string + + // 1. Environment variable + githubToken = os.Getenv("GITHUB_TOKEN") + + // 2. API key from options + if githubToken == "" { + githubToken = c.providerOptions.apiKey + } + + // 3. Standard GitHub CLI/Copilot locations + if githubToken == "" { + var err error + githubToken, err = config.LoadGitHubToken() + if err != nil { + logging.Debug("Failed to load GitHub token from standard locations during retry", "error", err) + } + } + + if githubToken != "" { + newBearerToken, tokenErr := c.exchangeGitHubToken(githubToken) + if tokenErr == nil { + c.options.bearerToken = newBearerToken + // Update the client with the new token + // Note: This is a simplified approach. In a production system, + // you might want to recreate the entire client with the new token + logging.Info("Refreshed Copilot bearer token") + return true, 1000, nil // Retry immediately with new token + } + logging.Error("Failed to refresh Copilot bearer token", "error", tokenErr) + } + return false, 0, fmt.Errorf("authentication failed: %w", err) + } + logging.Debug("Copilot API Error", "status", apierr.StatusCode, "headers", apierr.Response.Header, "body", apierr.RawJSON()) + + if apierr.StatusCode != 429 && apierr.StatusCode != 500 { + return false, 0, err + } + + if apierr.StatusCode == 500 { + logging.Warn("Copilot API returned 500 error, retrying", "error", err) + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 + } + } + return true, int64(retryMs), nil +} + +func (c *copilotClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { + for _, call := range completion.Choices[0].Message.ToolCalls { + toolCall := message.ToolCall{ + ID: call.ID, + Name: call.Function.Name, + Input: call.Function.Arguments, + Type: "function", + Finished: true, + } + toolCalls = append(toolCalls, toolCall) + } + } + + return toolCalls +} + +func (c *copilotClient) usage(completion openai.ChatCompletion) TokenUsage { + cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens + inputTokens := completion.Usage.PromptTokens - cachedTokens + + return TokenUsage{ + InputTokens: inputTokens, + OutputTokens: completion.Usage.CompletionTokens, + CacheCreationTokens: 0, // GitHub Copilot doesn't provide this directly + CacheReadTokens: cachedTokens, + } +} + +func WithCopilotReasoningEffort(effort string) CopilotOption { + return func(options *copilotOptions) { + defaultReasoningEffort := "medium" + switch effort { + case "low", "medium", "high": + defaultReasoningEffort = effort + default: + logging.Warn("Invalid reasoning effort, using default: medium") + } + options.reasoningEffort = defaultReasoningEffort + } +} + +func WithCopilotExtraHeaders(headers map[string]string) CopilotOption { + return func(options *copilotOptions) { + options.extraHeaders = headers + } +} + +func WithCopilotBearerToken(bearerToken string) CopilotOption { + return func(options *copilotOptions) { + options.bearerToken = bearerToken + } +} + diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 08175450a6d85953e996c08f436982a1981053b6..d5be0ba0e8b1a88165c8dddfa84a1acda2e1a0dc 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -68,6 +68,7 @@ type providerClientOptions struct { openaiOptions []OpenAIOption geminiOptions []GeminiOption bedrockOptions []BedrockOption + copilotOptions []CopilotOption } type ProviderClientOption func(*providerClientOptions) @@ -88,6 +89,11 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption o(&clientOptions) } switch providerName { + case models.ProviderCopilot: + return &baseProvider[CopilotClient]{ + options: clientOptions, + client: newCopilotClient(clientOptions), + }, nil case models.ProviderAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, @@ -233,3 +239,9 @@ func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption { options.bedrockOptions = bedrockOptions } } + +func WithCopilotOptions(copilotOptions ...CopilotOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.copilotOptions = copilotOptions + } +} diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 6d800ce6ee27902a5c99767b9954e91f2c650428..7802817226d728ac20e63ecf38d9d827dd182fa2 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/logging" "github.com/opencode-ai/opencode/internal/lsp" ) @@ -97,6 +98,7 @@ func (v *viewTool) Info() ToolInfo { // Run implements Tool. func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { var params ViewParams + logging.Debug("view tool params", "params", call.Input) if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil } diff --git a/internal/logging/logger.go b/internal/logging/logger.go index 7ae2e7b87ab7f3f71811c793118c79e2a72a3bbf..51787d00e2c6b4c37e0b1b9c3f23a9fbd6cd6f65 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -4,16 +4,33 @@ import ( "fmt" "log/slog" "os" + // "path/filepath" + "encoding/json" + "runtime" "runtime/debug" + "sync" "time" ) +func getCaller() string { + var caller string + if _, file, line, ok := runtime.Caller(2); ok { + // caller = fmt.Sprintf("%s:%d", filepath.Base(file), line) + caller = fmt.Sprintf("%s:%d", file, line) + } else { + caller = "unknown" + } + return caller +} func Info(msg string, args ...any) { - slog.Info(msg, args...) + source := getCaller() + slog.Info(msg, append([]any{"source", source}, args...)...) } func Debug(msg string, args ...any) { - slog.Debug(msg, args...) + // slog.Debug(msg, args...) + source := getCaller() + slog.Debug(msg, append([]any{"source", source}, args...)...) } func Warn(msg string, args ...any) { @@ -76,3 +93,115 @@ func RecoverPanic(name string, cleanup func()) { } } } + +// Message Logging for Debug +var MessageDir string + +func GetSessionPrefix(sessionId string) string { + return sessionId[:8] +} + +var sessionLogMutex sync.Mutex + +func AppendToSessionLogFile(sessionId string, filename string, content string) string { + if MessageDir == "" || sessionId == "" { + return "" + } + sessionPrefix := GetSessionPrefix(sessionId) + + sessionLogMutex.Lock() + defer sessionLogMutex.Unlock() + + sessionPath := fmt.Sprintf("%s/%s", MessageDir, sessionPrefix) + if _, err := os.Stat(sessionPath); os.IsNotExist(err) { + if err := os.MkdirAll(sessionPath, 0o766); err != nil { + Error("Failed to create session directory", "dirpath", sessionPath, "error", err) + return "" + } + } + + filePath := fmt.Sprintf("%s/%s", sessionPath, filename) + + f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + Error("Failed to open session log file", "filepath", filePath, "error", err) + return "" + } + defer f.Close() + + // Append chunk to file + _, err = f.WriteString(content) + if err != nil { + Error("Failed to write chunk to session log file", "filepath", filePath, "error", err) + return "" + } + return filePath +} + +func WriteRequestMessageJson(sessionId string, requestSeqId int, message any) string { + if MessageDir == "" || sessionId == "" || requestSeqId <= 0 { + return "" + } + msgJson, err := json.Marshal(message) + if err != nil { + Error("Failed to marshal message", "session_id", sessionId, "request_seq_id", requestSeqId, "error", err) + return "" + } + return WriteRequestMessage(sessionId, requestSeqId, string(msgJson)) +} + +func WriteRequestMessage(sessionId string, requestSeqId int, message string) string { + if MessageDir == "" || sessionId == "" || requestSeqId <= 0 { + return "" + } + filename := fmt.Sprintf("%d_request.json", requestSeqId) + + return AppendToSessionLogFile(sessionId, filename, message) +} + +func AppendToStreamSessionLogJson(sessionId string, requestSeqId int, jsonableChunk any) string { + if MessageDir == "" || sessionId == "" || requestSeqId <= 0 { + return "" + } + chunkJson, err := json.Marshal(jsonableChunk) + if err != nil { + Error("Failed to marshal message", "session_id", sessionId, "request_seq_id", requestSeqId, "error", err) + return "" + } + return AppendToStreamSessionLog(sessionId, requestSeqId, string(chunkJson)) +} + +func AppendToStreamSessionLog(sessionId string, requestSeqId int, chunk string) string { + if MessageDir == "" || sessionId == "" || requestSeqId <= 0 { + return "" + } + filename := fmt.Sprintf("%d_response_stream.log", requestSeqId) + return AppendToSessionLogFile(sessionId, filename, chunk) +} + +func WriteChatResponseJson(sessionId string, requestSeqId int, response any) string { + if MessageDir == "" || sessionId == "" || requestSeqId <= 0 { + return "" + } + responseJson, err := json.Marshal(response) + if err != nil { + Error("Failed to marshal response", "session_id", sessionId, "request_seq_id", requestSeqId, "error", err) + return "" + } + filename := fmt.Sprintf("%d_response.json", requestSeqId) + + return AppendToSessionLogFile(sessionId, filename, string(responseJson)) +} + +func WriteToolResultsJson(sessionId string, requestSeqId int, toolResults any) string { + if MessageDir == "" || sessionId == "" || requestSeqId <= 0 { + return "" + } + toolResultsJson, err := json.Marshal(toolResults) + if err != nil { + Error("Failed to marshal tool results", "session_id", sessionId, "request_seq_id", requestSeqId, "error", err) + return "" + } + filename := fmt.Sprintf("%d_tool_results.json", requestSeqId) + return AppendToSessionLogFile(sessionId, filename, string(toolResultsJson)) +} diff --git a/internal/logging/writer.go b/internal/logging/writer.go index 50f3367db015af253869262ce139d4d36c962254..5c0e3c80392cc92830038f2302a0bea1e3c4fdb3 100644 --- a/internal/logging/writer.go +++ b/internal/logging/writer.go @@ -45,6 +45,7 @@ type writer struct{} func (w *writer) Write(p []byte) (int, error) { d := logfmt.NewDecoder(bytes.NewReader(p)) + for d.ScanRecord() { msg := LogMessage{ ID: fmt.Sprintf("%d", time.Now().UnixNano()), diff --git a/opencode-schema.json b/opencode-schema.json index dc139fda374964b1254d5df12c42751c84d29e7a..406c75f8c7945cb1418f17cca9ba1aee9c4b2959 100644 --- a/opencode-schema.json +++ b/opencode-schema.json @@ -77,7 +77,18 @@ "openrouter.o4-mini", "openrouter.claude-3.5-haiku", "claude-4-opus", - "openrouter.o1-pro" + "openrouter.o1-pro", + "copilot.gpt-4o", + "copilot.gpt-4o-mini", + "copilot.gpt-4.1", + "copilot.claude-3.5-sonnet", + "copilot.claude-3.7-sonnet", + "copilot.claude-sonnet-4", + "copilot.o1", + "copilot.o3-mini", + "copilot.o4-mini", + "copilot.gemini-2.0-flash", + "copilot.gemini-2.5-pro" ], "type": "string" }, @@ -176,7 +187,18 @@ "openrouter.o4-mini", "openrouter.claude-3.5-haiku", "claude-4-opus", - "openrouter.o1-pro" + "openrouter.o1-pro", + "copilot.gpt-4o", + "copilot.gpt-4o-mini", + "copilot.gpt-4.1", + "copilot.claude-3.5-sonnet", + "copilot.claude-3.7-sonnet", + "copilot.claude-sonnet-4", + "copilot.o1", + "copilot.o3-mini", + "copilot.o4-mini", + "copilot.gemini-2.0-flash", + "copilot.gemini-2.5-pro" ], "type": "string" }, @@ -360,7 +382,8 @@ "openrouter", "bedrock", "azure", - "vertexai" + "vertexai", + "copilot" ], "type": "string" } From 4427df587f3c636002b66dba467a338fe948c828 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 26 Jun 2025 23:44:20 -0700 Subject: [PATCH 02/40] fixup early return for ollama (#266) --- internal/llm/models/local.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/llm/models/local.go b/internal/llm/models/local.go index 5d8412c86a0f3f4ccf305763171f6acfdaea6eb1..db0ea11c60b23151fe5f0b724c86683ef431b41c 100644 --- a/internal/llm/models/local.go +++ b/internal/llm/models/local.go @@ -81,6 +81,7 @@ func listLocalModels(modelsEndpoint string) []localModel { "error", err, "endpoint", modelsEndpoint, ) + return []localModel{} } defer res.Body.Close() @@ -89,6 +90,7 @@ func listLocalModels(modelsEndpoint string) []localModel { "status", res.StatusCode, "endpoint", modelsEndpoint, ) + return []localModel{} } var modelList localModelList @@ -97,6 +99,7 @@ func listLocalModels(modelsEndpoint string) []localModel { "error", err, "endpoint", modelsEndpoint, ) + return []localModel{} } var supportedModels []localModel From 85fe73a307225441023e4e0232bc791d16804283 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 24 Jun 2025 12:22:36 +0200 Subject: [PATCH 03/40] chore: initial setup for the new config wip --- go.mod | 5 +- internal/config/config.go | 26 +- internal/config_v2/config.go | 440 ++++++++++++++++++ internal/config_v2/config_test.go | 33 ++ internal/config_v2/fs.go | 71 +++ internal/config_v2/provider.go | 69 +++ internal/llm/models/anthropic.go | 2 +- internal/llm/models/azure.go | 2 +- internal/llm/models/gemini.go | 2 +- internal/llm/models/groq.go | 2 +- internal/llm/models/local.go | 2 +- internal/llm/models/models.go | 46 +- internal/llm/models/openai.go | 2 +- internal/llm/models/openrouter.go | 2 +- internal/llm/models/vertexai.go | 2 +- internal/llm/models/xai.go | 2 +- internal/llm/prompt/coder.go | 2 +- internal/llm/prompt/prompt.go | 2 +- internal/llm/prompt/summarizer.go | 2 +- internal/llm/prompt/task.go | 2 +- internal/llm/prompt/title.go | 2 +- internal/llm/provider/provider.go | 2 +- internal/llm/provider/vertexai.go | 4 +- internal/message/content.go | 2 +- .../tui/components/chat/sidebar/sidebar.go | 10 +- internal/tui/components/chat/splash/keys.go | 18 + internal/tui/components/chat/splash/splash.go | 85 ++++ .../tui/components/dialogs/models/models.go | 12 +- internal/tui/components/logo/logo.go | 3 +- 29 files changed, 794 insertions(+), 60 deletions(-) create mode 100644 internal/config_v2/config.go create mode 100644 internal/config_v2/config_test.go create mode 100644 internal/config_v2/fs.go create mode 100644 internal/config_v2/provider.go create mode 100644 internal/tui/components/chat/splash/keys.go create mode 100644 internal/tui/components/chat/splash/splash.go diff --git a/go.mod b/go.mod index 86b3a8541ddc01679f4a6e25a02ea6599b6b3754..8bc77cf3ffe7cdd96131027fe09f5b8f1a50796a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/charmbracelet/crush -go 1.24.0 +go 1.24.3 + +replace github.com/charmbracelet/fur => ../fur require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 @@ -15,6 +17,7 @@ require ( github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.3.0.20250609143341-c76fa36f1b94 github.com/charmbracelet/fang v0.1.0 + github.com/charmbracelet/fur v0.0.0-00010101000000-000000000000 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 diff --git a/internal/config/config.go b/internal/config/config.go index 3b794dee60cb13a0c42413aedce3199ea5998352..3944cb1374582f9af0eeb7bfadd05ef5f9a8c198 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -76,17 +76,17 @@ type TUIConfig struct { // Config is the main configuration structure for the application. type Config struct { - Data Data `json:"data"` - WorkingDir string `json:"wd,omitempty"` - MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` - Providers map[models.ModelProvider]Provider `json:"providers,omitempty"` - LSP map[string]LSPConfig `json:"lsp,omitempty"` - Agents map[AgentName]Agent `json:"agents,omitempty"` - Debug bool `json:"debug,omitempty"` - DebugLSP bool `json:"debugLSP,omitempty"` - ContextPaths []string `json:"contextPaths,omitempty"` - TUI TUIConfig `json:"tui"` - AutoCompact bool `json:"autoCompact,omitempty"` + Data Data `json:"data"` + WorkingDir string `json:"wd,omitempty"` + MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` + Providers map[models.InferenceProvider]Provider `json:"providers,omitempty"` + LSP map[string]LSPConfig `json:"lsp,omitempty"` + Agents map[AgentName]Agent `json:"agents,omitempty"` + Debug bool `json:"debug,omitempty"` + DebugLSP bool `json:"debugLSP,omitempty"` + ContextPaths []string `json:"contextPaths,omitempty"` + TUI TUIConfig `json:"tui"` + AutoCompact bool `json:"autoCompact,omitempty"` } // Application constants @@ -128,7 +128,7 @@ func Load(workingDir string, debug bool) (*Config, error) { cfg = &Config{ WorkingDir: workingDir, MCPServers: make(map[string]MCPServer), - Providers: make(map[models.ModelProvider]Provider), + Providers: make(map[models.InferenceProvider]Provider), LSP: make(map[string]LSPConfig), } @@ -640,7 +640,7 @@ func Validate() error { } // getProviderAPIKey gets the API key for a provider from environment variables -func getProviderAPIKey(provider models.ModelProvider) string { +func getProviderAPIKey(provider models.InferenceProvider) string { switch provider { case models.ProviderAnthropic: return os.Getenv("ANTHROPIC_API_KEY") diff --git a/internal/config_v2/config.go b/internal/config_v2/config.go new file mode 100644 index 0000000000000000000000000000000000000000..4ab12a83fe6de3e94105cf00d4045f652dd26cae --- /dev/null +++ b/internal/config_v2/config.go @@ -0,0 +1,440 @@ +package configv2 + +import ( + "encoding/json" + "errors" + "maps" + "os" + "path/filepath" + "slices" + "strings" + "sync" + + "github.com/charmbracelet/crush/internal/logging" + "github.com/charmbracelet/fur/pkg/provider" +) + +const ( + defaultDataDirectory = ".crush" + defaultLogLevel = "info" + appName = "crush" + + MaxTokensFallbackDefault = 4096 +) + +type Model struct { + ID string `json:"id"` + Name string `json:"model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + ReasoningEffort string `json:"reasoning_effort"` + SupportsImages bool `json:"supports_attachments"` +} + +type VertexAIOptions struct { + APIKey string `json:"api_key,omitempty"` + Project string `json:"project,omitempty"` + Location string `json:"location,omitempty"` +} + +type ProviderConfig struct { + BaseURL string `json:"base_url,omitempty"` + ProviderType provider.Type `json:"provider_type"` + APIKey string `json:"api_key,omitempty"` + Disabled bool `json:"disabled"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // used for e.x for vertex to set the project + ExtraParams map[string]string `json:"extra_params,omitempty"` + + DefaultModel string `json:"default_model"` +} + +type Agent struct { + Name string `json:"name"` + // This is the id of the system prompt used by the agent + // TODO: still needs to be implemented + PromptID string `json:"prompt_id"` + Disabled bool `json:"disabled"` + + Provider provider.InferenceProvider `json:"provider"` + Model Model `json:"model"` + + // The available tools for the agent + // if this is empty, all tools are available + AllowedTools []string `json:"allowed_tools"` + + // this tells us which MCPs are available for this agent + // if this is empty all mcps are available + // the string array is the list of tools from the MCP the agent has available + // if the string array is empty, all tools from the MCP are available + MCP map[string][]string `json:"mcp"` + + // The list of LSPs that this agent can use + // if this is empty, all LSPs are available + LSP []string `json:"lsp"` + + // Overrides the context paths for this agent + ContextPaths []string `json:"context_paths"` +} + +type MCPType string + +const ( + MCPStdio MCPType = "stdio" + MCPSse MCPType = "sse" +) + +type MCP struct { + Command string `json:"command"` + Env []string `json:"env"` + Args []string `json:"args"` + Type MCPType `json:"type"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` +} + +type LSPConfig struct { + Disabled bool `json:"enabled"` + Command string `json:"command"` + Args []string `json:"args"` + Options any `json:"options"` +} + +type TUIOptions struct { + CompactMode bool `json:"compact_mode"` + // Here we can add themes later or any TUI related options +} + +type Options struct { + ContextPaths []string `json:"context_paths"` + TUI TUIOptions `json:"tui"` + Debug bool `json:"debug"` + DebugLSP bool `json:"debug_lsp"` + DisableAutoSummarize bool `json:"disable_auto_summarize"` + // Relative to the cwd + DataDirectory string `json:"data_directory"` +} + +type Config struct { + // List of configured providers + Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` + + // List of configured agents + Agents map[string]Agent `json:"agents,omitempty"` + + // List of configured MCPs + MCP map[string]MCP `json:"mcp,omitempty"` + + // List of configured LSPs + LSP map[string]LSPConfig `json:"lsp,omitempty"` + + // Miscellaneous options + Options Options `json:"options"` + + // Used to add models that are not already in the repository + Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"` +} + +var ( + instance *Config // The single instance of the Singleton + cwd string + once sync.Once // Ensures the initialization happens only once +) + +func loadConfig(cwd string) (*Config, error) { + // First read the global config file + cfgPath := ConfigPath() + + cfg := defaultConfigBasedOnEnv() + + var globalCfg *Config + if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // config file exists, read it + file, err := os.ReadFile(cfgPath) + if err != nil { + return nil, err + } + globalCfg = &Config{} + if err := json.Unmarshal(file, globalCfg); err != nil { + return nil, err + } + } else { + // config file does not exist, create a new one + globalCfg = &Config{} + } + + var localConfig *Config + // Global config loaded, now read the local config file + localConfigPath := filepath.Join(cwd, "crush.json") + if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // local config file exists, read it + file, err := os.ReadFile(localConfigPath) + if err != nil { + return nil, err + } + localConfig = &Config{} + if err := json.Unmarshal(file, localConfig); err != nil { + return nil, err + } + } + + // merge options + cfg.Options = mergeOptions(cfg.Options, globalCfg.Options) + cfg.Options = mergeOptions(cfg.Options, localConfig.Options) + + mergeProviderConfigs(cfg, globalCfg, localConfig) + return cfg, nil +} + +func InitConfig(workingDir string) *Config { + once.Do(func() { + cwd = workingDir + cfg, err := loadConfig(cwd) + if err != nil { + // TODO: Handle this better + panic("Failed to load config: " + err.Error()) + } + instance = cfg + }) + + return instance +} + +func GetConfig() *Config { + if instance == nil { + // TODO: Handle this better + panic("Config not initialized. Call InitConfig first.") + } + return instance +} + +func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { + if other.APIKey != "" { + base.APIKey = other.APIKey + } + // Only change these options if the provider is not a known provider + if !slices.Contains(provider.KnownProviders(), p) { + if other.BaseURL != "" { + base.BaseURL = other.BaseURL + } + if other.ProviderType != "" { + base.ProviderType = other.ProviderType + } + if len(base.ExtraHeaders) > 0 { + if base.ExtraHeaders == nil { + base.ExtraHeaders = make(map[string]string) + } + maps.Copy(base.ExtraHeaders, other.ExtraHeaders) + } + if len(other.ExtraParams) > 0 { + if base.ExtraParams == nil { + base.ExtraParams = make(map[string]string) + } + maps.Copy(base.ExtraParams, other.ExtraParams) + } + } + + if other.Disabled { + base.Disabled = other.Disabled + } + + return base +} + +func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error { + if !slices.Contains(provider.KnownProviders(), p) { + if providerConfig.ProviderType != provider.TypeOpenAI { + return errors.New("invalid provider type: " + string(providerConfig.ProviderType)) + } + if providerConfig.BaseURL == "" { + return errors.New("base URL must be set for custom providers") + } + if providerConfig.APIKey == "" { + return errors.New("API key must be set for custom providers") + } + } + return nil +} + +func mergeOptions(base, other Options) Options { + result := base + + if len(other.ContextPaths) > 0 { + base.ContextPaths = append(base.ContextPaths, other.ContextPaths...) + } + + if other.TUI.CompactMode { + result.TUI.CompactMode = other.TUI.CompactMode + } + + if other.Debug { + result.Debug = other.Debug + } + + if other.DebugLSP { + result.DebugLSP = other.DebugLSP + } + + if other.DisableAutoSummarize { + result.DisableAutoSummarize = other.DisableAutoSummarize + } + + if other.DataDirectory != "" { + result.DataDirectory = other.DataDirectory + } + + return result +} + +func mergeProviderConfigs(base, global, local *Config) { + if global != nil { + for providerName, globalProvider := range global.Providers { + if _, ok := base.Providers[providerName]; !ok { + base.Providers[providerName] = globalProvider + } else { + base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider) + } + } + } + if local != nil { + for providerName, localProvider := range local.Providers { + if _, ok := base.Providers[providerName]; !ok { + base.Providers[providerName] = localProvider + } else { + base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider) + } + } + } + + finalProviders := make(map[provider.InferenceProvider]ProviderConfig) + for providerName, providerConfig := range base.Providers { + err := validateProvider(providerName, providerConfig) + if err != nil { + logging.Warn("Skipping provider", "name", providerName, "error", err) + } + finalProviders[providerName] = providerConfig + } + base.Providers = finalProviders +} + +func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig { + switch providerName { + case provider.InferenceProviderAnthropic: + return ProviderConfig{ + ProviderType: provider.TypeAnthropic, + } + case provider.InferenceProviderOpenAI: + return ProviderConfig{ + ProviderType: provider.TypeOpenAI, + } + case provider.InferenceProviderGemini: + return ProviderConfig{ + ProviderType: provider.TypeGemini, + } + case provider.InferenceProviderBedrock: + return ProviderConfig{ + ProviderType: provider.TypeBedrock, + } + case provider.InferenceProviderAzure: + return ProviderConfig{ + ProviderType: provider.TypeAzure, + } + case provider.InferenceProviderOpenRouter: + return ProviderConfig{ + ProviderType: provider.TypeOpenAI, + BaseURL: "https://openrouter.ai/api/v1", + ExtraHeaders: map[string]string{ + "HTTP-Referer": "crush.charm.land", + "X-Title": "Crush", + }, + } + case provider.InferenceProviderXAI: + return ProviderConfig{ + ProviderType: provider.TypeXAI, + BaseURL: "https://api.x.ai/v1", + } + case provider.InferenceProviderVertexAI: + return ProviderConfig{ + ProviderType: provider.TypeVertexAI, + } + default: + return ProviderConfig{ + ProviderType: provider.TypeOpenAI, + } + } +} + +func defaultConfigBasedOnEnv() *Config { + cfg := &Config{ + Options: Options{ + DataDirectory: defaultDataDirectory, + }, + Providers: make(map[provider.InferenceProvider]ProviderConfig), + } + + providers := Providers() + + for _, p := range providers { + if strings.HasPrefix(p.APIKey, "$") { + envVar := strings.TrimPrefix(p.APIKey, "$") + if apiKey := os.Getenv(envVar); apiKey != "" { + providerConfig := providerDefaultConfig(p.ID) + providerConfig.APIKey = apiKey + providerConfig.DefaultModel = p.DefaultModelID + cfg.Providers[p.ID] = providerConfig + } + } + } + // TODO: support local models + + if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" { + providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI) + providerConfig.ExtraParams = map[string]string{ + "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), + "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), + } + cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig + } + + if hasAWSCredentials() { + providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock) + cfg.Providers[provider.InferenceProviderBedrock] = providerConfig + } + return cfg +} + +func hasAWSCredentials() bool { + if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { + return true + } + + if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { + return true + } + + if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { + return true + } + + if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { + return true + } + + return false +} + +func WorkingDirectory() string { + return cwd +} diff --git a/internal/config_v2/config_test.go b/internal/config_v2/config_test.go new file mode 100644 index 0000000000000000000000000000000000000000..50b829271dcd42213141ecf2b9b72f5890480668 --- /dev/null +++ b/internal/config_v2/config_test.go @@ -0,0 +1,33 @@ +package configv2 + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func resetEnvVars() { + os.Setenv("ANTHROPIC_API_KEY", "") + os.Setenv("OPENAI_API_KEY", "") + os.Setenv("GEMINI_API_KEY", "") + os.Setenv("XAI_API_KEY", "") + os.Setenv("OPENROUTER_API_KEY", "") +} + +func TestConfigWithEnv(t *testing.T) { + resetEnvVars() + testConfigDir = t.TempDir() + + cwdDir := t.TempDir() + + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") + os.Setenv("OPENAI_API_KEY", "test-openai-key") + os.Setenv("GEMINI_API_KEY", "test-gemini-key") + os.Setenv("XAI_API_KEY", "test-xai-key") + os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") + cfg := InitConfig(cwdDir) + fmt.Println(cfg) + assert.Len(t, cfg.Providers, 5) +} diff --git a/internal/config_v2/fs.go b/internal/config_v2/fs.go new file mode 100644 index 0000000000000000000000000000000000000000..976267a2a68efb718449f59b3720d0d186720cdf --- /dev/null +++ b/internal/config_v2/fs.go @@ -0,0 +1,71 @@ +package configv2 + +import ( + "fmt" + "os" + "path/filepath" + "runtime" +) + +var testConfigDir string + +func baseConfigPath() string { + if testConfigDir != "" { + return testConfigDir + } + + xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") + if xdgConfigHome != "" { + return filepath.Join(xdgConfigHome, "crush") + } + + // return the path to the main config directory + // for windows, it should be in `%LOCALAPPDATA%/crush/` + // for linux and macOS, it should be in `$HOME/.config/crush/` + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") + } + return filepath.Join(localAppData, appName) + } + + return filepath.Join(os.Getenv("HOME"), ".config", appName) +} + +func baseDataPath() string { + if testConfigDir != "" { + return testConfigDir + } + + xdgDataHome := os.Getenv("XDG_DATA_HOME") + if xdgDataHome != "" { + return filepath.Join(xdgDataHome, appName) + } + + // return the path to the main data directory + // for windows, it should be in `%LOCALAPPDATA%/crush/` + // for linux and macOS, it should be in `$HOME/.local/share/crush/` + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") + } + return filepath.Join(localAppData, appName) + } + + return filepath.Join(os.Getenv("HOME"), ".local", "share", appName) +} + +func ConfigPath() string { + return filepath.Join(baseConfigPath(), fmt.Sprintf("%s.json", appName)) +} + +func CrushInitialized() bool { + cfgPath := ConfigPath() + if _, err := os.Stat(cfgPath); os.IsNotExist(err) { + // config file does not exist, so Crush is not initialized + return false + } + return true +} diff --git a/internal/config_v2/provider.go b/internal/config_v2/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..94fe2d44d74e3039dcdeaa0dc76e95b840a03125 --- /dev/null +++ b/internal/config_v2/provider.go @@ -0,0 +1,69 @@ +package configv2 + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + + "github.com/charmbracelet/fur/pkg/client" + "github.com/charmbracelet/fur/pkg/provider" +) + +var fur = client.New() + +var ( + providerOnc sync.Once // Ensures the initialization happens only once + providerList []provider.Provider +) + +func providersPath() string { + return filepath.Join(baseDataPath(), "providers.json") +} + +func saveProviders(providers []provider.Provider) error { + path := providersPath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + data, err := json.MarshalIndent(providers, "", " ") + if err != nil { + return err + } + + return os.WriteFile(path, data, 0o644) +} + +func loadProviders() ([]provider.Provider, error) { + path := providersPath() + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var providers []provider.Provider + err = json.Unmarshal(data, &providers) + return providers, err +} + +func Providers() []provider.Provider { + providerOnc.Do(func() { + // Try to get providers from upstream API + if providers, err := fur.GetProviders(); err == nil { + providerList = providers + // Save providers locally for future fallback + _ = saveProviders(providers) + } else { + // If upstream fails, try to load from local cache + if localProviders, localErr := loadProviders(); localErr == nil { + providerList = localProviders + } else { + // If both fail, return empty list + providerList = []provider.Provider{} + } + } + }) + return providerList +} diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go index 9da03a835126956a74ee16888397abb343811ec4..85c47def3d94034297265c506c5870f2b449d286 100644 --- a/internal/llm/models/anthropic.go +++ b/internal/llm/models/anthropic.go @@ -1,7 +1,7 @@ package models const ( - ProviderAnthropic ModelProvider = "anthropic" + ProviderAnthropic InferenceProvider = "anthropic" // Models Claude35Sonnet ModelID = "claude-3.5-sonnet" diff --git a/internal/llm/models/azure.go b/internal/llm/models/azure.go index 416597302f362b4f2d7c605f7166ced2b200885a..eb7ae293ee053d953f5bcbb20120089ca6bae95b 100644 --- a/internal/llm/models/azure.go +++ b/internal/llm/models/azure.go @@ -1,6 +1,6 @@ package models -const ProviderAzure ModelProvider = "azure" +const ProviderAzure InferenceProvider = "azure" const ( AzureGPT41 ModelID = "azure.gpt-4.1" diff --git a/internal/llm/models/gemini.go b/internal/llm/models/gemini.go index 794ec3f0a06a0e9975d110cd4fb89e1427a32552..9749c6d3409acf7b05cd67690504e2cb3ac4fd39 100644 --- a/internal/llm/models/gemini.go +++ b/internal/llm/models/gemini.go @@ -1,7 +1,7 @@ package models const ( - ProviderGemini ModelProvider = "gemini" + ProviderGemini InferenceProvider = "gemini" // Models Gemini25Flash ModelID = "gemini-2.5-flash" diff --git a/internal/llm/models/groq.go b/internal/llm/models/groq.go index 19917f20bb2647e296db681e30b1b0f379bf7349..39288962c8e42a1acec8a01b3157b10d9b00b5dc 100644 --- a/internal/llm/models/groq.go +++ b/internal/llm/models/groq.go @@ -1,7 +1,7 @@ package models const ( - ProviderGROQ ModelProvider = "groq" + ProviderGROQ InferenceProvider = "groq" // GROQ QWENQwq ModelID = "qwen-qwq" diff --git a/internal/llm/models/local.go b/internal/llm/models/local.go index 3a50fdf48fe86167600eceee3cce26b6caac900e..c469e99fd65d5befbfffe5126a31c88eae68e150 100644 --- a/internal/llm/models/local.go +++ b/internal/llm/models/local.go @@ -16,7 +16,7 @@ import ( ) const ( - ProviderLocal ModelProvider = "local" + ProviderLocal InferenceProvider = "local" localModelsPath = "v1/models" lmStudioBetaModelsPath = "api/v0/models" diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 50e8723989ccb268a9f515b4c693662654fa38d5..0aefc170d32d1023f0d246a2cc7522e895453a88 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -3,23 +3,23 @@ package models import "maps" type ( - ModelID string - ModelProvider string + ModelID string + InferenceProvider string ) type Model struct { - ID ModelID `json:"id"` - Name string `json:"name"` - Provider ModelProvider `json:"provider"` - APIModel string `json:"api_model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - SupportsAttachments bool `json:"supports_attachments"` + ID ModelID `json:"id"` + Name string `json:"name"` + Provider InferenceProvider `json:"provider"` + APIModel string `json:"api_model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + SupportsAttachments bool `json:"supports_attachments"` } // Model IDs @@ -29,9 +29,9 @@ const ( // GEMINI ) const ( - ProviderBedrock ModelProvider = "bedrock" + ProviderBedrock InferenceProvider = "bedrock" // ForTests - ProviderMock ModelProvider = "__mock" + ProviderMock InferenceProvider = "__mock" ) var SupportedModels = map[ModelID]Model{ @@ -48,6 +48,20 @@ var SupportedModels = map[ModelID]Model{ }, } +var KnownProviders = []InferenceProvider{ + ProviderAnthropic, + ProviderOpenAI, + ProviderGemini, + ProviderAzure, + ProviderGROQ, + ProviderLocal, + ProviderOpenRouter, + ProviderVertexAI, + ProviderBedrock, + ProviderXAI, + ProviderMock, +} + func init() { maps.Copy(SupportedModels, AnthropicModels) maps.Copy(SupportedModels, OpenAIModels) diff --git a/internal/llm/models/openai.go b/internal/llm/models/openai.go index abe0e30c53207c0a7cacfe3d27f43495cfec6e8e..e4173277cbdfe1e579068d2981df1e70b9943cb1 100644 --- a/internal/llm/models/openai.go +++ b/internal/llm/models/openai.go @@ -1,7 +1,7 @@ package models const ( - ProviderOpenAI ModelProvider = "openai" + ProviderOpenAI InferenceProvider = "openai" GPT41 ModelID = "gpt-4.1" GPT41Mini ModelID = "gpt-4.1-mini" diff --git a/internal/llm/models/openrouter.go b/internal/llm/models/openrouter.go index 10ad5a0409937e060616eda6c188991d979e3ea1..8884e03442d30787fd505ca6a6c518d299748752 100644 --- a/internal/llm/models/openrouter.go +++ b/internal/llm/models/openrouter.go @@ -1,7 +1,7 @@ package models const ( - ProviderOpenRouter ModelProvider = "openrouter" + ProviderOpenRouter InferenceProvider = "openrouter" OpenRouterGPT41 ModelID = "openrouter.gpt-4.1" OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini" diff --git a/internal/llm/models/vertexai.go b/internal/llm/models/vertexai.go index d71dfc0bed0a8071c89ab22883990413384cd56f..c9b5744b62c28e2529cac44b1e97234158d2eacf 100644 --- a/internal/llm/models/vertexai.go +++ b/internal/llm/models/vertexai.go @@ -1,7 +1,7 @@ package models const ( - ProviderVertexAI ModelProvider = "vertexai" + ProviderVertexAI InferenceProvider = "vertexai" // Models VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash" diff --git a/internal/llm/models/xai.go b/internal/llm/models/xai.go index 00caf3b89750c0789f75f6273d49e38a4cdf6282..a59eac97ee6fee5db5550663083062099512eddc 100644 --- a/internal/llm/models/xai.go +++ b/internal/llm/models/xai.go @@ -1,7 +1,7 @@ package models const ( - ProviderXAI ModelProvider = "xai" + ProviderXAI InferenceProvider = "xai" XAIGrok3Beta ModelID = "grok-3-beta" XAIGrok3MiniBeta ModelID = "grok-3-mini-beta" diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index ea31bfa0297c1ce207e188a7f162e26831927636..b272f4e9f263ff596d06aae787e8b5a1c3ac2aec 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -13,7 +13,7 @@ import ( "github.com/charmbracelet/crush/internal/llm/tools" ) -func CoderPrompt(provider models.ModelProvider) string { +func CoderPrompt(provider models.InferenceProvider) string { basePrompt := baseAnthropicCoderPrompt switch provider { case models.ProviderOpenAI: diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 9065fd6bc0bb4a69eab19479b9a18b270dddee02..ed75d29c500cce16f16d06892ad8fcabc254a08d 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -12,7 +12,7 @@ import ( "github.com/charmbracelet/crush/internal/logging" ) -func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { +func GetAgentPrompt(agentName config.AgentName, provider models.InferenceProvider) string { basePrompt := "" switch agentName { case config.AgentCoder: diff --git a/internal/llm/prompt/summarizer.go b/internal/llm/prompt/summarizer.go index 87a0f95c66af8b51d07a3a4e792c07dea7dab503..f5a1de0f8619252d99082c6ca54e152cc25a7bc7 100644 --- a/internal/llm/prompt/summarizer.go +++ b/internal/llm/prompt/summarizer.go @@ -2,7 +2,7 @@ package prompt import "github.com/charmbracelet/crush/internal/llm/models" -func SummarizerPrompt(_ models.ModelProvider) string { +func SummarizerPrompt(_ models.InferenceProvider) string { return `You are a helpful AI assistant tasked with summarizing conversations. When asked to summarize, provide a detailed but concise summary of the conversation. diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index 53fd67dc2f88928b4fbe9773db0cd1487bcd811a..89acf1f02121ea008359eaa5201222061dad0cff 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -6,7 +6,7 @@ import ( "github.com/charmbracelet/crush/internal/llm/models" ) -func TaskPrompt(_ models.ModelProvider) string { +func TaskPrompt(_ models.InferenceProvider) string { agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 03e47288507fa66bb88605bff4b2194b889cc3f7..0b3177b37857c24d299df0d6e64393cd60ea23eb 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -2,7 +2,7 @@ package prompt import "github.com/charmbracelet/crush/internal/llm/models" -func TitlePrompt(_ models.ModelProvider) string { +func TitlePrompt(_ models.InferenceProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index d63f73f67cf0455bcccbb06ae70e3dde6a09557c..40c7317fba3eb944ba83421bcee8bf1702882fcb 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -82,7 +82,7 @@ type baseProvider[C ProviderClient] struct { client C } -func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) { +func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) { clientOptions := providerClientOptions{} for _, o := range opts { o(&clientOptions) diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index fe2de2f4588f9dbe583e4f8af85e61eea67d5648..0e01ef9c3a082faea86bd6c76d6e9d53c9f8a933 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -17,8 +17,8 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient { } client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ - Project: os.Getenv("VERTEXAI_PROJECT"), - Location: os.Getenv("VERTEXAI_LOCATION"), + Project: os.Getenv("GOOGLE_CLOUD_PROJECT"), + Location: os.Getenv("GOOGLE_CLOUD_LOCATION"), Backend: genai.BackendVertexAI, }) if err != nil { diff --git a/internal/message/content.go b/internal/message/content.go index 383134b596e62a5fc18b2c8404d770fc6a2d4112..b9e83ba4dd7fcc96216755a3871f0553b58d88d7 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -71,7 +71,7 @@ type BinaryContent struct { Data []byte } -func (bc BinaryContent) String(provider models.ModelProvider) string { +func (bc BinaryContent) String(provider models.InferenceProvider) string { base64Encoded := base64.StdEncoding.EncodeToString(bc.Data) if provider == models.ProviderOpenAI { return "data:" + bc.MIMEType + ";base64," + base64Encoded diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index afd067adc06d99c3c9da911812750631423231e6..405bd1f0f8c7891db1958e70f97e290dd9a8d411 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -76,7 +76,7 @@ func NewSidebarCmp(history history.Service, lspClients map[string]*lsp.Client, c } func (m *sidebarCmp) Init() tea.Cmd { - m.logo = m.logoBlock(false) + m.logo = m.logoBlock() m.cwd = cwd() return nil } @@ -231,9 +231,9 @@ func (m *sidebarCmp) loadSessionFiles() tea.Msg { func (m *sidebarCmp) SetSize(width, height int) tea.Cmd { if width < logoBreakpoint && (m.width == 0 || m.width >= logoBreakpoint) { - m.logo = m.logoBlock(true) + m.logo = m.logoBlock() } else if width >= logoBreakpoint && (m.width == 0 || m.width < logoBreakpoint) { - m.logo = m.logoBlock(false) + m.logo = m.logoBlock() } m.width = width @@ -245,9 +245,9 @@ func (m *sidebarCmp) GetSize() (int, int) { return m.width, m.height } -func (m *sidebarCmp) logoBlock(compact bool) string { +func (m *sidebarCmp) logoBlock() string { t := styles.CurrentTheme() - return logo.Render(version.Version, compact, logo.Opts{ + return logo.Render(version.Version, true, logo.Opts{ FieldColor: t.Primary, TitleColorA: t.Secondary, TitleColorB: t.Primary, diff --git a/internal/tui/components/chat/splash/keys.go b/internal/tui/components/chat/splash/keys.go new file mode 100644 index 0000000000000000000000000000000000000000..df715c89e86971a0f788915737bf41a212c65b5a --- /dev/null +++ b/internal/tui/components/chat/splash/keys.go @@ -0,0 +1,18 @@ +package splash + +import ( + "github.com/charmbracelet/bubbles/v2/key" +) + +type KeyMap struct { + Cancel key.Binding +} + +func DefaultKeyMap() KeyMap { + return KeyMap{ + Cancel: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "cancel"), + ), + } +} diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go new file mode 100644 index 0000000000000000000000000000000000000000..75718b25471088607a66a19c7b8a56d36bd5d2d1 --- /dev/null +++ b/internal/tui/components/chat/splash/splash.go @@ -0,0 +1,85 @@ +package splash + +import ( + "github.com/charmbracelet/bubbles/v2/key" + tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/tui/components/core/layout" + "github.com/charmbracelet/crush/internal/tui/components/logo" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/tui/util" + "github.com/charmbracelet/crush/internal/version" + "github.com/charmbracelet/lipgloss/v2" +) + +type Splash interface { + util.Model + layout.Sizeable + layout.Help +} + +type splashCmp struct { + width, height int + keyMap KeyMap + logoRendered string +} + +func New() Splash { + return &splashCmp{ + width: 0, + height: 0, + keyMap: DefaultKeyMap(), + logoRendered: "", + } +} + +// GetSize implements SplashPage. +func (s *splashCmp) GetSize() (int, int) { + return s.width, s.height +} + +// Init implements SplashPage. +func (s *splashCmp) Init() tea.Cmd { + return nil +} + +// SetSize implements SplashPage. +func (s *splashCmp) SetSize(width int, height int) tea.Cmd { + s.width = width + s.height = height + s.logoRendered = s.logoBlock() + return nil +} + +// Update implements SplashPage. +func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + return s, s.SetSize(msg.Width, msg.Height) + } + return s, nil +} + +// View implements SplashPage. +func (s *splashCmp) View() tea.View { + content := lipgloss.JoinVertical(lipgloss.Left, s.logoRendered) + return tea.NewView(content) +} + +func (m *splashCmp) logoBlock() string { + t := styles.CurrentTheme() + return logo.Render(version.Version, false, logo.Opts{ + FieldColor: t.Primary, + TitleColorA: t.Secondary, + TitleColorB: t.Primary, + CharmColor: t.Secondary, + VersionColor: t.Primary, + Width: m.width - 2, // -2 for padding + }) +} + +// Bindings implements SplashPage. +func (s *splashCmp) Bindings() []key.Binding { + return []key.Binding{ + s.keyMap.Cancel, + } +} diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 906d87a9dfe65c1ec09bd5abaf4f9d6865545038..0197b7141560a67008ceac64c31756bd19fff74a 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -80,7 +80,7 @@ func NewModelDialogCmp() ModelDialog { } } -var ProviderPopularity = map[models.ModelProvider]int{ +var ProviderPopularity = map[models.InferenceProvider]int{ models.ProviderAnthropic: 1, models.ProviderOpenAI: 2, models.ProviderGemini: 3, @@ -92,7 +92,7 @@ var ProviderPopularity = map[models.ModelProvider]int{ models.ProviderXAI: 9, } -var ProviderName = map[models.ModelProvider]string{ +var ProviderName = map[models.InferenceProvider]string{ models.ProviderAnthropic: "Anthropic", models.ProviderOpenAI: "OpenAI", models.ProviderGemini: "Gemini", @@ -195,8 +195,8 @@ func GetSelectedModel(cfg *config.Config) models.Model { return models.SupportedModels[selectedModelID] } -func getEnabledProviders(cfg *config.Config) []models.ModelProvider { - var providers []models.ModelProvider +func getEnabledProviders(cfg *config.Config) []models.InferenceProvider { + var providers []models.InferenceProvider for providerID, provider := range cfg.Providers { if !provider.Disabled { providers = append(providers, providerID) @@ -204,7 +204,7 @@ func getEnabledProviders(cfg *config.Config) []models.ModelProvider { } // Sort by provider popularity - slices.SortFunc(providers, func(a, b models.ModelProvider) int { + slices.SortFunc(providers, func(a, b models.InferenceProvider) int { rA := ProviderPopularity[a] rB := ProviderPopularity[b] @@ -220,7 +220,7 @@ func getEnabledProviders(cfg *config.Config) []models.ModelProvider { return providers } -func getModelsForProvider(provider models.ModelProvider) []models.Model { +func getModelsForProvider(provider models.InferenceProvider) []models.Model { var providerModels []models.Model for _, model := range models.SupportedModels { if model.Provider == provider { diff --git a/internal/tui/components/logo/logo.go b/internal/tui/components/logo/logo.go index 4b044c9dbd45284c72b7d03636d7399555e5f388..9d170ee6c2a0036ad9e4ca8b11c1a373fbb15080 100644 --- a/internal/tui/components/logo/logo.go +++ b/internal/tui/components/logo/logo.go @@ -26,6 +26,7 @@ type Opts struct { TitleColorB color.Color // right gradient ramp point CharmColor color.Color // Charm™ text color VersionColor color.Color // Version text color + Width int // width of the rendered logo, used for truncation } // Render renders the Crush logo. Set the argument to true to render the narrow @@ -76,7 +77,7 @@ func Render(version string, compact bool, o Opts) string { } // Right field. - const rightWidth = 15 + rightWidth := max(15, o.Width-crushWidth-leftWidth) // 2 for the gap. const stepDownAt = 0 rightField := new(strings.Builder) for i := range fieldHeight { From 399e7f7084e5828879717695ae68d5c7317a93a9 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 26 Jun 2025 14:11:17 +0200 Subject: [PATCH 04/40] chore: move fur structs, small provider changes --- go.mod | 3 - internal/config_v2/config.go | 326 ++++++++++++++++++++++++----- internal/config_v2/config_test.go | 4 +- internal/config_v2/provider.go | 4 +- internal/fur/client/client.go | 63 ++++++ internal/fur/provider/provider.go | 72 +++++++ internal/llm/agent/agent.go | 24 +-- internal/llm/provider/anthropic.go | 73 ++----- internal/llm/provider/bedrock.go | 19 +- internal/llm/provider/gemini.go | 19 -- internal/llm/provider/openai.go | 47 +---- internal/llm/provider/provider.go | 72 +++---- internal/llm/provider/vertexai.go | 6 - 13 files changed, 477 insertions(+), 255 deletions(-) create mode 100644 internal/fur/client/client.go create mode 100644 internal/fur/provider/provider.go diff --git a/go.mod b/go.mod index 8bc77cf3ffe7cdd96131027fe09f5b8f1a50796a..99e808c14a24cd34d5274c74eba183229f51dd07 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module github.com/charmbracelet/crush go 1.24.3 -replace github.com/charmbracelet/fur => ../fur - require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 github.com/JohannesKaufmann/html-to-markdown v1.6.0 @@ -17,7 +15,6 @@ require ( github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.3.0.20250609143341-c76fa36f1b94 github.com/charmbracelet/fang v0.1.0 - github.com/charmbracelet/fur v0.0.0-00010101000000-000000000000 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 diff --git a/internal/config_v2/config.go b/internal/config_v2/config.go index 4ab12a83fe6de3e94105cf00d4045f652dd26cae..0de27aa1b16cf71e97c655561a5c8ce20bae0838 100644 --- a/internal/config_v2/config.go +++ b/internal/config_v2/config.go @@ -10,8 +10,8 @@ import ( "strings" "sync" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/logging" - "github.com/charmbracelet/fur/pkg/provider" ) const ( @@ -22,6 +22,29 @@ const ( MaxTokensFallbackDefault = 4096 ) +var defaultContextPaths = []string{ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md", +} + +type AgentID string + +const ( + AgentCoder AgentID = "coder" + AgentTask AgentID = "task" + AgentTitle AgentID = "title" + AgentSummarize AgentID = "summarize" +) + type Model struct { ID string `json:"id"` Name string `json:"model"` @@ -43,40 +66,43 @@ type VertexAIOptions struct { } type ProviderConfig struct { - BaseURL string `json:"base_url,omitempty"` - ProviderType provider.Type `json:"provider_type"` - APIKey string `json:"api_key,omitempty"` - Disabled bool `json:"disabled"` - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + ID provider.InferenceProvider `json:"id"` + BaseURL string `json:"base_url,omitempty"` + ProviderType provider.Type `json:"provider_type"` + APIKey string `json:"api_key,omitempty"` + Disabled bool `json:"disabled"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // used for e.x for vertex to set the project ExtraParams map[string]string `json:"extra_params,omitempty"` - DefaultModel string `json:"default_model"` + DefaultLargeModel string `json:"default_large_model,omitempty"` + DefaultSmallModel string `json:"default_small_model,omitempty"` + + Models []Model `json:"models,omitempty"` } type Agent struct { - Name string `json:"name"` + Name string `json:"name"` + Description string `json:"description,omitempty"` // This is the id of the system prompt used by the agent - // TODO: still needs to be implemented - PromptID string `json:"prompt_id"` - Disabled bool `json:"disabled"` + Disabled bool `json:"disabled"` Provider provider.InferenceProvider `json:"provider"` - Model Model `json:"model"` + Model string `json:"model"` // The available tools for the agent - // if this is empty, all tools are available + // if this is nil, all tools are available AllowedTools []string `json:"allowed_tools"` // this tells us which MCPs are available for this agent // if this is empty all mcps are available - // the string array is the list of tools from the MCP the agent has available - // if the string array is empty, all tools from the MCP are available - MCP map[string][]string `json:"mcp"` + // the string array is the list of tools from the AllowedMCP the agent has available + // if the string array is nil, all tools from the AllowedMCP are available + AllowedMCP map[string][]string `json:"allowed_mcp"` // The list of LSPs that this agent can use - // if this is empty, all LSPs are available - LSP []string `json:"lsp"` + // if this is nil, all LSPs are available + AllowedLSP []string `json:"allowed_lsp"` // Overrides the context paths for this agent ContextPaths []string `json:"context_paths"` @@ -125,7 +151,7 @@ type Config struct { Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` // List of configured agents - Agents map[string]Agent `json:"agents,omitempty"` + Agents map[AgentID]Agent `json:"agents,omitempty"` // List of configured MCPs MCP map[string]MCP `json:"mcp,omitempty"` @@ -135,15 +161,13 @@ type Config struct { // Miscellaneous options Options Options `json:"options"` - - // Used to add models that are not already in the repository - Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"` } var ( instance *Config // The single instance of the Singleton cwd string once sync.Once // Ensures the initialization happens only once + ) func loadConfig(cwd string) (*Config, error) { @@ -190,10 +214,73 @@ func loadConfig(cwd string) (*Config, error) { } // merge options - cfg.Options = mergeOptions(cfg.Options, globalCfg.Options) - cfg.Options = mergeOptions(cfg.Options, localConfig.Options) + mergeOptions(cfg, globalCfg, localConfig) mergeProviderConfigs(cfg, globalCfg, localConfig) + // no providers found the app is not initialized yet + if len(cfg.Providers) == 0 { + return cfg, nil + } + preferredProvider := getPreferredProvider(cfg.Providers) + + if preferredProvider == nil { + return nil, errors.New("no valid providers configured") + } + + agents := map[AgentID]Agent{ + AgentCoder: { + Name: "Coder", + Description: "An agent that helps with executing coding tasks.", + Provider: preferredProvider.ID, + Model: preferredProvider.DefaultLargeModel, + ContextPaths: cfg.Options.ContextPaths, + // All tools allowed + }, + AgentTask: { + Name: "Task", + Description: "An agent that helps with searching for context and finding implementation details.", + Provider: preferredProvider.ID, + Model: preferredProvider.DefaultLargeModel, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{ + "glob", + "grep", + "ls", + "sourcegraph", + "view", + }, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + AgentTitle: { + Name: "Title", + Description: "An agent that helps with generating titles for sessions.", + Provider: preferredProvider.ID, + Model: preferredProvider.DefaultSmallModel, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{}, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + AgentSummarize: { + Name: "Summarize", + Description: "An agent that helps with summarizing sessions.", + Provider: preferredProvider.ID, + Model: preferredProvider.DefaultSmallModel, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{}, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + } + cfg.Agents = agents + mergeAgents(cfg, globalCfg, localConfig) + mergeMCPs(cfg, globalCfg, localConfig) + mergeLSPs(cfg, globalCfg, localConfig) + return cfg, nil } @@ -219,6 +306,22 @@ func GetConfig() *Config { return instance } +func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig { + providers := Providers() + for _, p := range providers { + if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled { + return &providerConfig + } + } + // if none found return the first configured provider + for _, providerConfig := range configuredProviders { + if !providerConfig.Disabled { + return &providerConfig + } + } + return nil +} + func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { if other.APIKey != "" { base.APIKey = other.APIKey @@ -249,6 +352,26 @@ func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfi base.Disabled = other.Disabled } + if other.DefaultLargeModel != "" { + base.DefaultLargeModel = other.DefaultLargeModel + } + // Add new models if they don't exist + if other.Models != nil { + for _, model := range other.Models { + // check if the model already exists + exists := false + for _, existingModel := range base.Models { + if existingModel.ID == model.ID { + exists = true + break + } + } + if !exists { + base.Models = append(base.Models, model) + } + } + } + return base } @@ -267,39 +390,114 @@ func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfi return nil } -func mergeOptions(base, other Options) Options { - result := base +func mergeOptions(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + baseOptions := base.Options + other := cfg.Options + if len(other.ContextPaths) > 0 { + baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...) + } - if len(other.ContextPaths) > 0 { - base.ContextPaths = append(base.ContextPaths, other.ContextPaths...) - } + if other.TUI.CompactMode { + baseOptions.TUI.CompactMode = other.TUI.CompactMode + } - if other.TUI.CompactMode { - result.TUI.CompactMode = other.TUI.CompactMode - } + if other.Debug { + baseOptions.Debug = other.Debug + } - if other.Debug { - result.Debug = other.Debug - } + if other.DebugLSP { + baseOptions.DebugLSP = other.DebugLSP + } - if other.DebugLSP { - result.DebugLSP = other.DebugLSP + if other.DisableAutoSummarize { + baseOptions.DisableAutoSummarize = other.DisableAutoSummarize + } + + if other.DataDirectory != "" { + baseOptions.DataDirectory = other.DataDirectory + } + base.Options = baseOptions } +} - if other.DisableAutoSummarize { - result.DisableAutoSummarize = other.DisableAutoSummarize +func mergeAgents(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + for agentID, globalAgent := range cfg.Agents { + if _, ok := base.Agents[agentID]; !ok { + base.Agents[agentID] = globalAgent + } else { + switch agentID { + case AgentCoder: + baseAgent := base.Agents[agentID] + baseAgent.Model = globalAgent.Model + baseAgent.Provider = globalAgent.Provider + baseAgent.AllowedMCP = globalAgent.AllowedMCP + baseAgent.AllowedLSP = globalAgent.AllowedLSP + base.Agents[agentID] = baseAgent + case AgentTask: + baseAgent := base.Agents[agentID] + baseAgent.Model = globalAgent.Model + baseAgent.Provider = globalAgent.Provider + base.Agents[agentID] = baseAgent + case AgentTitle: + baseAgent := base.Agents[agentID] + baseAgent.Model = globalAgent.Model + baseAgent.Provider = globalAgent.Provider + base.Agents[agentID] = baseAgent + case AgentSummarize: + baseAgent := base.Agents[agentID] + baseAgent.Model = globalAgent.Model + baseAgent.Provider = globalAgent.Provider + base.Agents[agentID] = baseAgent + default: + baseAgent := base.Agents[agentID] + baseAgent.Name = globalAgent.Name + baseAgent.Description = globalAgent.Description + baseAgent.Disabled = globalAgent.Disabled + baseAgent.Provider = globalAgent.Provider + baseAgent.Model = globalAgent.Model + baseAgent.AllowedTools = globalAgent.AllowedTools + baseAgent.AllowedMCP = globalAgent.AllowedMCP + baseAgent.AllowedLSP = globalAgent.AllowedLSP + base.Agents[agentID] = baseAgent + + } + } + } } +} - if other.DataDirectory != "" { - result.DataDirectory = other.DataDirectory +func mergeMCPs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + maps.Copy(base.MCP, cfg.MCP) } +} - return result +func mergeLSPs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + maps.Copy(base.LSP, cfg.LSP) + } } func mergeProviderConfigs(base, global, local *Config) { - if global != nil { - for providerName, globalProvider := range global.Providers { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + for providerName, globalProvider := range cfg.Providers { if _, ok := base.Providers[providerName]; !ok { base.Providers[providerName] = globalProvider } else { @@ -307,15 +505,6 @@ func mergeProviderConfigs(base, global, local *Config) { } } } - if local != nil { - for providerName, localProvider := range local.Providers { - if _, ok := base.Providers[providerName]; !ok { - base.Providers[providerName] = localProvider - } else { - base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider) - } - } - } finalProviders := make(map[provider.InferenceProvider]ProviderConfig) for providerName, providerConfig := range base.Providers { @@ -328,30 +517,36 @@ func mergeProviderConfigs(base, global, local *Config) { base.Providers = finalProviders } -func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig { - switch providerName { +func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig { + switch providerId { case provider.InferenceProviderAnthropic: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeAnthropic, } case provider.InferenceProviderOpenAI: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeOpenAI, } case provider.InferenceProviderGemini: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeGemini, } case provider.InferenceProviderBedrock: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeBedrock, } case provider.InferenceProviderAzure: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeAzure, } case provider.InferenceProviderOpenRouter: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeOpenAI, BaseURL: "https://openrouter.ai/api/v1", ExtraHeaders: map[string]string{ @@ -361,15 +556,18 @@ func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConf } case provider.InferenceProviderXAI: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeXAI, BaseURL: "https://api.x.ai/v1", } case provider.InferenceProviderVertexAI: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeVertexAI, } default: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeOpenAI, } } @@ -379,6 +577,7 @@ func defaultConfigBasedOnEnv() *Config { cfg := &Config{ Options: Options{ DataDirectory: defaultDataDirectory, + ContextPaths: defaultContextPaths, }, Providers: make(map[provider.InferenceProvider]ProviderConfig), } @@ -391,7 +590,22 @@ func defaultConfigBasedOnEnv() *Config { if apiKey := os.Getenv(envVar); apiKey != "" { providerConfig := providerDefaultConfig(p.ID) providerConfig.APIKey = apiKey - providerConfig.DefaultModel = p.DefaultModelID + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + for _, model := range p.Models { + providerConfig.Models = append(providerConfig.Models, Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + }) + } cfg.Providers[p.ID] = providerConfig } } diff --git a/internal/config_v2/config_test.go b/internal/config_v2/config_test.go index 50b829271dcd42213141ecf2b9b72f5890480668..9bcfcdc78375e1a3a35726b513f04e3cb1e2c3b3 100644 --- a/internal/config_v2/config_test.go +++ b/internal/config_v2/config_test.go @@ -1,6 +1,7 @@ package configv2 import ( + "encoding/json" "fmt" "os" "testing" @@ -28,6 +29,7 @@ func TestConfigWithEnv(t *testing.T) { os.Setenv("XAI_API_KEY", "test-xai-key") os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") cfg := InitConfig(cwdDir) - fmt.Println(cfg) + data, _ := json.MarshalIndent(cfg, "", " ") + fmt.Println(string(data)) assert.Len(t, cfg.Providers, 5) } diff --git a/internal/config_v2/provider.go b/internal/config_v2/provider.go index 94fe2d44d74e3039dcdeaa0dc76e95b840a03125..ec6b5bdb701876af4705c9e78fcc55a87646edd2 100644 --- a/internal/config_v2/provider.go +++ b/internal/config_v2/provider.go @@ -6,8 +6,8 @@ import ( "path/filepath" "sync" - "github.com/charmbracelet/fur/pkg/client" - "github.com/charmbracelet/fur/pkg/provider" + "github.com/charmbracelet/crush/internal/fur/client" + "github.com/charmbracelet/crush/internal/fur/provider" ) var fur = client.New() diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go new file mode 100644 index 0000000000000000000000000000000000000000..263e8317ce8ac92d8820ba5288f2e40d2616e0e1 --- /dev/null +++ b/internal/fur/client/client.go @@ -0,0 +1,63 @@ +// Package client provides a client for interacting with the fur service. +package client + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + + "github.com/charmbracelet/crush/internal/fur/provider" +) + +const defaultURL = "http://localhost:8080" + +// Client represents a client for the fur service. +type Client struct { + baseURL string + httpClient *http.Client +} + +// New creates a new client instance +// Uses FUR_URL environment variable or falls back to localhost:8080. +func New() *Client { + baseURL := os.Getenv("FUR_URL") + if baseURL == "" { + baseURL = defaultURL + } + + return &Client{ + baseURL: baseURL, + httpClient: &http.Client{}, + } +} + +// NewWithURL creates a new client with a specific URL. +func NewWithURL(url string) *Client { + return &Client{ + baseURL: url, + httpClient: &http.Client{}, + } +} + +// GetProviders retrieves all available providers from the service. +func (c *Client) GetProviders() ([]provider.Provider, error) { + url := fmt.Sprintf("%s/providers", c.baseURL) + + resp, err := c.httpClient.Get(url) //nolint:noctx + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var providers []provider.Provider + if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return providers, nil +} diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..85275f1155eff219c87d85fce3cdcc436f4a4e47 --- /dev/null +++ b/internal/fur/provider/provider.go @@ -0,0 +1,72 @@ +// Package provider provides types and constants for AI providers. +package provider + +// Type represents the type of AI provider. +type Type string + +// All the supported AI provider types. +const ( + TypeOpenAI Type = "openai" + TypeAnthropic Type = "anthropic" + TypeGemini Type = "gemini" + TypeAzure Type = "azure" + TypeBedrock Type = "bedrock" + TypeVertexAI Type = "vertexai" + TypeXAI Type = "xai" + TypeOpenRouter Type = "openrouter" +) + +// InferenceProvider represents the inference provider identifier. +type InferenceProvider string + +// All the inference providers supported by the system. +const ( + InferenceProviderOpenAI InferenceProvider = "openai" + InferenceProviderAnthropic InferenceProvider = "anthropic" + InferenceProviderGemini InferenceProvider = "gemini" + InferenceProviderAzure InferenceProvider = "azure" + InferenceProviderBedrock InferenceProvider = "bedrock" + InferenceProviderVertexAI InferenceProvider = "vertexai" + InferenceProviderXAI InferenceProvider = "xai" + InferenceProviderOpenRouter InferenceProvider = "openrouter" +) + +// Provider represents an AI provider configuration. +type Provider struct { + Name string `json:"name"` + ID InferenceProvider `json:"id"` + APIKey string `json:"api_key,omitempty"` + APIEndpoint string `json:"api_endpoint,omitempty"` + Type Type `json:"type,omitempty"` + DefaultLargeModelID string `json:"default_large_model_id,omitempty"` + DefaultSmallModelID string `json:"default_small_model_id,omitempty"` + Models []Model `json:"models,omitempty"` +} + +// Model represents an AI model configuration. +type Model struct { + ID string `json:"id"` + Name string `json:"model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + SupportsImages bool `json:"supports_attachments"` +} + +// KnownProviders returns all the known inference providers. +func KnownProviders() []InferenceProvider { + return []InferenceProvider{ + InferenceProviderOpenAI, + InferenceProviderAnthropic, + InferenceProviderGemini, + InferenceProviderAzure, + InferenceProviderBedrock, + InferenceProviderVertexAI, + InferenceProviderXAI, + InferenceProviderOpenRouter, + } +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 26d952441317607323e1170d6a06559f9173605d..ea2a3bd2b11735c1f0422e859adcfa65a82fdb98 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -734,21 +734,15 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), provider.WithMaxTokens(maxTokens), } - if (model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal) && model.CanReason { - opts = append( - opts, - provider.WithOpenAIOptions( - provider.WithReasoningEffort(agentConfig.ReasoningEffort), - ), - ) - } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder { - opts = append( - opts, - provider.WithAnthropicOptions( - provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn), - ), - ) - } + // TODO: reimplement + // if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason { + // opts = append( + // opts, + // provider.WithOpenAIOptions( + // provider.WithReasoningEffort(agentConfig.ReasoningEffort), + // ), + // ) + // } agentProvider, err := provider.NewProvider( model.Provider, opts..., diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 634040ebc3dad8d6dc9e7642ebbe95ac3b051c63..709a56263e0a8880d444c8ee7e9cab1373e67344 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -19,40 +19,25 @@ import ( "github.com/charmbracelet/crush/internal/message" ) -type anthropicOptions struct { - useBedrock bool - disableCache bool - shouldThink func(userMessage string) bool -} - -type AnthropicOption func(*anthropicOptions) - type anthropicClient struct { providerOptions providerClientOptions - options anthropicOptions client anthropic.Client } type AnthropicClient ProviderClient -func newAnthropicClient(opts providerClientOptions) AnthropicClient { - anthropicOpts := anthropicOptions{} - for _, o := range opts.anthropicOptions { - o(&anthropicOpts) - } - +func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient { anthropicClientOptions := []option.RequestOption{} if opts.apiKey != "" { anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) } - if anthropicOpts.useBedrock { + if useBedrock { anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } client := anthropic.NewClient(anthropicClientOptions...) return &anthropicClient{ providerOptions: opts, - options: anthropicOpts, client: client, } } @@ -66,7 +51,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic switch msg.Role { case message.User: content := anthropic.NewTextBlock(msg.Content().String()) - if cache && !a.options.disableCache { + if cache && !a.providerOptions.disableCache { content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -84,7 +69,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic blocks := []anthropic.ContentBlockParamUnion{} if msg.Content().String() != "" { content := anthropic.NewTextBlock(msg.Content().String()) - if cache && !a.options.disableCache { + if cache && !a.providerOptions.disableCache { content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -132,7 +117,7 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU }, } - if i == len(tools)-1 && !a.options.disableCache { + if i == len(tools)-1 && !a.providerOptions.disableCache { toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -161,21 +146,22 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason { func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { var thinkingParam anthropic.ThinkingConfigParamUnion - lastMessage := messages[len(messages)-1] - isUser := lastMessage.Role == anthropic.MessageParamRoleUser - messageContent := "" + // TODO: Implement a proper thinking function + // lastMessage := messages[len(messages)-1] + // isUser := lastMessage.Role == anthropic.MessageParamRoleUser + // messageContent := "" temperature := anthropic.Float(0) - if isUser { - for _, m := range lastMessage.Content { - if m.OfText != nil && m.OfText.Text != "" { - messageContent = m.OfText.Text - } - } - if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) { - thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8)) - temperature = anthropic.Float(1) - } - } + // if isUser { + // for _, m := range lastMessage.Content { + // if m.OfText != nil && m.OfText.Text != "" { + // messageContent = m.OfText.Text + // } + // } + // if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) { + // thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8)) + // temperature = anthropic.Float(1) + // } + // } return anthropic.MessageNewParams{ Model: anthropic.Model(a.providerOptions.model.APIModel), @@ -439,24 +425,7 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { } } -func WithAnthropicBedrock(useBedrock bool) AnthropicOption { - return func(options *anthropicOptions) { - options.useBedrock = useBedrock - } -} - -func WithAnthropicDisableCache() AnthropicOption { - return func(options *anthropicOptions) { - options.disableCache = true - } -} - +// TODO: check if we need func DefaultShouldThinkFn(s string) bool { return strings.Contains(strings.ToLower(s), "think") } - -func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption { - return func(options *anthropicOptions) { - options.shouldThink = fn - } -} diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 8d3a86198aab5a38742e33b167f2545efd808873..8db9c1e84a4e8496be77e69e612de4abb9ce0c07 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -11,22 +11,14 @@ import ( "github.com/charmbracelet/crush/internal/message" ) -type bedrockOptions struct { - // Bedrock specific options can be added here -} - -type BedrockOption func(*bedrockOptions) - type bedrockClient struct { providerOptions providerClientOptions - options bedrockOptions childProvider ProviderClient } type BedrockClient ProviderClient func newBedrockClient(opts providerClientOptions) BedrockClient { - bedrockOpts := bedrockOptions{} // Apply bedrock specific options if they are added in the future // Get AWS region from environment @@ -41,7 +33,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { if len(region) < 2 { return &bedrockClient{ providerOptions: opts, - options: bedrockOpts, childProvider: nil, // Will cause an error when used } } @@ -55,14 +46,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { if strings.Contains(string(opts.model.APIModel), "anthropic") { // Create Anthropic client with Bedrock configuration anthropicOpts := opts - anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions, - WithAnthropicBedrock(true), - WithAnthropicDisableCache(), - ) + // TODO: later find a way to check if the AWS account has caching enabled + opts.disableCache = true // Disable cache for Bedrock return &bedrockClient{ providerOptions: opts, - options: bedrockOpts, - childProvider: newAnthropicClient(anthropicOpts), + childProvider: newAnthropicClient(anthropicOpts, true), } } @@ -70,7 +58,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { // This will cause an error when used return &bedrockClient{ providerOptions: opts, - options: bedrockOpts, childProvider: nil, } } diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 9481d8d545aab12a3739fe99b4af61f4ed99a514..dd54dac4491634de06a31ee00f1ffd13ea935076 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -17,26 +17,14 @@ import ( "google.golang.org/genai" ) -type geminiOptions struct { - disableCache bool -} - -type GeminiOption func(*geminiOptions) - type geminiClient struct { providerOptions providerClientOptions - options geminiOptions client *genai.Client } type GeminiClient ProviderClient func newGeminiClient(opts providerClientOptions) GeminiClient { - geminiOpts := geminiOptions{} - for _, o := range opts.geminiOptions { - o(&geminiOpts) - } - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) if err != nil { logging.Error("Failed to create Gemini client", "error", err) @@ -45,7 +33,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient { return &geminiClient{ providerOptions: opts, - options: geminiOpts, client: client, } } @@ -452,12 +439,6 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { } } -func WithGeminiDisableCache() GeminiOption { - return func(options *geminiOptions) { - options.disableCache = true - } -} - // Helper functions func parseJsonToMap(jsonStr string) (map[string]any, error) { var result map[string]any diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 05658dd6db760a1d05a88ae4931de5c70d9cc453..334312f9e8c41f5d68251d9e7bbd890074fa3982 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -19,14 +19,9 @@ import ( ) type openaiOptions struct { - baseURL string - disableCache bool reasoningEffort string - extraHeaders map[string]string } -type OpenAIOption func(*openaiOptions) - type openaiClient struct { providerOptions providerClientOptions options openaiOptions @@ -39,20 +34,17 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient { openaiOpts := openaiOptions{ reasoningEffort: "medium", } - for _, o := range opts.openaiOptions { - o(&openaiOpts) - } openaiClientOptions := []option.RequestOption{} if opts.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } - if openaiOpts.baseURL != "" { - openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL)) + if opts.baseURL != "" { + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL)) } - if openaiOpts.extraHeaders != nil { - for key, value := range openaiOpts.extraHeaders { + if opts.extraHeaders != nil { + for key, value := range opts.extraHeaders { openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) } } @@ -392,34 +384,3 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { CacheReadTokens: cachedTokens, } } - -func WithOpenAIBaseURL(baseURL string) OpenAIOption { - return func(options *openaiOptions) { - options.baseURL = baseURL - } -} - -func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption { - return func(options *openaiOptions) { - options.extraHeaders = headers - } -} - -func WithOpenAIDisableCache() OpenAIOption { - return func(options *openaiOptions) { - options.disableCache = true - } -} - -func WithReasoningEffort(effort string) OpenAIOption { - return func(options *openaiOptions) { - defaultReasoningEffort := "medium" - switch effort { - case "low", "medium", "high": - defaultReasoningEffort = effort - default: - logging.Warn("Invalid reasoning effort, using default: medium") - } - options.reasoningEffort = defaultReasoningEffort - } -} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 40c7317fba3eb944ba83421bcee8bf1702882fcb..86c47c2e0c24f2f99d91eb51c946da7bbf90dfa0 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "maps" "os" "github.com/charmbracelet/crush/internal/llm/models" @@ -59,15 +60,13 @@ type Provider interface { } type providerClientOptions struct { + baseURL string apiKey string model models.Model + disableCache bool maxTokens int64 systemMessage string - - anthropicOptions []AnthropicOption - openaiOptions []OpenAIOption - geminiOptions []GeminiOption - bedrockOptions []BedrockOption + extraHeaders map[string]string } type ProviderClientOption func(*providerClientOptions) @@ -91,7 +90,7 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp case models.ProviderAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, - client: newAnthropicClient(clientOptions), + client: newAnthropicClient(clientOptions, false), }, nil case models.ProviderOpenAI: return &baseProvider[OpenAIClient]{ @@ -109,9 +108,7 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp client: newBedrockClient(clientOptions), }, nil case models.ProviderGROQ: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) + clientOptions.baseURL = "https://api.groq.com/openai/v1" return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), @@ -127,29 +124,23 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp client: newVertexAIClient(clientOptions), }, nil case models.ProviderOpenRouter: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL("https://openrouter.ai/api/v1"), - WithOpenAIExtraHeaders(map[string]string{ - "HTTP-Referer": "crush.charm.land", - "X-Title": "Crush", - }), - ) + clientOptions.baseURL = "https://openrouter.ai/api/v1" + clientOptions.extraHeaders = map[string]string{ + "HTTP-Referer": "crush.charm.land", + "X-Title": "Crush", + } return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderXAI: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL("https://api.x.ai/v1"), - ) + clientOptions.baseURL = "https://api.x.ai/v1" return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderLocal: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")), - ) + clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT") return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), @@ -186,50 +177,47 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message return p.client.stream(ctx, messages, tools) } -func WithAPIKey(apiKey string) ProviderClientOption { - return func(options *providerClientOptions) { - options.apiKey = apiKey - } -} - -func WithModel(model models.Model) ProviderClientOption { +func WithBaseURL(baseURL string) ProviderClientOption { return func(options *providerClientOptions) { - options.model = model + options.baseURL = baseURL } } -func WithMaxTokens(maxTokens int64) ProviderClientOption { +func WithAPIKey(apiKey string) ProviderClientOption { return func(options *providerClientOptions) { - options.maxTokens = maxTokens + options.apiKey = apiKey } } -func WithSystemMessage(systemMessage string) ProviderClientOption { +func WithModel(model models.Model) ProviderClientOption { return func(options *providerClientOptions) { - options.systemMessage = systemMessage + options.model = model } } -func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption { +func WithDisableCache(disableCache bool) ProviderClientOption { return func(options *providerClientOptions) { - options.anthropicOptions = anthropicOptions + options.disableCache = disableCache } } -func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption { +func WithExtraHeaders(extraHeaders map[string]string) ProviderClientOption { return func(options *providerClientOptions) { - options.openaiOptions = openaiOptions + if options.extraHeaders == nil { + options.extraHeaders = make(map[string]string) + } + maps.Copy(options.extraHeaders, extraHeaders) } } -func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption { +func WithMaxTokens(maxTokens int64) ProviderClientOption { return func(options *providerClientOptions) { - options.geminiOptions = geminiOptions + options.maxTokens = maxTokens } } -func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption { +func WithSystemMessage(systemMessage string) ProviderClientOption { return func(options *providerClientOptions) { - options.bedrockOptions = bedrockOptions + options.systemMessage = systemMessage } } diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index 0e01ef9c3a082faea86bd6c76d6e9d53c9f8a933..49374d33fa81ab42e9f0c4d6e7905bfa37a6154e 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -11,11 +11,6 @@ import ( type VertexAIClient ProviderClient func newVertexAIClient(opts providerClientOptions) VertexAIClient { - geminiOpts := geminiOptions{} - for _, o := range opts.geminiOptions { - o(&geminiOpts) - } - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ Project: os.Getenv("GOOGLE_CLOUD_PROJECT"), Location: os.Getenv("GOOGLE_CLOUD_LOCATION"), @@ -28,7 +23,6 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient { return &geminiClient{ providerOptions: opts, - options: geminiOpts, client: client, } } From 6e0feda9c412babd6f0b1ca67c4495d51f8ccd05 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 26 Jun 2025 14:31:34 +0200 Subject: [PATCH 05/40] chore: add name to tools --- internal/config_v2/config.go | 50 +++++++++++++++++-------------- internal/llm/agent/agent-tool.go | 4 +++ internal/llm/agent/mcp-tools.go | 4 +++ internal/llm/provider/provider.go | 22 -------------- internal/llm/tools/bash.go | 4 +++ internal/llm/tools/diagnostics.go | 4 +++ internal/llm/tools/edit.go | 4 +++ internal/llm/tools/fetch.go | 4 +++ internal/llm/tools/glob.go | 4 +++ internal/llm/tools/grep.go | 4 +++ internal/llm/tools/ls.go | 4 +++ internal/llm/tools/sourcegraph.go | 4 +++ internal/llm/tools/tools.go | 1 + internal/llm/tools/view.go | 4 +++ internal/llm/tools/write.go | 4 +++ 15 files changed, 77 insertions(+), 44 deletions(-) diff --git a/internal/config_v2/config.go b/internal/config_v2/config.go index 0de27aa1b16cf71e97c655561a5c8ce20bae0838..9f7f2ad14356531150cca4f05952fb390c716c68 100644 --- a/internal/config_v2/config.go +++ b/internal/config_v2/config.go @@ -82,8 +82,9 @@ type ProviderConfig struct { } type Agent struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` + ID AgentID `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` // This is the id of the system prompt used by the agent Disabled bool `json:"disabled"` @@ -229,6 +230,7 @@ func loadConfig(cwd string) (*Config, error) { agents := map[AgentID]Agent{ AgentCoder: { + ID: AgentCoder, Name: "Coder", Description: "An agent that helps with executing coding tasks.", Provider: preferredProvider.ID, @@ -237,6 +239,7 @@ func loadConfig(cwd string) (*Config, error) { // All tools allowed }, AgentTask: { + ID: AgentTask, Name: "Task", Description: "An agent that helps with searching for context and finding implementation details.", Provider: preferredProvider.ID, @@ -254,6 +257,7 @@ func loadConfig(cwd string) (*Config, error) { AllowedLSP: []string{}, }, AgentTitle: { + ID: AgentTitle, Name: "Title", Description: "An agent that helps with generating titles for sessions.", Provider: preferredProvider.ID, @@ -265,6 +269,7 @@ func loadConfig(cwd string) (*Config, error) { AllowedLSP: []string{}, }, AgentSummarize: { + ID: AgentSummarize, Name: "Summarize", Description: "An agent that helps with summarizing sessions.", Provider: preferredProvider.ID, @@ -429,43 +434,44 @@ func mergeAgents(base, global, local *Config) { if cfg == nil { continue } - for agentID, globalAgent := range cfg.Agents { + for agentID, newAgent := range cfg.Agents { if _, ok := base.Agents[agentID]; !ok { - base.Agents[agentID] = globalAgent + newAgent.ID = agentID // Ensure the ID is set correctly + base.Agents[agentID] = newAgent } else { switch agentID { case AgentCoder: baseAgent := base.Agents[agentID] - baseAgent.Model = globalAgent.Model - baseAgent.Provider = globalAgent.Provider - baseAgent.AllowedMCP = globalAgent.AllowedMCP - baseAgent.AllowedLSP = globalAgent.AllowedLSP + baseAgent.Model = newAgent.Model + baseAgent.Provider = newAgent.Provider + baseAgent.AllowedMCP = newAgent.AllowedMCP + baseAgent.AllowedLSP = newAgent.AllowedLSP base.Agents[agentID] = baseAgent case AgentTask: baseAgent := base.Agents[agentID] - baseAgent.Model = globalAgent.Model - baseAgent.Provider = globalAgent.Provider + baseAgent.Model = newAgent.Model + baseAgent.Provider = newAgent.Provider base.Agents[agentID] = baseAgent case AgentTitle: baseAgent := base.Agents[agentID] - baseAgent.Model = globalAgent.Model - baseAgent.Provider = globalAgent.Provider + baseAgent.Model = newAgent.Model + baseAgent.Provider = newAgent.Provider base.Agents[agentID] = baseAgent case AgentSummarize: baseAgent := base.Agents[agentID] - baseAgent.Model = globalAgent.Model - baseAgent.Provider = globalAgent.Provider + baseAgent.Model = newAgent.Model + baseAgent.Provider = newAgent.Provider base.Agents[agentID] = baseAgent default: baseAgent := base.Agents[agentID] - baseAgent.Name = globalAgent.Name - baseAgent.Description = globalAgent.Description - baseAgent.Disabled = globalAgent.Disabled - baseAgent.Provider = globalAgent.Provider - baseAgent.Model = globalAgent.Model - baseAgent.AllowedTools = globalAgent.AllowedTools - baseAgent.AllowedMCP = globalAgent.AllowedMCP - baseAgent.AllowedLSP = globalAgent.AllowedLSP + baseAgent.Name = newAgent.Name + baseAgent.Description = newAgent.Description + baseAgent.Disabled = newAgent.Disabled + baseAgent.Provider = newAgent.Provider + baseAgent.Model = newAgent.Model + baseAgent.AllowedTools = newAgent.AllowedTools + baseAgent.AllowedMCP = newAgent.AllowedMCP + baseAgent.AllowedLSP = newAgent.AllowedLSP base.Agents[agentID] = baseAgent } diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index de4a86ac36d62ef0990a58d6abeb9a53572bc215..9e5e9bc7844b055c52464032dfc4d75495f9e426 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -26,6 +26,10 @@ type AgentParams struct { Prompt string `json:"prompt"` } +func (b *agentTool) Name() string { + return AgentToolName +} + func (b *agentTool) Info() tools.ToolInfo { return tools.ToolInfo{ Name: AgentToolName, diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 55b6983d053a70ccebf56f7c6d239246acf8c317..8fde02755eb320b8925891a3eca938c3cd7911f9 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -32,6 +32,10 @@ type MCPClient interface { Close() error } +func (b *mcpTool) Name() string { + return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name) +} + func (b *mcpTool) Info() tools.ToolInfo { return tools.ToolInfo{ Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name), diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 86c47c2e0c24f2f99d91eb51c946da7bbf90dfa0..0d98b74c3292c0aa066dfd0676445e587b800b57 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -3,7 +3,6 @@ package provider import ( "context" "fmt" - "maps" "os" "github.com/charmbracelet/crush/internal/llm/models" @@ -177,18 +176,6 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message return p.client.stream(ctx, messages, tools) } -func WithBaseURL(baseURL string) ProviderClientOption { - return func(options *providerClientOptions) { - options.baseURL = baseURL - } -} - -func WithAPIKey(apiKey string) ProviderClientOption { - return func(options *providerClientOptions) { - options.apiKey = apiKey - } -} - func WithModel(model models.Model) ProviderClientOption { return func(options *providerClientOptions) { options.model = model @@ -201,15 +188,6 @@ func WithDisableCache(disableCache bool) ProviderClientOption { } } -func WithExtraHeaders(extraHeaders map[string]string) ProviderClientOption { - return func(options *providerClientOptions) { - if options.extraHeaders == nil { - options.extraHeaders = make(map[string]string) - } - maps.Copy(options.extraHeaders, extraHeaders) - } -} - func WithMaxTokens(maxTokens int64) ProviderClientOption { return func(options *providerClientOptions) { options.maxTokens = maxTokens diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index abbd19113db746cd8e82c5cdebc02c4b8fc28b99..5f8b41338c8c5ef6f771e80fbd4e1355b27eb036 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -250,6 +250,10 @@ func NewBashTool(permission permission.Service) BaseTool { } } +func (b *bashTool) Name() string { + return BashToolName +} + func (b *bashTool) Info() ToolInfo { return ToolInfo{ Name: BashToolName, diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go index e48132414ad338635ada6294134e3d7f2d938c7d..5e0d19f9e5b3f72ea7588e4c4a15b71ae7b262d9 100644 --- a/internal/llm/tools/diagnostics.go +++ b/internal/llm/tools/diagnostics.go @@ -51,6 +51,10 @@ func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool { } } +func (b *diagnosticsTool) Name() string { + return DiagnosticsToolName +} + func (b *diagnosticsTool) Info() ToolInfo { return ToolInfo{ Name: DiagnosticsToolName, diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 52a202e32cd95869e3f76ebe6e3e01ca2f4802cc..b72112f43e140edd7298e802ab88ba2747784d7c 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -106,6 +106,10 @@ func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Servi } } +func (e *editTool) Name() string { + return EditToolName +} + func (e *editTool) Info() ToolInfo { return ToolInfo{ Name: EditToolName, diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 7acf23bae61df88792dd805317bdf8a67095dd0d..ac73ddbf3b0033cf503bdc8cfa2ef065a0072477 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -79,6 +79,10 @@ func NewFetchTool(permissions permission.Service) BaseTool { } } +func (t *fetchTool) Name() string { + return FetchToolName +} + func (t *fetchTool) Info() ToolInfo { return ToolInfo{ Name: FetchToolName, diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 39471a8ae81f9c31c4bccccf273f88044184f6ad..25c80860b791a5b601366d455f5ddd1ea91523ed 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -74,6 +74,10 @@ func NewGlobTool() BaseTool { return &globTool{} } +func (g *globTool) Name() string { + return GlobToolName +} + func (g *globTool) Info() ToolInfo { return ToolInfo{ Name: GlobToolName, diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 7321ad4ed9a2d713cb8685b3f4b23dbc62cc6c2a..c3e13766884f17932187ad63cb5ffaacdf375b45 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -140,6 +140,10 @@ func NewGrepTool() BaseTool { return &grepTool{} } +func (g *grepTool) Name() string { + return GrepToolName +} + func (g *grepTool) Info() ToolInfo { return ToolInfo{ Name: GrepToolName, diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index aeda758a66bc9dc796dc8d4e3722f8711aa39f67..a51b5bdb5dccb7c209d9cdc28e94dad328e8c093 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -74,6 +74,10 @@ func NewLsTool() BaseTool { return &lsTool{} } +func (l *lsTool) Name() string { + return LSToolName +} + func (l *lsTool) Info() ToolInfo { return ToolInfo{ Name: LSToolName, diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index 29518b7b818da5746d195ea8b7da521d80429962..fcc9bb57428b45a4620417775c61acb380b4d7ad 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -138,6 +138,10 @@ func NewSourcegraphTool() BaseTool { } } +func (t *sourcegraphTool) Name() string { + return SourcegraphToolName +} + func (t *sourcegraphTool) Info() ToolInfo { return ToolInfo{ Name: SourcegraphToolName, diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index bf0f8df0bacdc34fc6a3f5f89f3591b8b6fd8f15..41c0515616032b117f3c09a0056cac9e86b62c66 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -68,6 +68,7 @@ type ToolCall struct { type BaseTool interface { Info() ToolInfo + Name() string Run(ctx context.Context, params ToolCall) (ToolResponse, error) } diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 7d9cdaefcb98e0d6289396c9c44d45c4aa0b03c5..750efef73795f115e3ad90e4da9a2d955ee10529 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -77,6 +77,10 @@ func NewViewTool(lspClients map[string]*lsp.Client) BaseTool { } } +func (v *viewTool) Name() string { + return ViewToolName +} + func (v *viewTool) Info() ToolInfo { return ToolInfo{ Name: ViewToolName, diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 4186954d03acc4768233435c71c2ac5d78606119..0c213cec1f4e0a9bc8fc205a183206c0842f9688 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -84,6 +84,10 @@ func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Serv } } +func (w *writeTool) Name() string { + return WriteToolName +} + func (w *writeTool) Info() ToolInfo { return ToolInfo{ Name: WriteToolName, From 565ab85eb91102ff7a9b03f950ac13ae4391f6ab Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 10:55:44 +0200 Subject: [PATCH 06/40] chore: move to the new config --- cmd/root.go | 3 +- cmd/schema/main.go | 21 +- internal/app/app.go | 21 +- internal/config/config.go | 1295 +++++++---------- internal/{config_v2 => config}/config_test.go | 4 +- internal/{config_v2 => config}/fs.go | 2 +- internal/config/init.go | 19 +- internal/{config_v2 => config}/provider.go | 2 +- internal/config_v2/config.go | 660 --------- internal/db/connect.go | 5 +- internal/db/messages.sql.go | 14 +- ...0250627000000_add_provider_to_messages.sql | 11 + internal/db/models.go | 1 + internal/db/sql/messages.sql | 3 +- internal/fur/client/client.go | 2 +- internal/fur/provider/provider.go | 15 +- internal/llm/agent/agent-tool.go | 27 +- internal/llm/agent/agent.go | 301 ++-- internal/llm/agent/mcp-tools.go | 8 +- internal/llm/agent/tools.go | 50 - internal/llm/models/anthropic.go | 111 -- internal/llm/models/azure.go | 168 --- internal/llm/models/gemini.go | 67 - internal/llm/models/groq.go | 87 -- internal/llm/models/local.go | 206 --- internal/llm/models/models.go | 74 - internal/llm/models/openai.go | 181 --- internal/llm/models/openrouter.go | 276 ---- internal/llm/models/vertexai.go | 38 - internal/llm/models/xai.go | 61 - internal/llm/prompt/coder.go | 18 +- internal/llm/prompt/prompt.go | 54 +- internal/llm/prompt/prompt_test.go | 15 +- internal/llm/prompt/summarizer.go | 6 +- internal/llm/prompt/task.go | 4 +- internal/llm/prompt/title.go | 6 +- internal/llm/provider/anthropic.go | 10 +- internal/llm/provider/bedrock.go | 15 +- internal/llm/provider/gemini.go | 8 +- internal/llm/provider/openai.go | 10 +- internal/llm/provider/provider.go | 134 +- internal/llm/provider/vertexai.go | 7 +- internal/lsp/client.go | 28 +- internal/lsp/handlers.go | 4 +- internal/lsp/transport.go | 38 +- internal/lsp/watcher/watcher.go | 50 +- internal/message/content.go | 9 +- internal/message/message.go | 12 +- internal/tui/components/chat/header/header.go | 7 +- .../tui/components/chat/messages/messages.go | 8 +- .../tui/components/chat/sidebar/sidebar.go | 8 +- .../tui/components/dialogs/commands/loader.go | 2 +- internal/tui/components/dialogs/init/init.go | 6 +- .../tui/components/dialogs/models/models.go | 124 +- internal/tui/page/chat/chat.go | 10 +- internal/tui/tui.go | 7 +- 56 files changed, 1096 insertions(+), 3237 deletions(-) rename internal/{config_v2 => config}/config_test.go (94%) rename internal/{config_v2 => config}/fs.go (99%) rename internal/{config_v2 => config}/provider.go (98%) delete mode 100644 internal/config_v2/config.go create mode 100644 internal/db/migrations/20250627000000_add_provider_to_messages.sql delete mode 100644 internal/llm/agent/tools.go delete mode 100644 internal/llm/models/anthropic.go delete mode 100644 internal/llm/models/azure.go delete mode 100644 internal/llm/models/gemini.go delete mode 100644 internal/llm/models/groq.go delete mode 100644 internal/llm/models/local.go delete mode 100644 internal/llm/models/models.go delete mode 100644 internal/llm/models/openai.go delete mode 100644 internal/llm/models/openrouter.go delete mode 100644 internal/llm/models/vertexai.go delete mode 100644 internal/llm/models/xai.go diff --git a/cmd/root.go b/cmd/root.go index 2b5f79cf0337c386196d783ad9d18e2e1380aa5b..d741b859178e6c524b4b4e3a61863f144840812c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -72,7 +72,8 @@ to assist developers in writing, debugging, and understanding code directly from } cwd = c } - _, err := config.Load(cwd, debug) + + _, err := config.Init(cwd, debug) if err != nil { return err } diff --git a/cmd/schema/main.go b/cmd/schema/main.go index da5353c0ec7353bfa3ec9b35760b735ecc2c9ccd..9eb88769fd84772628df5332d3dcc1b1b234ac90 100644 --- a/cmd/schema/main.go +++ b/cmd/schema/main.go @@ -1,3 +1,4 @@ +// TODO: FIX THIS package main import ( @@ -6,7 +7,6 @@ import ( "os" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" ) // JSONSchemaType represents a JSON Schema type @@ -192,22 +192,10 @@ func generateSchema() map[string]any { }, } - // Add known providers - knownProviders := []string{ - string(models.ProviderAnthropic), - string(models.ProviderOpenAI), - string(models.ProviderGemini), - string(models.ProviderGROQ), - string(models.ProviderOpenRouter), - string(models.ProviderBedrock), - string(models.ProviderAzure), - string(models.ProviderVertexAI), - } - providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{ "type": "string", "description": "Provider type", - "enum": knownProviders, + "enum": []string{}, } schema["properties"].(map[string]any)["providers"] = providerSchema @@ -241,9 +229,7 @@ func generateSchema() map[string]any { // Add model enum modelEnum := []string{} - for modelID := range models.SupportedModels { - modelEnum = append(modelEnum, string(modelID)) - } + agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum // Add specific agent properties @@ -251,7 +237,6 @@ func generateSchema() map[string]any { knownAgents := []string{ string(config.AgentCoder), string(config.AgentTask), - string(config.AgentTitle), } for _, agentName := range knownAgents { diff --git a/internal/app/app.go b/internal/app/app.go index e7472059a9f3fad360172c353f5d9a188529d177..75042e89648779cf50a4376aa01aa3b6ac8e72a0 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/charmbracelet/crush/internal/config" + configv2 "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" @@ -55,18 +55,21 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { // Initialize LSP clients in the background go app.initLSPClients(ctx) + cfg := configv2.Get() + + coderAgentCfg := cfg.Agents[configv2.AgentCoder] + if coderAgentCfg.ID == "" { + return nil, fmt.Errorf("coder agent configuration is missing") + } + var err error app.CoderAgent, err = agent.NewAgent( - config.AgentCoder, + coderAgentCfg, + app.Permissions, app.Sessions, app.Messages, - agent.CoderAgentTools( - app.Permissions, - app.Sessions, - app.Messages, - app.History, - app.LSPClients, - ), + app.History, + app.LSPClients, ) if err != nil { logging.Error("Failed to create coder agent", err) diff --git a/internal/config/config.go b/internal/config/config.go index 3944cb1374582f9af0eeb7bfadd05ef5f9a8c198..13444a5ccc8e99bdaa57a6156151b45a40176c09 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,67 +1,132 @@ -// Package config manages application configuration from various sources. package config import ( "encoding/json" + "errors" "fmt" "log/slog" + "maps" "os" "path/filepath" + "slices" "strings" + "sync" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/logging" - "github.com/spf13/afero" - "github.com/spf13/viper" ) -// MCPType defines the type of MCP (Model Control Protocol) server. -type MCPType string - -// Supported MCP types const ( - MCPStdio MCPType = "stdio" - MCPSse MCPType = "sse" + defaultDataDirectory = ".crush" + defaultLogLevel = "info" + appName = "crush" + + MaxTokensFallbackDefault = 4096 ) -// MCPServer defines the configuration for a Model Control Protocol server. -type MCPServer struct { - Command string `json:"command"` - Env []string `json:"env"` - Args []string `json:"args"` - Type MCPType `json:"type"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` +var defaultContextPaths = []string{ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md", } -type AgentName string +type AgentID string const ( - AgentCoder AgentName = "coder" - AgentSummarizer AgentName = "summarizer" - AgentTask AgentName = "task" - AgentTitle AgentName = "title" + AgentCoder AgentID = "coder" + AgentTask AgentID = "task" ) -// Agent defines configuration for different LLM models and their token limits. -type Agent struct { - Model models.ModelID `json:"model"` - MaxTokens int64 `json:"maxTokens"` - ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh +type Model struct { + ID string `json:"id"` + Name string `json:"model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + ReasoningEffort string `json:"reasoning_effort"` + SupportsImages bool `json:"supports_attachments"` } -// Provider defines configuration for an LLM provider. -type Provider struct { - APIKey string `json:"apiKey"` - Disabled bool `json:"disabled"` +type VertexAIOptions struct { + APIKey string `json:"api_key,omitempty"` + Project string `json:"project,omitempty"` + Location string `json:"location,omitempty"` } -// Data defines storage configuration. -type Data struct { - Directory string `json:"directory,omitempty"` +type ProviderConfig struct { + ID provider.InferenceProvider `json:"id"` + BaseURL string `json:"base_url,omitempty"` + ProviderType provider.Type `json:"provider_type"` + APIKey string `json:"api_key,omitempty"` + Disabled bool `json:"disabled"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // used for e.x for vertex to set the project + ExtraParams map[string]string `json:"extra_params,omitempty"` + + DefaultLargeModel string `json:"default_large_model,omitempty"` + DefaultSmallModel string `json:"default_small_model,omitempty"` + + Models []Model `json:"models,omitempty"` +} + +type Agent struct { + ID AgentID `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + // This is the id of the system prompt used by the agent + Disabled bool `json:"disabled"` + + Provider provider.InferenceProvider `json:"provider"` + Model string `json:"model"` + + // The available tools for the agent + // if this is nil, all tools are available + AllowedTools []string `json:"allowed_tools"` + + // this tells us which MCPs are available for this agent + // if this is empty all mcps are available + // the string array is the list of tools from the AllowedMCP the agent has available + // if the string array is nil, all tools from the AllowedMCP are available + AllowedMCP map[string][]string `json:"allowed_mcp"` + + // The list of LSPs that this agent can use + // if this is nil, all LSPs are available + AllowedLSP []string `json:"allowed_lsp"` + + // Overrides the context paths for this agent + ContextPaths []string `json:"context_paths"` +} + +type MCPType string + +const ( + MCPStdio MCPType = "stdio" + MCPSse MCPType = "sse" +) + +type MCP struct { + Command string `json:"command"` + Env []string `json:"env"` + Args []string `json:"args"` + Type MCPType `json:"type"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` } -// LSPConfig defines configuration for Language Server Protocol integration. type LSPConfig struct { Disabled bool `json:"enabled"` Command string `json:"command"` @@ -69,98 +134,72 @@ type LSPConfig struct { Options any `json:"options"` } -// TUIConfig defines the configuration for the Terminal User Interface. -type TUIConfig struct { - Theme string `json:"theme,omitempty"` +type TUIOptions struct { + CompactMode bool `json:"compact_mode"` + // Here we can add themes later or any TUI related options } -// Config is the main configuration structure for the application. -type Config struct { - Data Data `json:"data"` - WorkingDir string `json:"wd,omitempty"` - MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` - Providers map[models.InferenceProvider]Provider `json:"providers,omitempty"` - LSP map[string]LSPConfig `json:"lsp,omitempty"` - Agents map[AgentName]Agent `json:"agents,omitempty"` - Debug bool `json:"debug,omitempty"` - DebugLSP bool `json:"debugLSP,omitempty"` - ContextPaths []string `json:"contextPaths,omitempty"` - TUI TUIConfig `json:"tui"` - AutoCompact bool `json:"autoCompact,omitempty"` +type Options struct { + ContextPaths []string `json:"context_paths"` + TUI TUIOptions `json:"tui"` + Debug bool `json:"debug"` + DebugLSP bool `json:"debug_lsp"` + DisableAutoSummarize bool `json:"disable_auto_summarize"` + // Relative to the cwd + DataDirectory string `json:"data_directory"` } -// Application constants -const ( - defaultDataDirectory = ".crush" - defaultLogLevel = "info" - appName = "crush" - - MaxTokensFallbackDefault = 4096 -) +type PreferredModel struct { + ModelID string `json:"model_id"` + Provider provider.InferenceProvider `json:"provider"` +} -var defaultContextPaths = []string{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", +type PreferredModels struct { + Large PreferredModel `json:"large"` + Small PreferredModel `json:"small"` } -// Global configuration instance -var cfg *Config +type Config struct { + Models PreferredModels `json:"models"` + // List of configured providers + Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` -// Load initializes the configuration from environment variables and config files. -// If debug is true, debug mode is enabled and log level is set to debug. -// It returns an error if configuration loading fails. -func Load(workingDir string, debug bool) (*Config, error) { - if cfg != nil { - return cfg, nil - } + // List of configured agents + Agents map[AgentID]Agent `json:"agents,omitempty"` - cfg = &Config{ - WorkingDir: workingDir, - MCPServers: make(map[string]MCPServer), - Providers: make(map[models.InferenceProvider]Provider), - LSP: make(map[string]LSPConfig), - } + // List of configured MCPs + MCP map[string]MCP `json:"mcp,omitempty"` - configureViper() - setDefaults(debug) + // List of configured LSPs + LSP map[string]LSPConfig `json:"lsp,omitempty"` - // Read global config - if err := readConfig(viper.ReadInConfig()); err != nil { - return cfg, err - } + // Miscellaneous options + Options Options `json:"options"` +} - // Load and merge local config - mergeLocalConfig(workingDir) +var ( + instance *Config // The single instance of the Singleton + cwd string + once sync.Once // Ensures the initialization happens only once - setProviderDefaults() +) - // Apply configuration to the struct - if err := viper.Unmarshal(cfg); err != nil { - return cfg, fmt.Errorf("failed to unmarshal config: %w", err) - } +func loadConfig(cwd string, debug bool) (*Config, error) { + // First read the global config file + cfgPath := ConfigPath() - applyDefaultValues() + cfg := defaultConfigBasedOnEnv() + cfg.Options.Debug = debug defaultLevel := slog.LevelInfo - if cfg.Debug { + if cfg.Options.Debug { defaultLevel = slog.LevelDebug } if os.Getenv("CRUSH_DEV_DEBUG") == "true" { - loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log") + loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log") // if file does not exist create it if _, err := os.Stat(loggingFile); os.IsNotExist(err) { - if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil { + if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil { return cfg, fmt.Errorf("failed to create directory: %w", err) } if _, err := os.Create(loggingFile); err != nil { @@ -184,734 +223,530 @@ func Load(workingDir string, debug bool) (*Config, error) { })) slog.SetDefault(logger) } - - // Validate configuration - if err := Validate(); err != nil { - return cfg, fmt.Errorf("config validation failed: %w", err) - } - - if cfg.Agents == nil { - cfg.Agents = make(map[AgentName]Agent) - } - - // Override the max tokens for title agent - cfg.Agents[AgentTitle] = Agent{ - Model: cfg.Agents[AgentTitle].Model, - MaxTokens: 80, - } - return cfg, nil -} - -type configFinder struct { - appName string - dotPrefix bool - paths []string -} - -func (f configFinder) Find(fsys afero.Fs) ([]string, error) { - var configFiles []string - configName := fmt.Sprintf("%s.json", f.appName) - if f.dotPrefix { - configName = fmt.Sprintf(".%s.json", f.appName) - } - paths := []string{} - for _, p := range f.paths { - if p == "" { - continue + var globalCfg *Config + if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // config file exists, read it + file, err := os.ReadFile(cfgPath) + if err != nil { + return nil, err } - paths = append(paths, os.ExpandEnv(p)) - } - - for _, path := range paths { - if path == "" { - continue + globalCfg = &Config{} + if err := json.Unmarshal(file, globalCfg); err != nil { + return nil, err } - - configPath := filepath.Join(path, configName) - if exists, err := afero.Exists(fsys, configPath); err == nil && exists { - configFiles = append(configFiles, configPath) + } else { + // config file does not exist, create a new one + globalCfg = &Config{} + } + + var localConfig *Config + // Global config loaded, now read the local config file + localConfigPath := filepath.Join(cwd, "crush.json") + if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // local config file exists, read it + file, err := os.ReadFile(localConfigPath) + if err != nil { + return nil, err + } + localConfig = &Config{} + if err := json.Unmarshal(file, localConfig); err != nil { + return nil, err } } - return configFiles, nil -} -// configureViper sets up viper's configuration paths and environment variables. -func configureViper() { - viper.SetConfigType("json") - - // Create the three finders - windowsFinder := configFinder{appName: appName, dotPrefix: false, paths: []string{ - "$USERPROFILE", - fmt.Sprintf("$APPDATA/%s", appName), - fmt.Sprintf("$LOCALAPPDATA/%s", appName), - }} - - unixFinder := configFinder{appName: appName, dotPrefix: false, paths: []string{ - "$HOME", - fmt.Sprintf("$XDG_CONFIG_HOME/%s", appName), - fmt.Sprintf("$HOME/.config/%s", appName), - }} - - localFinder := configFinder{appName: appName, dotPrefix: true, paths: []string{ - ".", - }} - - // Use all finders with viper - viper.SetOptions(viper.WithFinder(viper.Finders(windowsFinder, unixFinder, localFinder))) - viper.SetEnvPrefix(strings.ToUpper(appName)) - viper.AutomaticEnv() -} + // merge options + mergeOptions(cfg, globalCfg, localConfig) -// setDefaults configures default values for configuration options. -func setDefaults(debug bool) { - viper.SetDefault("data.directory", defaultDataDirectory) - viper.SetDefault("contextPaths", defaultContextPaths) - viper.SetDefault("tui.theme", "crush") - viper.SetDefault("autoCompact", true) - - if debug { - viper.SetDefault("debug", true) - viper.Set("log.level", "debug") - } else { - viper.SetDefault("debug", false) - viper.SetDefault("log.level", defaultLogLevel) + mergeProviderConfigs(cfg, globalCfg, localConfig) + // no providers found the app is not initialized yet + if len(cfg.Providers) == 0 { + return cfg, nil } -} + preferredProvider := getPreferredProvider(cfg.Providers) + cfg.Models = PreferredModels{ + Large: PreferredModel{ + ModelID: preferredProvider.DefaultLargeModel, + Provider: preferredProvider.ID, + }, + Small: PreferredModel{ + ModelID: preferredProvider.DefaultSmallModel, + Provider: preferredProvider.ID, + }, + } + + mergeModels(cfg, globalCfg, localConfig) + + if preferredProvider == nil { + return nil, errors.New("no valid providers configured") + } + + agents := map[AgentID]Agent{ + AgentCoder: { + ID: AgentCoder, + Name: "Coder", + Description: "An agent that helps with executing coding tasks.", + Provider: cfg.Models.Large.Provider, + Model: cfg.Models.Large.ModelID, + ContextPaths: cfg.Options.ContextPaths, + // All tools allowed + }, + AgentTask: { + ID: AgentTask, + Name: "Task", + Description: "An agent that helps with searching for context and finding implementation details.", + Provider: cfg.Models.Large.Provider, + Model: cfg.Models.Large.ModelID, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{ + "glob", + "grep", + "ls", + "sourcegraph", + "view", + }, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + } + cfg.Agents = agents + mergeAgents(cfg, globalCfg, localConfig) + mergeMCPs(cfg, globalCfg, localConfig) + mergeLSPs(cfg, globalCfg, localConfig) -// setProviderDefaults configures LLM provider defaults based on provider provided by -// environment variables and configuration file. -func setProviderDefaults() { - // Set all API keys we can find in the environment - if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { - viper.SetDefault("providers.anthropic.apiKey", apiKey) - } - if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { - viper.SetDefault("providers.openai.apiKey", apiKey) - } - if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { - viper.SetDefault("providers.gemini.apiKey", apiKey) - } - if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { - viper.SetDefault("providers.groq.apiKey", apiKey) - } - if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" { - viper.SetDefault("providers.openrouter.apiKey", apiKey) - } - if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" { - viper.SetDefault("providers.xai.apiKey", apiKey) - } - if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" { - // api-key may be empty when using Entra ID credentials – that's okay - viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY")) - } - - // Use this order to set the default models - // 1. Anthropic - // 2. OpenAI - // 3. Google Gemini - // 4. Groq - // 5. OpenRouter - // 6. AWS Bedrock - // 7. Azure - // 8. Google Cloud VertexAI - - // Anthropic configuration - if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.Claude4Sonnet) - viper.SetDefault("agents.summarizer.model", models.Claude4Sonnet) - viper.SetDefault("agents.task.model", models.Claude4Sonnet) - viper.SetDefault("agents.title.model", models.Claude4Sonnet) - return - } - - // OpenAI configuration - if key := viper.GetString("providers.openai.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.GPT41) - viper.SetDefault("agents.summarizer.model", models.GPT41) - viper.SetDefault("agents.task.model", models.GPT41Mini) - viper.SetDefault("agents.title.model", models.GPT41Mini) - return - } - - // Google Gemini configuration - if key := viper.GetString("providers.gemini.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.Gemini25) - viper.SetDefault("agents.summarizer.model", models.Gemini25) - viper.SetDefault("agents.task.model", models.Gemini25Flash) - viper.SetDefault("agents.title.model", models.Gemini25Flash) - return - } - - // Groq configuration - if key := viper.GetString("providers.groq.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.QWENQwq) - viper.SetDefault("agents.summarizer.model", models.QWENQwq) - viper.SetDefault("agents.task.model", models.QWENQwq) - viper.SetDefault("agents.title.model", models.QWENQwq) - return - } - - // OpenRouter configuration - if key := viper.GetString("providers.openrouter.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.OpenRouterClaude37Sonnet) - viper.SetDefault("agents.summarizer.model", models.OpenRouterClaude37Sonnet) - viper.SetDefault("agents.task.model", models.OpenRouterClaude37Sonnet) - viper.SetDefault("agents.title.model", models.OpenRouterClaude35Haiku) - return - } - - // XAI configuration - if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.XAIGrok3Beta) - viper.SetDefault("agents.summarizer.model", models.XAIGrok3Beta) - viper.SetDefault("agents.task.model", models.XAIGrok3Beta) - viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta) - return - } - - // AWS Bedrock configuration - if hasAWSCredentials() { - viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.summarizer.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet) - return - } - - // Azure OpenAI configuration - if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" { - viper.SetDefault("agents.coder.model", models.AzureGPT41) - viper.SetDefault("agents.summarizer.model", models.AzureGPT41) - viper.SetDefault("agents.task.model", models.AzureGPT41Mini) - viper.SetDefault("agents.title.model", models.AzureGPT41Mini) - return - } - - // Google Cloud VertexAI configuration - if hasVertexAICredentials() { - viper.SetDefault("agents.coder.model", models.VertexAIGemini25) - viper.SetDefault("agents.summarizer.model", models.VertexAIGemini25) - viper.SetDefault("agents.task.model", models.VertexAIGemini25Flash) - viper.SetDefault("agents.title.model", models.VertexAIGemini25Flash) - return - } + return cfg, nil } -// hasAWSCredentials checks if AWS credentials are available in the environment. -func hasAWSCredentials() bool { - // Check for explicit AWS credentials - if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { - return true - } - - // Check for AWS profile - if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { - return true - } +func Init(workingDir string, debug bool) (*Config, error) { + var err error + once.Do(func() { + cwd = workingDir + instance, err = loadConfig(cwd, debug) + if err != nil { + logging.Error("Failed to load config", "error", err) + } + }) - // Check for AWS region - if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { - return true - } + return instance, err +} - // Check if running on EC2 with instance profile - if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || - os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { - return true +func Get() *Config { + if instance == nil { + // TODO: Handle this better + panic("Config not initialized. Call InitConfig first.") } - - return false + return instance } -// hasVertexAICredentials checks if VertexAI credentials are available in the environment. -func hasVertexAICredentials() bool { - // Check for explicit VertexAI parameters - if os.Getenv("VERTEXAI_PROJECT") != "" && os.Getenv("VERTEXAI_LOCATION") != "" { - return true +func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig { + providers := Providers() + for _, p := range providers { + if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled { + return &providerConfig + } } - // Check for Google Cloud project and location - if os.Getenv("GOOGLE_CLOUD_PROJECT") != "" && (os.Getenv("GOOGLE_CLOUD_REGION") != "" || os.Getenv("GOOGLE_CLOUD_LOCATION") != "") { - return true + // if none found return the first configured provider + for _, providerConfig := range configuredProviders { + if !providerConfig.Disabled { + return &providerConfig + } } - return false + return nil } -// readConfig handles the result of reading a configuration file. -func readConfig(err error) error { - if err == nil { - return nil +func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { + if other.APIKey != "" { + base.APIKey = other.APIKey } - - // It's okay if the config file doesn't exist - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - return nil + // Only change these options if the provider is not a known provider + if !slices.Contains(provider.KnownProviders(), p) { + if other.BaseURL != "" { + base.BaseURL = other.BaseURL + } + if other.ProviderType != "" { + base.ProviderType = other.ProviderType + } + if len(base.ExtraHeaders) > 0 { + if base.ExtraHeaders == nil { + base.ExtraHeaders = make(map[string]string) + } + maps.Copy(base.ExtraHeaders, other.ExtraHeaders) + } + if len(other.ExtraParams) > 0 { + if base.ExtraParams == nil { + base.ExtraParams = make(map[string]string) + } + maps.Copy(base.ExtraParams, other.ExtraParams) + } } - return fmt.Errorf("failed to read config: %w", err) -} - -// mergeLocalConfig loads and merges configuration from the local directory. -func mergeLocalConfig(workingDir string) { - local := viper.New() - local.SetConfigName(fmt.Sprintf(".%s", appName)) - local.SetConfigType("json") - local.AddConfigPath(workingDir) - - // Merge local config if it exists - if err := local.ReadInConfig(); err == nil { - viper.MergeConfigMap(local.AllSettings()) + if other.Disabled { + base.Disabled = other.Disabled } -} -// applyDefaultValues sets default values for configuration fields that need processing. -func applyDefaultValues() { - // Set default MCP type if not specified - for k, v := range cfg.MCPServers { - if v.Type == "" { - v.Type = MCPStdio - cfg.MCPServers[k] = v - } + if other.DefaultLargeModel != "" { + base.DefaultLargeModel = other.DefaultLargeModel } -} - -// It validates model IDs and providers, ensuring they are supported. -func validateAgent(cfg *Config, name AgentName, agent Agent) error { - // Check if model exists - model, modelExists := models.SupportedModels[agent.Model] - if !modelExists { - logging.Warn("unsupported model configured, reverting to default", - "agent", name, - "configured_model", agent.Model) - - // Set default model based on available providers - if setDefaultModelForAgent(name) { - logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) - } else { - return fmt.Errorf("no valid provider available for agent %s", name) - } - return nil - } - - // Check if provider for the model is configured - provider := model.Provider - providerCfg, providerExists := cfg.Providers[provider] - - if !providerExists { - // Provider not configured, check if we have environment variables - apiKey := getProviderAPIKey(provider) - if apiKey == "" { - logging.Warn("provider not configured for model, reverting to default", - "agent", name, - "model", agent.Model, - "provider", provider) - - // Set default model based on available providers - if setDefaultModelForAgent(name) { - logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) - } else { - return fmt.Errorf("no valid provider available for agent %s", name) + // Add new models if they don't exist + if other.Models != nil { + for _, model := range other.Models { + // check if the model already exists + exists := false + for _, existingModel := range base.Models { + if existingModel.ID == model.ID { + exists = true + break + } } - } else { - // Add provider with API key from environment - cfg.Providers[provider] = Provider{ - APIKey: apiKey, - } - logging.Info("added provider from environment", "provider", provider) - } - } else if providerCfg.Disabled || providerCfg.APIKey == "" { - // Provider is disabled or has no API key - logging.Warn("provider is disabled or has no API key, reverting to default", - "agent", name, - "model", agent.Model, - "provider", provider) - - // Set default model based on available providers - if setDefaultModelForAgent(name) { - logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) - } else { - return fmt.Errorf("no valid provider available for agent %s", name) - } - } - - // Validate max tokens - if agent.MaxTokens <= 0 { - logging.Warn("invalid max tokens, setting to default", - "agent", name, - "model", agent.Model, - "max_tokens", agent.MaxTokens) - - // Update the agent with default max tokens - updatedAgent := cfg.Agents[name] - if model.DefaultMaxTokens > 0 { - updatedAgent.MaxTokens = model.DefaultMaxTokens - } else { - updatedAgent.MaxTokens = MaxTokensFallbackDefault - } - cfg.Agents[name] = updatedAgent - } else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 { - // Ensure max tokens doesn't exceed half the context window (reasonable limit) - logging.Warn("max tokens exceeds half the context window, adjusting", - "agent", name, - "model", agent.Model, - "max_tokens", agent.MaxTokens, - "context_window", model.ContextWindow) - - // Update the agent with adjusted max tokens - updatedAgent := cfg.Agents[name] - updatedAgent.MaxTokens = model.ContextWindow / 2 - cfg.Agents[name] = updatedAgent - } - - // Validate reasoning effort for models that support reasoning - if model.CanReason && provider == models.ProviderOpenAI || provider == models.ProviderLocal { - if agent.ReasoningEffort == "" { - // Set default reasoning effort for models that support it - logging.Info("setting default reasoning effort for model that supports reasoning", - "agent", name, - "model", agent.Model) - - // Update the agent with default reasoning effort - updatedAgent := cfg.Agents[name] - updatedAgent.ReasoningEffort = "medium" - cfg.Agents[name] = updatedAgent - } else { - // Check if reasoning effort is valid (low, medium, high) - effort := strings.ToLower(agent.ReasoningEffort) - if effort != "low" && effort != "medium" && effort != "high" { - logging.Warn("invalid reasoning effort, setting to medium", - "agent", name, - "model", agent.Model, - "reasoning_effort", agent.ReasoningEffort) - - // Update the agent with valid reasoning effort - updatedAgent := cfg.Agents[name] - updatedAgent.ReasoningEffort = "medium" - cfg.Agents[name] = updatedAgent + if !exists { + base.Models = append(base.Models, model) } } - } else if !model.CanReason && agent.ReasoningEffort != "" { - // Model doesn't support reasoning but reasoning effort is set - logging.Warn("model doesn't support reasoning but reasoning effort is set, ignoring", - "agent", name, - "model", agent.Model, - "reasoning_effort", agent.ReasoningEffort) - - // Update the agent to remove reasoning effort - updatedAgent := cfg.Agents[name] - updatedAgent.ReasoningEffort = "" - cfg.Agents[name] = updatedAgent } - return nil + return base } -// Validate checks if the configuration is valid and applies defaults where needed. -func Validate() error { - if cfg == nil { - return fmt.Errorf("config not loaded") - } - - // Validate agent models - for name, agent := range cfg.Agents { - if err := validateAgent(cfg, name, agent); err != nil { - return err +func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error { + if !slices.Contains(provider.KnownProviders(), p) { + if providerConfig.ProviderType != provider.TypeOpenAI { + return errors.New("invalid provider type: " + string(providerConfig.ProviderType)) } - } - - // Validate providers - for provider, providerCfg := range cfg.Providers { - if providerCfg.APIKey == "" && !providerCfg.Disabled { - logging.Warn("provider has no API key, marking as disabled", "provider", provider) - providerCfg.Disabled = true - cfg.Providers[provider] = providerCfg + if providerConfig.BaseURL == "" { + return errors.New("base URL must be set for custom providers") } - } - - // Validate LSP configurations - for language, lspConfig := range cfg.LSP { - if lspConfig.Command == "" && !lspConfig.Disabled { - logging.Warn("LSP configuration has no command, marking as disabled", "language", language) - lspConfig.Disabled = true - cfg.LSP[language] = lspConfig + if providerConfig.APIKey == "" { + return errors.New("API key must be set for custom providers") } } - return nil } -// getProviderAPIKey gets the API key for a provider from environment variables -func getProviderAPIKey(provider models.InferenceProvider) string { - switch provider { - case models.ProviderAnthropic: - return os.Getenv("ANTHROPIC_API_KEY") - case models.ProviderOpenAI: - return os.Getenv("OPENAI_API_KEY") - case models.ProviderGemini: - return os.Getenv("GEMINI_API_KEY") - case models.ProviderGROQ: - return os.Getenv("GROQ_API_KEY") - case models.ProviderAzure: - return os.Getenv("AZURE_OPENAI_API_KEY") - case models.ProviderOpenRouter: - return os.Getenv("OPENROUTER_API_KEY") - case models.ProviderBedrock: - if hasAWSCredentials() { - return "aws-credentials-available" - } - case models.ProviderVertexAI: - if hasVertexAICredentials() { - return "vertex-ai-credentials-available" - } - } - return "" -} - -// setDefaultModelForAgent sets a default model for an agent based on available providers -func setDefaultModelForAgent(agent AgentName) bool { - // Check providers in order of preference - if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { - maxTokens := int64(5000) - if agent == AgentTitle { - maxTokens = 80 +func mergeModels(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } - cfg.Agents[agent] = Agent{ - Model: models.Claude37Sonnet, - MaxTokens: maxTokens, + if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" { + base.Models.Large = cfg.Models.Large } - return true - } - if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { - var model models.ModelID - maxTokens := int64(5000) - reasoningEffort := "" - - switch agent { - case AgentTitle: - model = models.GPT41Mini - maxTokens = 80 - case AgentTask: - model = models.GPT41Mini - default: - model = models.GPT41 + if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" { + base.Models.Small = cfg.Models.Small } + } +} - // Check if model supports reasoning - if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason { - reasoningEffort = "medium" +func mergeOptions(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + baseOptions := base.Options + other := cfg.Options + if len(other.ContextPaths) > 0 { + baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...) } - cfg.Agents[agent] = Agent{ - Model: model, - MaxTokens: maxTokens, - ReasoningEffort: reasoningEffort, + if other.TUI.CompactMode { + baseOptions.TUI.CompactMode = other.TUI.CompactMode } - return true - } - if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" { - var model models.ModelID - maxTokens := int64(5000) - reasoningEffort := "" + if other.Debug { + baseOptions.Debug = other.Debug + } - switch agent { - case AgentTitle: - model = models.OpenRouterClaude35Haiku - maxTokens = 80 - case AgentTask: - model = models.OpenRouterClaude37Sonnet - default: - model = models.OpenRouterClaude37Sonnet + if other.DebugLSP { + baseOptions.DebugLSP = other.DebugLSP } - // Check if model supports reasoning - if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason { - reasoningEffort = "medium" + if other.DisableAutoSummarize { + baseOptions.DisableAutoSummarize = other.DisableAutoSummarize } - cfg.Agents[agent] = Agent{ - Model: model, - MaxTokens: maxTokens, - ReasoningEffort: reasoningEffort, + if other.DataDirectory != "" { + baseOptions.DataDirectory = other.DataDirectory } - return true + base.Options = baseOptions } +} - if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { - var model models.ModelID - maxTokens := int64(5000) - - if agent == AgentTitle { - model = models.Gemini25Flash - maxTokens = 80 - } else { - model = models.Gemini25 +func mergeAgents(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } - - cfg.Agents[agent] = Agent{ - Model: model, - MaxTokens: maxTokens, + for agentID, newAgent := range cfg.Agents { + if _, ok := base.Agents[agentID]; !ok { + newAgent.ID = agentID // Ensure the ID is set correctly + base.Agents[agentID] = newAgent + } else { + switch agentID { + case AgentCoder: + baseAgent := base.Agents[agentID] + if newAgent.Model != "" && newAgent.Provider != "" { + baseAgent.Model = newAgent.Model + baseAgent.Provider = newAgent.Provider + } + baseAgent.AllowedMCP = newAgent.AllowedMCP + baseAgent.AllowedLSP = newAgent.AllowedLSP + base.Agents[agentID] = baseAgent + default: + baseAgent := base.Agents[agentID] + baseAgent.Name = newAgent.Name + baseAgent.Description = newAgent.Description + baseAgent.Disabled = newAgent.Disabled + if newAgent.Model == "" || newAgent.Provider == "" { + baseAgent.Provider = base.Models.Large.Provider + baseAgent.Model = base.Models.Large.ModelID + } + baseAgent.AllowedTools = newAgent.AllowedTools + baseAgent.AllowedMCP = newAgent.AllowedMCP + baseAgent.AllowedLSP = newAgent.AllowedLSP + base.Agents[agentID] = baseAgent + } + } } - return true } +} - if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { - maxTokens := int64(5000) - if agent == AgentTitle { - maxTokens = 80 +func mergeMCPs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } + maps.Copy(base.MCP, cfg.MCP) + } +} - cfg.Agents[agent] = Agent{ - Model: models.QWENQwq, - MaxTokens: maxTokens, +func mergeLSPs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } - return true + maps.Copy(base.LSP, cfg.LSP) } +} - if hasAWSCredentials() { - maxTokens := int64(5000) - if agent == AgentTitle { - maxTokens = 80 +func mergeProviderConfigs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } - - cfg.Agents[agent] = Agent{ - Model: models.BedrockClaude37Sonnet, - MaxTokens: maxTokens, - ReasoningEffort: "medium", // Claude models support reasoning + for providerName, globalProvider := range cfg.Providers { + if _, ok := base.Providers[providerName]; !ok { + base.Providers[providerName] = globalProvider + } else { + base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider) + } } - return true } - if hasVertexAICredentials() { - var model models.ModelID - maxTokens := int64(5000) - - if agent == AgentTitle { - model = models.VertexAIGemini25Flash - maxTokens = 80 - } else { - model = models.VertexAIGemini25 + finalProviders := make(map[provider.InferenceProvider]ProviderConfig) + for providerName, providerConfig := range base.Providers { + err := validateProvider(providerName, providerConfig) + if err != nil { + logging.Warn("Skipping provider", "name", providerName, "error", err) } + finalProviders[providerName] = providerConfig + } + base.Providers = finalProviders +} - cfg.Agents[agent] = Agent{ - Model: model, - MaxTokens: maxTokens, +func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig { + switch providerId { + case provider.InferenceProviderAnthropic: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeAnthropic, + } + case provider.InferenceProviderOpenAI: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeOpenAI, + } + case provider.InferenceProviderGemini: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeGemini, + } + case provider.InferenceProviderBedrock: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeBedrock, + } + case provider.InferenceProviderAzure: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeAzure, + } + case provider.InferenceProviderOpenRouter: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeOpenAI, + BaseURL: "https://openrouter.ai/api/v1", + ExtraHeaders: map[string]string{ + "HTTP-Referer": "crush.charm.land", + "X-Title": "Crush", + }, + } + case provider.InferenceProviderXAI: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeXAI, + BaseURL: "https://api.x.ai/v1", + } + case provider.InferenceProviderVertexAI: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeVertexAI, + } + default: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeOpenAI, } - return true } - - return false } -func updateCfgFile(updateCfg func(config *Config)) error { - if cfg == nil { - return fmt.Errorf("config not loaded") +func defaultConfigBasedOnEnv() *Config { + cfg := &Config{ + Options: Options{ + DataDirectory: defaultDataDirectory, + ContextPaths: defaultContextPaths, + }, + Providers: make(map[provider.InferenceProvider]ProviderConfig), + } + + providers := Providers() + + for _, p := range providers { + if strings.HasPrefix(p.APIKey, "$") { + envVar := strings.TrimPrefix(p.APIKey, "$") + if apiKey := os.Getenv(envVar); apiKey != "" { + providerConfig := providerDefaultConfig(p.ID) + providerConfig.APIKey = apiKey + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + baseURL := p.APIEndpoint + if strings.HasPrefix(baseURL, "$") { + envVar := strings.TrimPrefix(baseURL, "$") + if url := os.Getenv(envVar); url != "" { + baseURL = url + } + } + providerConfig.BaseURL = baseURL + for _, model := range p.Models { + providerConfig.Models = append(providerConfig.Models, Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + }) + } + cfg.Providers[p.ID] = providerConfig + } + } } + // TODO: support local models - // Get the config file path - configFile := viper.ConfigFileUsed() - var configData []byte - if configFile == "" { - homeDir, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to get home directory: %w", err) + if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" { + providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI) + providerConfig.ExtraParams = map[string]string{ + "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), + "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), } - configFile = filepath.Join(homeDir, fmt.Sprintf(".%s.json", appName)) - logging.Info("config file not found, creating new one", "path", configFile) - configData = []byte(`{}`) - } else { - // Read the existing config file - data, err := os.ReadFile(configFile) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - configData = data + cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig } - // Parse the JSON - var userCfg *Config - if err := json.Unmarshal(configData, &userCfg); err != nil { - return fmt.Errorf("failed to parse config file: %w", err) + if hasAWSCredentials() { + providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock) + providerConfig.ExtraParams = map[string]string{ + "region": os.Getenv("AWS_DEFAULT_REGION"), + } + if providerConfig.ExtraParams["region"] == "" { + providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION") + } + cfg.Providers[provider.InferenceProviderBedrock] = providerConfig } + return cfg +} - updateCfg(userCfg) +func hasAWSCredentials() bool { + if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { + return true + } - // Write the updated config back to file - updatedData, err := json.MarshalIndent(userCfg, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal config: %w", err) + if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { + return true } - if err := os.WriteFile(configFile, updatedData, 0o644); err != nil { - return fmt.Errorf("failed to write config file: %w", err) + if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { + return true } - return nil -} + if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { + return true + } -// Get returns the current configuration. -// It's safe to call this function multiple times. -func Get() *Config { - return cfg + return false } -// WorkingDirectory returns the current working directory from the configuration. func WorkingDirectory() string { - if cfg == nil { - panic("config not loaded") - } - return cfg.WorkingDir + return cwd } -func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error { - if cfg == nil { - panic("config not loaded") +func GetAgentModel(agentID AgentID) Model { + cfg := Get() + agent, ok := cfg.Agents[agentID] + if !ok { + logging.Error("Agent not found", "agent_id", agentID) + return Model{} } - existingAgentCfg := cfg.Agents[agentName] - - model, ok := models.SupportedModels[modelID] + providerConfig, ok := cfg.Providers[agent.Provider] if !ok { - return fmt.Errorf("model %s not supported", modelID) + logging.Error("Provider not found for agent", "agent_id", agentID, "provider", agent.Provider) + return Model{} } - maxTokens := existingAgentCfg.MaxTokens - if model.DefaultMaxTokens > 0 { - maxTokens = model.DefaultMaxTokens + for _, model := range providerConfig.Models { + if model.ID == agent.Model { + return model + } } - newAgentCfg := Agent{ - Model: modelID, - MaxTokens: maxTokens, - ReasoningEffort: existingAgentCfg.ReasoningEffort, - } - cfg.Agents[agentName] = newAgentCfg + logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model) + return Model{} +} - if err := validateAgent(cfg, agentName, newAgentCfg); err != nil { - // revert config update on failure - cfg.Agents[agentName] = existingAgentCfg - return fmt.Errorf("failed to update agent model: %w", err) +func GetProviderModel(provider provider.InferenceProvider, modelID string) Model { + cfg := Get() + providerConfig, ok := cfg.Providers[provider] + if !ok { + logging.Error("Provider not found", "provider", provider) + return Model{} } - return updateCfgFile(func(config *Config) { - if config.Agents == nil { - config.Agents = make(map[AgentName]Agent) + for _, model := range providerConfig.Models { + if model.ID == modelID { + return model } - config.Agents[agentName] = newAgentCfg - }) -} - -// UpdateTheme updates the theme in the configuration and writes it to the config file. -func UpdateTheme(themeName string) error { - if cfg == nil { - return fmt.Errorf("config not loaded") } - // Update the in-memory config - cfg.TUI.Theme = themeName - - // Update the file config - return updateCfgFile(func(config *Config) { - config.TUI.Theme = themeName - }) + logging.Error("Model not found for provider", "provider", provider, "model_id", modelID) + return Model{} } diff --git a/internal/config_v2/config_test.go b/internal/config/config_test.go similarity index 94% rename from internal/config_v2/config_test.go rename to internal/config/config_test.go index 9bcfcdc78375e1a3a35726b513f04e3cb1e2c3b3..2942c206aa4bb8b81ff3f3fca9a444411359e515 100644 --- a/internal/config_v2/config_test.go +++ b/internal/config/config_test.go @@ -1,4 +1,4 @@ -package configv2 +package config import ( "encoding/json" @@ -28,7 +28,7 @@ func TestConfigWithEnv(t *testing.T) { os.Setenv("GEMINI_API_KEY", "test-gemini-key") os.Setenv("XAI_API_KEY", "test-xai-key") os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") - cfg := InitConfig(cwdDir) + cfg, _ := Init(cwdDir, false) data, _ := json.MarshalIndent(cfg, "", " ") fmt.Println(string(data)) assert.Len(t, cfg.Providers, 5) diff --git a/internal/config_v2/fs.go b/internal/config/fs.go similarity index 99% rename from internal/config_v2/fs.go rename to internal/config/fs.go index 976267a2a68efb718449f59b3720d0d186720cdf..efa622cf937846370616042de4fe2bcd6f33b7a1 100644 --- a/internal/config_v2/fs.go +++ b/internal/config/fs.go @@ -1,4 +1,4 @@ -package configv2 +package config import ( "fmt" diff --git a/internal/config/init.go b/internal/config/init.go index 1b603fbb846aba45230cd0f4683cb465e14db69a..f17e1db28e41cc44e168765e55e88311423e1102 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -17,23 +17,20 @@ type ProjectInitFlag struct { Initialized bool `json:"initialized"` } -// ShouldShowInitDialog checks if the initialization dialog should be shown for the current directory -func ShouldShowInitDialog() (bool, error) { - if cfg == nil { +// ProjectNeedsInitialization checks if the current project needs initialization +func ProjectNeedsInitialization() (bool, error) { + if instance == nil { return false, fmt.Errorf("config not loaded") } - // Create the flag file path - flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename) + flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename) // Check if the flag file exists _, err := os.Stat(flagFilePath) if err == nil { - // File exists, don't show the dialog return false, nil } - // If the error is not "file not found", return the error if !os.IsNotExist(err) { return false, fmt.Errorf("failed to check init flag file: %w", err) } @@ -44,11 +41,9 @@ func ShouldShowInitDialog() (bool, error) { return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err) } if crushExists { - // CRUSH.md already exists, don't show the dialog return false, nil } - // File doesn't exist, show the dialog return true, nil } @@ -75,13 +70,11 @@ func crushMdExists(dir string) (bool, error) { // MarkProjectInitialized marks the current project as initialized func MarkProjectInitialized() error { - if cfg == nil { + if instance == nil { return fmt.Errorf("config not loaded") } - // Create the flag file path - flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename) + flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename) - // Create an empty file to mark the project as initialized file, err := os.Create(flagFilePath) if err != nil { return fmt.Errorf("failed to create init flag file: %w", err) diff --git a/internal/config_v2/provider.go b/internal/config/provider.go similarity index 98% rename from internal/config_v2/provider.go rename to internal/config/provider.go index ec6b5bdb701876af4705c9e78fcc55a87646edd2..4c2b61ff6d5d86f62a8a1833a6ea91b500bbc7b0 100644 --- a/internal/config_v2/provider.go +++ b/internal/config/provider.go @@ -1,4 +1,4 @@ -package configv2 +package config import ( "encoding/json" diff --git a/internal/config_v2/config.go b/internal/config_v2/config.go deleted file mode 100644 index 9f7f2ad14356531150cca4f05952fb390c716c68..0000000000000000000000000000000000000000 --- a/internal/config_v2/config.go +++ /dev/null @@ -1,660 +0,0 @@ -package configv2 - -import ( - "encoding/json" - "errors" - "maps" - "os" - "path/filepath" - "slices" - "strings" - "sync" - - "github.com/charmbracelet/crush/internal/fur/provider" - "github.com/charmbracelet/crush/internal/logging" -) - -const ( - defaultDataDirectory = ".crush" - defaultLogLevel = "info" - appName = "crush" - - MaxTokensFallbackDefault = 4096 -) - -var defaultContextPaths = []string{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", -} - -type AgentID string - -const ( - AgentCoder AgentID = "coder" - AgentTask AgentID = "task" - AgentTitle AgentID = "title" - AgentSummarize AgentID = "summarize" -) - -type Model struct { - ID string `json:"id"` - Name string `json:"model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - ReasoningEffort string `json:"reasoning_effort"` - SupportsImages bool `json:"supports_attachments"` -} - -type VertexAIOptions struct { - APIKey string `json:"api_key,omitempty"` - Project string `json:"project,omitempty"` - Location string `json:"location,omitempty"` -} - -type ProviderConfig struct { - ID provider.InferenceProvider `json:"id"` - BaseURL string `json:"base_url,omitempty"` - ProviderType provider.Type `json:"provider_type"` - APIKey string `json:"api_key,omitempty"` - Disabled bool `json:"disabled"` - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` - // used for e.x for vertex to set the project - ExtraParams map[string]string `json:"extra_params,omitempty"` - - DefaultLargeModel string `json:"default_large_model,omitempty"` - DefaultSmallModel string `json:"default_small_model,omitempty"` - - Models []Model `json:"models,omitempty"` -} - -type Agent struct { - ID AgentID `json:"id"` - Name string `json:"name"` - Description string `json:"description,omitempty"` - // This is the id of the system prompt used by the agent - Disabled bool `json:"disabled"` - - Provider provider.InferenceProvider `json:"provider"` - Model string `json:"model"` - - // The available tools for the agent - // if this is nil, all tools are available - AllowedTools []string `json:"allowed_tools"` - - // this tells us which MCPs are available for this agent - // if this is empty all mcps are available - // the string array is the list of tools from the AllowedMCP the agent has available - // if the string array is nil, all tools from the AllowedMCP are available - AllowedMCP map[string][]string `json:"allowed_mcp"` - - // The list of LSPs that this agent can use - // if this is nil, all LSPs are available - AllowedLSP []string `json:"allowed_lsp"` - - // Overrides the context paths for this agent - ContextPaths []string `json:"context_paths"` -} - -type MCPType string - -const ( - MCPStdio MCPType = "stdio" - MCPSse MCPType = "sse" -) - -type MCP struct { - Command string `json:"command"` - Env []string `json:"env"` - Args []string `json:"args"` - Type MCPType `json:"type"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` -} - -type LSPConfig struct { - Disabled bool `json:"enabled"` - Command string `json:"command"` - Args []string `json:"args"` - Options any `json:"options"` -} - -type TUIOptions struct { - CompactMode bool `json:"compact_mode"` - // Here we can add themes later or any TUI related options -} - -type Options struct { - ContextPaths []string `json:"context_paths"` - TUI TUIOptions `json:"tui"` - Debug bool `json:"debug"` - DebugLSP bool `json:"debug_lsp"` - DisableAutoSummarize bool `json:"disable_auto_summarize"` - // Relative to the cwd - DataDirectory string `json:"data_directory"` -} - -type Config struct { - // List of configured providers - Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` - - // List of configured agents - Agents map[AgentID]Agent `json:"agents,omitempty"` - - // List of configured MCPs - MCP map[string]MCP `json:"mcp,omitempty"` - - // List of configured LSPs - LSP map[string]LSPConfig `json:"lsp,omitempty"` - - // Miscellaneous options - Options Options `json:"options"` -} - -var ( - instance *Config // The single instance of the Singleton - cwd string - once sync.Once // Ensures the initialization happens only once - -) - -func loadConfig(cwd string) (*Config, error) { - // First read the global config file - cfgPath := ConfigPath() - - cfg := defaultConfigBasedOnEnv() - - var globalCfg *Config - if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) { - // some other error occurred while checking the file - return nil, err - } else if err == nil { - // config file exists, read it - file, err := os.ReadFile(cfgPath) - if err != nil { - return nil, err - } - globalCfg = &Config{} - if err := json.Unmarshal(file, globalCfg); err != nil { - return nil, err - } - } else { - // config file does not exist, create a new one - globalCfg = &Config{} - } - - var localConfig *Config - // Global config loaded, now read the local config file - localConfigPath := filepath.Join(cwd, "crush.json") - if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) { - // some other error occurred while checking the file - return nil, err - } else if err == nil { - // local config file exists, read it - file, err := os.ReadFile(localConfigPath) - if err != nil { - return nil, err - } - localConfig = &Config{} - if err := json.Unmarshal(file, localConfig); err != nil { - return nil, err - } - } - - // merge options - mergeOptions(cfg, globalCfg, localConfig) - - mergeProviderConfigs(cfg, globalCfg, localConfig) - // no providers found the app is not initialized yet - if len(cfg.Providers) == 0 { - return cfg, nil - } - preferredProvider := getPreferredProvider(cfg.Providers) - - if preferredProvider == nil { - return nil, errors.New("no valid providers configured") - } - - agents := map[AgentID]Agent{ - AgentCoder: { - ID: AgentCoder, - Name: "Coder", - Description: "An agent that helps with executing coding tasks.", - Provider: preferredProvider.ID, - Model: preferredProvider.DefaultLargeModel, - ContextPaths: cfg.Options.ContextPaths, - // All tools allowed - }, - AgentTask: { - ID: AgentTask, - Name: "Task", - Description: "An agent that helps with searching for context and finding implementation details.", - Provider: preferredProvider.ID, - Model: preferredProvider.DefaultLargeModel, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{ - "glob", - "grep", - "ls", - "sourcegraph", - "view", - }, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - AgentTitle: { - ID: AgentTitle, - Name: "Title", - Description: "An agent that helps with generating titles for sessions.", - Provider: preferredProvider.ID, - Model: preferredProvider.DefaultSmallModel, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{}, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - AgentSummarize: { - ID: AgentSummarize, - Name: "Summarize", - Description: "An agent that helps with summarizing sessions.", - Provider: preferredProvider.ID, - Model: preferredProvider.DefaultSmallModel, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{}, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - } - cfg.Agents = agents - mergeAgents(cfg, globalCfg, localConfig) - mergeMCPs(cfg, globalCfg, localConfig) - mergeLSPs(cfg, globalCfg, localConfig) - - return cfg, nil -} - -func InitConfig(workingDir string) *Config { - once.Do(func() { - cwd = workingDir - cfg, err := loadConfig(cwd) - if err != nil { - // TODO: Handle this better - panic("Failed to load config: " + err.Error()) - } - instance = cfg - }) - - return instance -} - -func GetConfig() *Config { - if instance == nil { - // TODO: Handle this better - panic("Config not initialized. Call InitConfig first.") - } - return instance -} - -func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig { - providers := Providers() - for _, p := range providers { - if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled { - return &providerConfig - } - } - // if none found return the first configured provider - for _, providerConfig := range configuredProviders { - if !providerConfig.Disabled { - return &providerConfig - } - } - return nil -} - -func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { - if other.APIKey != "" { - base.APIKey = other.APIKey - } - // Only change these options if the provider is not a known provider - if !slices.Contains(provider.KnownProviders(), p) { - if other.BaseURL != "" { - base.BaseURL = other.BaseURL - } - if other.ProviderType != "" { - base.ProviderType = other.ProviderType - } - if len(base.ExtraHeaders) > 0 { - if base.ExtraHeaders == nil { - base.ExtraHeaders = make(map[string]string) - } - maps.Copy(base.ExtraHeaders, other.ExtraHeaders) - } - if len(other.ExtraParams) > 0 { - if base.ExtraParams == nil { - base.ExtraParams = make(map[string]string) - } - maps.Copy(base.ExtraParams, other.ExtraParams) - } - } - - if other.Disabled { - base.Disabled = other.Disabled - } - - if other.DefaultLargeModel != "" { - base.DefaultLargeModel = other.DefaultLargeModel - } - // Add new models if they don't exist - if other.Models != nil { - for _, model := range other.Models { - // check if the model already exists - exists := false - for _, existingModel := range base.Models { - if existingModel.ID == model.ID { - exists = true - break - } - } - if !exists { - base.Models = append(base.Models, model) - } - } - } - - return base -} - -func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error { - if !slices.Contains(provider.KnownProviders(), p) { - if providerConfig.ProviderType != provider.TypeOpenAI { - return errors.New("invalid provider type: " + string(providerConfig.ProviderType)) - } - if providerConfig.BaseURL == "" { - return errors.New("base URL must be set for custom providers") - } - if providerConfig.APIKey == "" { - return errors.New("API key must be set for custom providers") - } - } - return nil -} - -func mergeOptions(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - baseOptions := base.Options - other := cfg.Options - if len(other.ContextPaths) > 0 { - baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...) - } - - if other.TUI.CompactMode { - baseOptions.TUI.CompactMode = other.TUI.CompactMode - } - - if other.Debug { - baseOptions.Debug = other.Debug - } - - if other.DebugLSP { - baseOptions.DebugLSP = other.DebugLSP - } - - if other.DisableAutoSummarize { - baseOptions.DisableAutoSummarize = other.DisableAutoSummarize - } - - if other.DataDirectory != "" { - baseOptions.DataDirectory = other.DataDirectory - } - base.Options = baseOptions - } -} - -func mergeAgents(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - for agentID, newAgent := range cfg.Agents { - if _, ok := base.Agents[agentID]; !ok { - newAgent.ID = agentID // Ensure the ID is set correctly - base.Agents[agentID] = newAgent - } else { - switch agentID { - case AgentCoder: - baseAgent := base.Agents[agentID] - baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider - baseAgent.AllowedMCP = newAgent.AllowedMCP - baseAgent.AllowedLSP = newAgent.AllowedLSP - base.Agents[agentID] = baseAgent - case AgentTask: - baseAgent := base.Agents[agentID] - baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider - base.Agents[agentID] = baseAgent - case AgentTitle: - baseAgent := base.Agents[agentID] - baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider - base.Agents[agentID] = baseAgent - case AgentSummarize: - baseAgent := base.Agents[agentID] - baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider - base.Agents[agentID] = baseAgent - default: - baseAgent := base.Agents[agentID] - baseAgent.Name = newAgent.Name - baseAgent.Description = newAgent.Description - baseAgent.Disabled = newAgent.Disabled - baseAgent.Provider = newAgent.Provider - baseAgent.Model = newAgent.Model - baseAgent.AllowedTools = newAgent.AllowedTools - baseAgent.AllowedMCP = newAgent.AllowedMCP - baseAgent.AllowedLSP = newAgent.AllowedLSP - base.Agents[agentID] = baseAgent - - } - } - } - } -} - -func mergeMCPs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - maps.Copy(base.MCP, cfg.MCP) - } -} - -func mergeLSPs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - maps.Copy(base.LSP, cfg.LSP) - } -} - -func mergeProviderConfigs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - for providerName, globalProvider := range cfg.Providers { - if _, ok := base.Providers[providerName]; !ok { - base.Providers[providerName] = globalProvider - } else { - base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider) - } - } - } - - finalProviders := make(map[provider.InferenceProvider]ProviderConfig) - for providerName, providerConfig := range base.Providers { - err := validateProvider(providerName, providerConfig) - if err != nil { - logging.Warn("Skipping provider", "name", providerName, "error", err) - } - finalProviders[providerName] = providerConfig - } - base.Providers = finalProviders -} - -func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig { - switch providerId { - case provider.InferenceProviderAnthropic: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeAnthropic, - } - case provider.InferenceProviderOpenAI: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeOpenAI, - } - case provider.InferenceProviderGemini: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeGemini, - } - case provider.InferenceProviderBedrock: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeBedrock, - } - case provider.InferenceProviderAzure: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeAzure, - } - case provider.InferenceProviderOpenRouter: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeOpenAI, - BaseURL: "https://openrouter.ai/api/v1", - ExtraHeaders: map[string]string{ - "HTTP-Referer": "crush.charm.land", - "X-Title": "Crush", - }, - } - case provider.InferenceProviderXAI: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeXAI, - BaseURL: "https://api.x.ai/v1", - } - case provider.InferenceProviderVertexAI: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeVertexAI, - } - default: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeOpenAI, - } - } -} - -func defaultConfigBasedOnEnv() *Config { - cfg := &Config{ - Options: Options{ - DataDirectory: defaultDataDirectory, - ContextPaths: defaultContextPaths, - }, - Providers: make(map[provider.InferenceProvider]ProviderConfig), - } - - providers := Providers() - - for _, p := range providers { - if strings.HasPrefix(p.APIKey, "$") { - envVar := strings.TrimPrefix(p.APIKey, "$") - if apiKey := os.Getenv(envVar); apiKey != "" { - providerConfig := providerDefaultConfig(p.ID) - providerConfig.APIKey = apiKey - providerConfig.DefaultLargeModel = p.DefaultLargeModelID - providerConfig.DefaultSmallModel = p.DefaultSmallModelID - for _, model := range p.Models { - providerConfig.Models = append(providerConfig.Models, Model{ - ID: model.ID, - Name: model.Name, - CostPer1MIn: model.CostPer1MIn, - CostPer1MOut: model.CostPer1MOut, - CostPer1MInCached: model.CostPer1MInCached, - CostPer1MOutCached: model.CostPer1MOutCached, - ContextWindow: model.ContextWindow, - DefaultMaxTokens: model.DefaultMaxTokens, - CanReason: model.CanReason, - SupportsImages: model.SupportsImages, - }) - } - cfg.Providers[p.ID] = providerConfig - } - } - } - // TODO: support local models - - if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" { - providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI) - providerConfig.ExtraParams = map[string]string{ - "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), - "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), - } - cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig - } - - if hasAWSCredentials() { - providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock) - cfg.Providers[provider.InferenceProviderBedrock] = providerConfig - } - return cfg -} - -func hasAWSCredentials() bool { - if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { - return true - } - - if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { - return true - } - - if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { - return true - } - - if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || - os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { - return true - } - - return false -} - -func WorkingDirectory() string { - return cwd -} diff --git a/internal/db/connect.go b/internal/db/connect.go index 9212ce1f097e6877a9ce9b368e77d76e739b673f..6452dabdb3a5de6ffb0f618062631dfe4b25102d 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -1,7 +1,6 @@ package db import ( - "context" "database/sql" "fmt" "os" @@ -16,8 +15,8 @@ import ( "github.com/pressly/goose/v3" ) -func Connect(ctx context.Context) (*sql.DB, error) { - dataDir := config.Get().Data.Directory +func Connect() (*sql.DB, error) { + dataDir := config.Get().Options.DataDirectory if dataDir == "" { return nil, fmt.Errorf("data.dir is not set") } diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go index 2acfe18fdbc63312c49d65e9e3acb1bd24cf4d7e..81f322921db87dde7ade48ce64322aa01004d255 100644 --- a/internal/db/messages.sql.go +++ b/internal/db/messages.sql.go @@ -17,12 +17,13 @@ INSERT INTO messages ( role, parts, model, + provider, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) -RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at +RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider ` type CreateMessageParams struct { @@ -31,6 +32,7 @@ type CreateMessageParams struct { Role string `json:"role"` Parts string `json:"parts"` Model sql.NullString `json:"model"` + Provider sql.NullString `json:"provider"` } func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) { @@ -40,6 +42,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M arg.Role, arg.Parts, arg.Model, + arg.Provider, ) var i Message err := row.Scan( @@ -51,6 +54,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M &i.CreatedAt, &i.UpdatedAt, &i.FinishedAt, + &i.Provider, ) return i, err } @@ -76,7 +80,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e } const getMessage = `-- name: GetMessage :one -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider FROM messages WHERE id = ? LIMIT 1 ` @@ -93,12 +97,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) { &i.CreatedAt, &i.UpdatedAt, &i.FinishedAt, + &i.Provider, ) return i, err } const listMessagesBySession = `-- name: ListMessagesBySession :many -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider FROM messages WHERE session_id = ? ORDER BY created_at ASC @@ -122,6 +127,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) ( &i.CreatedAt, &i.UpdatedAt, &i.FinishedAt, + &i.Provider, ); err != nil { return nil, err } diff --git a/internal/db/migrations/20250627000000_add_provider_to_messages.sql b/internal/db/migrations/20250627000000_add_provider_to_messages.sql new file mode 100644 index 0000000000000000000000000000000000000000..9bf0ed9749c49640f10407c97deb032f60baaac2 --- /dev/null +++ b/internal/db/migrations/20250627000000_add_provider_to_messages.sql @@ -0,0 +1,11 @@ +-- +goose Up +-- +goose StatementBegin +-- Add provider column to messages table +ALTER TABLE messages ADD COLUMN provider TEXT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +-- Remove provider column from messages table +ALTER TABLE messages DROP COLUMN provider; +-- +goose StatementEnd \ No newline at end of file diff --git a/internal/db/models.go b/internal/db/models.go index ec19f99b213e041331b5d6a14dee3648bc14c1de..ec3e6e10ad990d0f1a3d03a7533c8b1aed184447 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -27,6 +27,7 @@ type Message struct { CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` FinishedAt sql.NullInt64 `json:"finished_at"` + Provider sql.NullString `json:"provider"` } type Session struct { diff --git a/internal/db/sql/messages.sql b/internal/db/sql/messages.sql index a59cebe7d00fe5fd7cbd449df681df45e832979a..ea946177591d1e145a59475a1ca9272f3191d4d6 100644 --- a/internal/db/sql/messages.sql +++ b/internal/db/sql/messages.sql @@ -16,10 +16,11 @@ INSERT INTO messages ( role, parts, model, + provider, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) RETURNING *; diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go index 263e8317ce8ac92d8820ba5288f2e40d2616e0e1..5f0ddeaeee708d4b5475403ce1874591f7e9bb2c 100644 --- a/internal/fur/client/client.go +++ b/internal/fur/client/client.go @@ -10,7 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/fur/provider" ) -const defaultURL = "http://localhost:8080" +const defaultURL = "https://fur.charmcli.dev" // Client represents a client for the fur service. type Client struct { diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go index 85275f1155eff219c87d85fce3cdcc436f4a4e47..8545694dea70b410a3a1912b82313bde2852d942 100644 --- a/internal/fur/provider/provider.go +++ b/internal/fur/provider/provider.go @@ -6,14 +6,13 @@ type Type string // All the supported AI provider types. const ( - TypeOpenAI Type = "openai" - TypeAnthropic Type = "anthropic" - TypeGemini Type = "gemini" - TypeAzure Type = "azure" - TypeBedrock Type = "bedrock" - TypeVertexAI Type = "vertexai" - TypeXAI Type = "xai" - TypeOpenRouter Type = "openrouter" + TypeOpenAI Type = "openai" + TypeAnthropic Type = "anthropic" + TypeGemini Type = "gemini" + TypeAzure Type = "azure" + TypeBedrock Type = "bedrock" + TypeVertexAI Type = "vertexai" + TypeXAI Type = "xai" ) // InferenceProvider represents the inference provider identifier. diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 9e5e9bc7844b055c52464032dfc4d75495f9e426..ae15c5867e7321f2ac29e1809f5eb7effb830fdc 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,17 +5,15 @@ import ( "encoding/json" "fmt" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/session" ) type agentTool struct { - sessions session.Service - messages message.Service - lspClients map[string]*lsp.Client + agent Service + sessions session.Service + messages message.Service } const ( @@ -58,17 +56,12 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients)) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) - } - session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") if err != nil { return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) } - done, err := agent.Run(ctx, session.ID, params.Prompt) + done, err := b.agent.Run(ctx, session.ID, params.Prompt) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) } @@ -101,13 +94,13 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes } func NewAgentTool( - Sessions session.Service, - Messages message.Service, - LspClients map[string]*lsp.Client, + agent Service, + sessions session.Service, + messages message.Service, ) tools.BaseTool { return &agentTool{ - sessions: Sessions, - messages: Messages, - lspClients: LspClients, + sessions: sessions, + messages: messages, + agent: agent, } } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ea2a3bd2b11735c1f0422e859adcfa65a82fdb98..f9e97b164aa98fe1ae76490fdfcf336efb43098f 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -4,16 +4,18 @@ import ( "context" "errors" "fmt" + "slices" "strings" "sync" "time" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/logging" + "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" @@ -47,71 +49,198 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] - Model() models.Model + Model() configv2.Model Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() IsSessionBusy(sessionID string) bool IsBusy() bool - Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) + Update(model configv2.PreferredModel) (configv2.Model, error) Summarize(ctx context.Context, sessionID string) error } type agent struct { *pubsub.Broker[AgentEvent] + agentCfg configv2.Agent sessions session.Service messages message.Service - tools []tools.BaseTool - provider provider.Provider + tools []tools.BaseTool + provider provider.Provider + providerID string - titleProvider provider.Provider - summarizeProvider provider.Provider + titleProvider provider.Provider + summarizeProvider provider.Provider + summarizeProviderID string activeRequests sync.Map } +var agentPromptMap = map[configv2.AgentID]prompt.PromptID{ + configv2.AgentCoder: prompt.PromptCoder, + configv2.AgentTask: prompt.PromptTask, +} + func NewAgent( - agentName config.AgentName, + agentCfg configv2.Agent, + // These services are needed in the tools + permissions permission.Service, sessions session.Service, messages message.Service, - agentTools []tools.BaseTool, + history history.Service, + lspClients map[string]*lsp.Client, ) (Service, error) { - agentProvider, err := createAgentProvider(agentName) + ctx := context.Background() + cfg := configv2.Get() + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + + allTools := []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions, history), + } + + if agentCfg.ID == configv2.AgentCoder { + taskAgentCfg := configv2.Get().Agents[configv2.AgentTask] + if taskAgentCfg.ID == "" { + return nil, fmt.Errorf("task agent not found in config") + } + taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) + if err != nil { + return nil, fmt.Errorf("failed to create task agent: %w", err) + } + + allTools = append( + allTools, + NewAgentTool( + taskAgent, + sessions, + messages, + ), + ) + } + + allTools = append(allTools, otherTools...) + var providerCfg configv2.ProviderConfig + for _, p := range cfg.Providers { + if p.ID == agentCfg.Provider { + providerCfg = p + break + } + } + if providerCfg.ID == "" { + return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider) + } + + var model configv2.Model + for _, m := range providerCfg.Models { + if m.ID == agentCfg.Model { + model = m + break + } + } + if model.ID == "" { + return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider) + } + + promptID := agentPromptMap[agentCfg.ID] + if promptID == "" { + promptID = prompt.PromptDefault + } + opts := []provider.ProviderClientOption{ + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), + provider.WithMaxTokens(model.DefaultMaxTokens), + } + agentProvider, err := provider.NewProviderV2(providerCfg, opts...) if err != nil { return nil, err } - var titleProvider provider.Provider - // Only generate titles for the coder agent - if agentName == config.AgentCoder { - titleProvider, err = createAgentProvider(config.AgentTitle) - if err != nil { - return nil, err + + smallModelCfg := cfg.Models.Small + var smallModel configv2.Model + + var smallModelProviderCfg configv2.ProviderConfig + if smallModelCfg.Provider == providerCfg.ID { + smallModelProviderCfg = providerCfg + } else { + for _, p := range cfg.Providers { + if p.ID == smallModelCfg.Provider { + smallModelProviderCfg = p + break + } + } + if smallModelProviderCfg.ID == "" { + return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) } } - var summarizeProvider provider.Provider - if agentName == config.AgentCoder { - summarizeProvider, err = createAgentProvider(config.AgentSummarizer) - if err != nil { - return nil, err + for _, m := range smallModelProviderCfg.Models { + if m.ID == smallModelCfg.ModelID { + smallModel = m + break + } + } + if smallModel.ID == "" { + return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID) + } + + titleOpts := []provider.ProviderClientOption{ + provider.WithModel(smallModel), + provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), + provider.WithMaxTokens(40), + } + titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) + if err != nil { + return nil, err + } + summarizeOpts := []provider.ProviderClientOption{ + provider.WithModel(smallModel), + provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), + provider.WithMaxTokens(smallModel.DefaultMaxTokens), + } + summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) + if err != nil { + return nil, err + } + + agentTools := []tools.BaseTool{} + if agentCfg.AllowedTools == nil { + agentTools = allTools + } else { + for _, tool := range allTools { + if slices.Contains(agentCfg.AllowedTools, tool.Name()) { + agentTools = append(agentTools, tool) + } } } agent := &agent{ - Broker: pubsub.NewBroker[AgentEvent](), - provider: agentProvider, - messages: messages, - sessions: sessions, - tools: agentTools, - titleProvider: titleProvider, - summarizeProvider: summarizeProvider, - activeRequests: sync.Map{}, + Broker: pubsub.NewBroker[AgentEvent](), + agentCfg: agentCfg, + provider: agentProvider, + providerID: string(providerCfg.ID), + messages: messages, + sessions: sessions, + tools: agentTools, + titleProvider: titleProvider, + summarizeProvider: summarizeProvider, + summarizeProviderID: string(smallModelProviderCfg.ID), + activeRequests: sync.Map{}, } return agent, nil } -func (a *agent) Model() models.Model { +func (a *agent) Model() configv2.Model { return a.provider.Model() } @@ -207,7 +336,7 @@ func (a *agent) err(err error) AgentEvent { } func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) { - if !a.provider.Model().SupportsAttachments && attachments != nil { + if !a.provider.Model().SupportsImages && attachments != nil { attachments = nil } events := make(chan AgentEvent) @@ -327,9 +456,10 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.provider.Model().ID, + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.provider.Model().ID, + Provider: a.providerID, }) if err != nil { return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) @@ -424,8 +554,9 @@ out: parts = append(parts, tr) } msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, + Role: message.Tool, + Parts: parts, + Provider: a.providerID, }) if err != nil { return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) @@ -484,7 +615,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg return nil } -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error { sess, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -506,21 +637,48 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.M return nil } -func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) { +func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) { if a.IsBusy() { - return models.Model{}, fmt.Errorf("cannot change model while processing requests") + return configv2.Model{}, fmt.Errorf("cannot change model while processing requests") } - if err := config.UpdateAgentModel(agentName, modelID); err != nil { - return models.Model{}, fmt.Errorf("failed to update config: %w", err) + cfg := configv2.Get() + var providerCfg configv2.ProviderConfig + for _, p := range cfg.Providers { + if p.ID == modelCfg.Provider { + providerCfg = p + break + } + } + if providerCfg.ID == "" { + return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider) } - provider, err := createAgentProvider(agentName) - if err != nil { - return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err) + var model configv2.Model + for _, m := range providerCfg.Models { + if m.ID == modelCfg.ModelID { + model = m + break + } + } + if model.ID == "" { + return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider) } - a.provider = provider + promptID := agentPromptMap[a.agentCfg.ID] + if promptID == "" { + promptID = prompt.PromptDefault + } + opts := []provider.ProviderClientOption{ + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), + provider.WithMaxTokens(model.DefaultMaxTokens), + } + agentProvider, err := provider.NewProviderV2(providerCfg, opts...) + if err != nil { + return configv2.Model{}, err + } + a.provider = agentProvider return a.provider.Model(), nil } @@ -654,7 +812,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { Time: time.Now().Unix(), }, }, - Model: a.summarizeProvider.Model().ID, + Model: a.summarizeProvider.Model().ID, + Provider: a.summarizeProviderID, }) if err != nil { event = AgentEvent{ @@ -705,51 +864,3 @@ func (a *agent) CancelAll() { return true }) } - -func createAgentProvider(agentName config.AgentName) (provider.Provider, error) { - cfg := config.Get() - agentConfig, ok := cfg.Agents[agentName] - if !ok { - return nil, fmt.Errorf("agent %s not found", agentName) - } - model, ok := models.SupportedModels[agentConfig.Model] - if !ok { - return nil, fmt.Errorf("model %s not supported", agentConfig.Model) - } - - providerCfg, ok := cfg.Providers[model.Provider] - if !ok { - return nil, fmt.Errorf("provider %s not supported", model.Provider) - } - if providerCfg.Disabled { - return nil, fmt.Errorf("provider %s is not enabled", model.Provider) - } - maxTokens := model.DefaultMaxTokens - if agentConfig.MaxTokens > 0 { - maxTokens = agentConfig.MaxTokens - } - opts := []provider.ProviderClientOption{ - provider.WithAPIKey(providerCfg.APIKey), - provider.WithModel(model), - provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), - provider.WithMaxTokens(maxTokens), - } - // TODO: reimplement - // if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason { - // opts = append( - // opts, - // provider.WithOpenAIOptions( - // provider.WithReasoningEffort(agentConfig.ReasoningEffort), - // ), - // ) - // } - agentProvider, err := provider.NewProvider( - model.Provider, - opts..., - ) - if err != nil { - return nil, fmt.Errorf("could not create provider: %v", err) - } - - return agentProvider, nil -} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 8fde02755eb320b8925891a3eca938c3cd7911f9..1950324fa3ed4dbd9de358d18023247b0bb429e7 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -18,7 +18,7 @@ import ( type mcpTool struct { mcpName string tool mcp.Tool - mcpConfig config.MCPServer + mcpConfig config.MCP permissions permission.Service } @@ -128,7 +128,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("invalid mcp type"), nil } -func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool { +func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool { return &mcpTool{ mcpName: name, tool: tool, @@ -139,7 +139,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC var mcpTools []tools.BaseTool -func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool { +func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool { var stdioTools []tools.BaseTool initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION @@ -170,7 +170,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba if len(mcpTools) > 0 { return mcpTools } - for name, m := range config.Get().MCPServers { + for name, m := range config.Get().MCP { switch m.Type { case config.MCPStdio: c, err := client.NewStdioMCPClient( diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go deleted file mode 100644 index 0fe2c530ca6dc30916fd2dfa094ad6303bf39443..0000000000000000000000000000000000000000 --- a/internal/llm/agent/tools.go +++ /dev/null @@ -1,50 +0,0 @@ -package agent - -import ( - "context" - - "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/crush/internal/message" - "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/session" -) - -func CoderAgentTools( - permissions permission.Service, - sessions session.Service, - messages message.Service, - history history.Service, - lspClients map[string]*lsp.Client, -) []tools.BaseTool { - ctx := context.Background() - otherTools := GetMcpTools(ctx, permissions) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - return append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions, history), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewWriteTool(lspClients, permissions, history), - NewAgentTool(sessions, messages, lspClients), - }, otherTools..., - ) -} - -func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { - return []tools.BaseTool{ - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - } -} diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go deleted file mode 100644 index 85c47def3d94034297265c506c5870f2b449d286..0000000000000000000000000000000000000000 --- a/internal/llm/models/anthropic.go +++ /dev/null @@ -1,111 +0,0 @@ -package models - -const ( - ProviderAnthropic InferenceProvider = "anthropic" - - // Models - Claude35Sonnet ModelID = "claude-3.5-sonnet" - Claude3Haiku ModelID = "claude-3-haiku" - Claude37Sonnet ModelID = "claude-3.7-sonnet" - Claude35Haiku ModelID = "claude-3.5-haiku" - Claude3Opus ModelID = "claude-3-opus" - Claude4Opus ModelID = "claude-4-opus" - Claude4Sonnet ModelID = "claude-4-sonnet" -) - -// https://docs.anthropic.com/en/docs/about-claude/models/all-models -var AnthropicModels = map[ModelID]Model{ - Claude35Sonnet: { - ID: Claude35Sonnet, - Name: "Claude 3.5 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-5-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - ContextWindow: 200000, - DefaultMaxTokens: 5000, - SupportsAttachments: true, - }, - Claude3Haiku: { - ID: Claude3Haiku, - Name: "Claude 3 Haiku", - Provider: ProviderAnthropic, - APIModel: "claude-3-haiku-20240307", // doesn't support "-latest" - CostPer1MIn: 0.25, - CostPer1MInCached: 0.30, - CostPer1MOutCached: 0.03, - CostPer1MOut: 1.25, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, - Claude37Sonnet: { - ID: Claude37Sonnet, - Name: "Claude 3.7 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-7-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - Claude35Haiku: { - ID: Claude35Haiku, - Name: "Claude 3.5 Haiku", - Provider: ProviderAnthropic, - APIModel: "claude-3-5-haiku-latest", - CostPer1MIn: 0.80, - CostPer1MInCached: 1.0, - CostPer1MOutCached: 0.08, - CostPer1MOut: 4.0, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, - Claude3Opus: { - ID: Claude3Opus, - Name: "Claude 3 Opus", - Provider: ProviderAnthropic, - APIModel: "claude-3-opus-latest", - CostPer1MIn: 15.0, - CostPer1MInCached: 18.75, - CostPer1MOutCached: 1.50, - CostPer1MOut: 75.0, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, - Claude4Sonnet: { - ID: Claude4Sonnet, - Name: "Claude 4 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-sonnet-4-20250514", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - Claude4Opus: { - ID: Claude4Opus, - Name: "Claude 4 Opus", - Provider: ProviderAnthropic, - APIModel: "claude-opus-4-20250514", - CostPer1MIn: 15.0, - CostPer1MInCached: 18.75, - CostPer1MOutCached: 1.50, - CostPer1MOut: 75.0, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/azure.go b/internal/llm/models/azure.go deleted file mode 100644 index eb7ae293ee053d953f5bcbb20120089ca6bae95b..0000000000000000000000000000000000000000 --- a/internal/llm/models/azure.go +++ /dev/null @@ -1,168 +0,0 @@ -package models - -const ProviderAzure InferenceProvider = "azure" - -const ( - AzureGPT41 ModelID = "azure.gpt-4.1" - AzureGPT41Mini ModelID = "azure.gpt-4.1-mini" - AzureGPT41Nano ModelID = "azure.gpt-4.1-nano" - AzureGPT45Preview ModelID = "azure.gpt-4.5-preview" - AzureGPT4o ModelID = "azure.gpt-4o" - AzureGPT4oMini ModelID = "azure.gpt-4o-mini" - AzureO1 ModelID = "azure.o1" - AzureO1Mini ModelID = "azure.o1-mini" - AzureO3 ModelID = "azure.o3" - AzureO3Mini ModelID = "azure.o3-mini" - AzureO4Mini ModelID = "azure.o4-mini" -) - -var AzureModels = map[ModelID]Model{ - AzureGPT41: { - ID: AzureGPT41, - Name: "Azure OpenAI – GPT 4.1", - Provider: ProviderAzure, - APIModel: "gpt-4.1", - CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT41Mini: { - ID: AzureGPT41Mini, - Name: "Azure OpenAI – GPT 4.1 mini", - Provider: ProviderAzure, - APIModel: "gpt-4.1-mini", - CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT41Nano: { - ID: AzureGPT41Nano, - Name: "Azure OpenAI – GPT 4.1 nano", - Provider: ProviderAzure, - APIModel: "gpt-4.1-nano", - CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41Nano].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT45Preview: { - ID: AzureGPT45Preview, - Name: "Azure OpenAI – GPT 4.5 preview", - Provider: ProviderAzure, - APIModel: "gpt-4.5-preview", - CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT45Preview].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT4o: { - ID: AzureGPT4o, - Name: "Azure OpenAI – GPT-4o", - Provider: ProviderAzure, - APIModel: "gpt-4o", - CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT4o].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT4oMini: { - ID: AzureGPT4oMini, - Name: "Azure OpenAI – GPT-4o mini", - Provider: ProviderAzure, - APIModel: "gpt-4o-mini", - CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT4oMini].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureO1: { - ID: AzureO1, - Name: "Azure OpenAI – O1", - Provider: ProviderAzure, - APIModel: "o1", - CostPer1MIn: OpenAIModels[O1].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens, - CanReason: OpenAIModels[O1].CanReason, - SupportsAttachments: true, - }, - AzureO1Mini: { - ID: AzureO1Mini, - Name: "Azure OpenAI – O1 mini", - Provider: ProviderAzure, - APIModel: "o1-mini", - CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O1Mini].CanReason, - SupportsAttachments: true, - }, - AzureO3: { - ID: AzureO3, - Name: "Azure OpenAI – O3", - Provider: ProviderAzure, - APIModel: "o3", - CostPer1MIn: OpenAIModels[O3].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O3].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached, - ContextWindow: OpenAIModels[O3].ContextWindow, - DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens, - CanReason: OpenAIModels[O3].CanReason, - SupportsAttachments: true, - }, - AzureO3Mini: { - ID: AzureO3Mini, - Name: "Azure OpenAI – O3 mini", - Provider: ProviderAzure, - APIModel: "o3-mini", - CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O3Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O3Mini].CanReason, - SupportsAttachments: false, - }, - AzureO4Mini: { - ID: AzureO4Mini, - Name: "Azure OpenAI – O4 mini", - Provider: ProviderAzure, - APIModel: "o4-mini", - CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O4Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O4Mini].CanReason, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/gemini.go b/internal/llm/models/gemini.go deleted file mode 100644 index 9749c6d3409acf7b05cd67690504e2cb3ac4fd39..0000000000000000000000000000000000000000 --- a/internal/llm/models/gemini.go +++ /dev/null @@ -1,67 +0,0 @@ -package models - -const ( - ProviderGemini InferenceProvider = "gemini" - - // Models - Gemini25Flash ModelID = "gemini-2.5-flash" - Gemini25 ModelID = "gemini-2.5" - Gemini20Flash ModelID = "gemini-2.0-flash" - Gemini20FlashLite ModelID = "gemini-2.0-flash-lite" -) - -var GeminiModels = map[ModelID]Model{ - Gemini25Flash: { - ID: Gemini25Flash, - Name: "Gemini 2.5 Flash", - Provider: ProviderGemini, - APIModel: "gemini-2.5-flash-preview-04-17", - CostPer1MIn: 0.15, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.60, - ContextWindow: 1000000, - DefaultMaxTokens: 50000, - SupportsAttachments: true, - }, - Gemini25: { - ID: Gemini25, - Name: "Gemini 2.5 Pro", - Provider: ProviderGemini, - APIModel: "gemini-2.5-pro-preview-05-06", - CostPer1MIn: 1.25, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 10, - ContextWindow: 1000000, - DefaultMaxTokens: 50000, - SupportsAttachments: true, - }, - - Gemini20Flash: { - ID: Gemini20Flash, - Name: "Gemini 2.0 Flash", - Provider: ProviderGemini, - APIModel: "gemini-2.0-flash", - CostPer1MIn: 0.10, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.40, - ContextWindow: 1000000, - DefaultMaxTokens: 6000, - SupportsAttachments: true, - }, - Gemini20FlashLite: { - ID: Gemini20FlashLite, - Name: "Gemini 2.0 Flash Lite", - Provider: ProviderGemini, - APIModel: "gemini-2.0-flash-lite", - CostPer1MIn: 0.05, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.30, - ContextWindow: 1000000, - DefaultMaxTokens: 6000, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/groq.go b/internal/llm/models/groq.go deleted file mode 100644 index 39288962c8e42a1acec8a01b3157b10d9b00b5dc..0000000000000000000000000000000000000000 --- a/internal/llm/models/groq.go +++ /dev/null @@ -1,87 +0,0 @@ -package models - -const ( - ProviderGROQ InferenceProvider = "groq" - - // GROQ - QWENQwq ModelID = "qwen-qwq" - - // GROQ preview models - Llama4Scout ModelID = "meta-llama/llama-4-scout-17b-16e-instruct" - Llama4Maverick ModelID = "meta-llama/llama-4-maverick-17b-128e-instruct" - Llama3_3_70BVersatile ModelID = "llama-3.3-70b-versatile" - DeepseekR1DistillLlama70b ModelID = "deepseek-r1-distill-llama-70b" -) - -var GroqModels = map[ModelID]Model{ - // - // GROQ - QWENQwq: { - ID: QWENQwq, - Name: "Qwen Qwq", - Provider: ProviderGROQ, - APIModel: "qwen-qwq-32b", - CostPer1MIn: 0.29, - CostPer1MInCached: 0.275, - CostPer1MOutCached: 0.0, - CostPer1MOut: 0.39, - ContextWindow: 128_000, - DefaultMaxTokens: 50000, - // for some reason, the groq api doesn't like the reasoningEffort parameter - CanReason: false, - SupportsAttachments: false, - }, - - Llama4Scout: { - ID: Llama4Scout, - Name: "Llama4Scout", - Provider: ProviderGROQ, - APIModel: "meta-llama/llama-4-scout-17b-16e-instruct", - CostPer1MIn: 0.11, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.34, - ContextWindow: 128_000, // 10M when? - SupportsAttachments: true, - }, - - Llama4Maverick: { - ID: Llama4Maverick, - Name: "Llama4Maverick", - Provider: ProviderGROQ, - APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct", - CostPer1MIn: 0.20, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.20, - ContextWindow: 128_000, - SupportsAttachments: true, - }, - - Llama3_3_70BVersatile: { - ID: Llama3_3_70BVersatile, - Name: "Llama3_3_70BVersatile", - Provider: ProviderGROQ, - APIModel: "llama-3.3-70b-versatile", - CostPer1MIn: 0.59, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.79, - ContextWindow: 128_000, - SupportsAttachments: false, - }, - - DeepseekR1DistillLlama70b: { - ID: DeepseekR1DistillLlama70b, - Name: "DeepseekR1DistillLlama70b", - Provider: ProviderGROQ, - APIModel: "deepseek-r1-distill-llama-70b", - CostPer1MIn: 0.75, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.99, - ContextWindow: 128_000, - CanReason: true, - SupportsAttachments: false, - }, -} diff --git a/internal/llm/models/local.go b/internal/llm/models/local.go deleted file mode 100644 index c469e99fd65d5befbfffe5126a31c88eae68e150..0000000000000000000000000000000000000000 --- a/internal/llm/models/local.go +++ /dev/null @@ -1,206 +0,0 @@ -package models - -import ( - "cmp" - "context" - "encoding/json" - "net/http" - "net/url" - "os" - "regexp" - "strings" - "unicode" - - "github.com/charmbracelet/crush/internal/logging" - "github.com/spf13/viper" -) - -const ( - ProviderLocal InferenceProvider = "local" - - localModelsPath = "v1/models" - lmStudioBetaModelsPath = "api/v0/models" -) - -func init() { - if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" { - localEndpoint, err := url.Parse(endpoint) - if err != nil { - logging.Debug("Failed to parse local endpoint", - "error", err, - "endpoint", endpoint, - ) - return - } - - load := func(url *url.URL, path string) []localModel { - url.Path = path - return listLocalModels(url.String()) - } - - models := load(localEndpoint, lmStudioBetaModelsPath) - - if len(models) == 0 { - models = load(localEndpoint, localModelsPath) - } - - if len(models) == 0 { - logging.Debug("No local models found", - "endpoint", endpoint, - ) - return - } - - loadLocalModels(models) - - viper.SetDefault("providers.local.apiKey", "dummy") - } -} - -type localModelList struct { - Data []localModel `json:"data"` -} - -type localModel struct { - ID string `json:"id"` - Object string `json:"object"` - Type string `json:"type"` - Publisher string `json:"publisher"` - Arch string `json:"arch"` - CompatibilityType string `json:"compatibility_type"` - Quantization string `json:"quantization"` - State string `json:"state"` - MaxContextLength int64 `json:"max_context_length"` - LoadedContextLength int64 `json:"loaded_context_length"` -} - -func listLocalModels(modelsEndpoint string) []localModel { - res, err := http.NewRequestWithContext(context.Background(), http.MethodGet, modelsEndpoint, nil) - if err != nil { - logging.Debug("Failed to list local models", - "error", err, - "endpoint", modelsEndpoint, - ) - } - defer res.Body.Close() - - if res.Response.StatusCode != http.StatusOK { - logging.Debug("Failed to list local models", - "status", res.Response.Status, - "endpoint", modelsEndpoint, - ) - } - - var modelList localModelList - if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil { - logging.Debug("Failed to list local models", - "error", err, - "endpoint", modelsEndpoint, - ) - } - - var supportedModels []localModel - for _, model := range modelList.Data { - if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) { - if model.Object != "model" || model.Type != "llm" { - logging.Debug("Skipping unsupported LMStudio model", - "endpoint", modelsEndpoint, - "id", model.ID, - "object", model.Object, - "type", model.Type, - ) - - continue - } - } - - supportedModels = append(supportedModels, model) - } - - return supportedModels -} - -func loadLocalModels(models []localModel) { - for i, m := range models { - model := convertLocalModel(m) - SupportedModels[model.ID] = model - - if i == 0 || m.State == "loaded" { - viper.SetDefault("agents.coder.model", model.ID) - viper.SetDefault("agents.summarizer.model", model.ID) - viper.SetDefault("agents.task.model", model.ID) - viper.SetDefault("agents.title.model", model.ID) - } - } -} - -func convertLocalModel(model localModel) Model { - return Model{ - ID: ModelID("local." + model.ID), - Name: friendlyModelName(model.ID), - Provider: ProviderLocal, - APIModel: model.ID, - ContextWindow: cmp.Or(model.LoadedContextLength, 4096), - DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096), - CanReason: true, - SupportsAttachments: true, - } -} - -var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`) - -func friendlyModelName(modelID string) string { - mainID := modelID - tag := "" - - if slash := strings.LastIndex(mainID, "/"); slash != -1 { - mainID = mainID[slash+1:] - } - - if at := strings.Index(modelID, "@"); at != -1 { - mainID = modelID[:at] - tag = modelID[at+1:] - } - - match := modelInfoRegex.FindStringSubmatch(mainID) - if match == nil { - return modelID - } - - capitalize := func(s string) string { - if s == "" { - return "" - } - runes := []rune(s) - runes[0] = unicode.ToUpper(runes[0]) - return string(runes) - } - - family := capitalize(match[1]) - version := "" - label := "" - - if len(match) > 2 && match[2] != "" { - version = strings.ToUpper(match[2]) - } - - if len(match) > 3 && match[3] != "" { - label = capitalize(match[3]) - } - - var parts []string - if family != "" { - parts = append(parts, family) - } - if version != "" { - parts = append(parts, version) - } - if label != "" { - parts = append(parts, label) - } - if tag != "" { - parts = append(parts, tag) - } - - return strings.Join(parts, " ") -} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go deleted file mode 100644 index 0aefc170d32d1023f0d246a2cc7522e895453a88..0000000000000000000000000000000000000000 --- a/internal/llm/models/models.go +++ /dev/null @@ -1,74 +0,0 @@ -package models - -import "maps" - -type ( - ModelID string - InferenceProvider string -) - -type Model struct { - ID ModelID `json:"id"` - Name string `json:"name"` - Provider InferenceProvider `json:"provider"` - APIModel string `json:"api_model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - SupportsAttachments bool `json:"supports_attachments"` -} - -// Model IDs -const ( // GEMINI - // Bedrock - BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet" -) - -const ( - ProviderBedrock InferenceProvider = "bedrock" - // ForTests - ProviderMock InferenceProvider = "__mock" -) - -var SupportedModels = map[ModelID]Model{ - // Bedrock - BedrockClaude37Sonnet: { - ID: BedrockClaude37Sonnet, - Name: "Bedrock: Claude 3.7 Sonnet", - Provider: ProviderBedrock, - APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, -} - -var KnownProviders = []InferenceProvider{ - ProviderAnthropic, - ProviderOpenAI, - ProviderGemini, - ProviderAzure, - ProviderGROQ, - ProviderLocal, - ProviderOpenRouter, - ProviderVertexAI, - ProviderBedrock, - ProviderXAI, - ProviderMock, -} - -func init() { - maps.Copy(SupportedModels, AnthropicModels) - maps.Copy(SupportedModels, OpenAIModels) - maps.Copy(SupportedModels, GeminiModels) - maps.Copy(SupportedModels, GroqModels) - maps.Copy(SupportedModels, AzureModels) - maps.Copy(SupportedModels, OpenRouterModels) - maps.Copy(SupportedModels, XAIModels) - maps.Copy(SupportedModels, VertexAIGeminiModels) -} diff --git a/internal/llm/models/openai.go b/internal/llm/models/openai.go deleted file mode 100644 index e4173277cbdfe1e579068d2981df1e70b9943cb1..0000000000000000000000000000000000000000 --- a/internal/llm/models/openai.go +++ /dev/null @@ -1,181 +0,0 @@ -package models - -const ( - ProviderOpenAI InferenceProvider = "openai" - - GPT41 ModelID = "gpt-4.1" - GPT41Mini ModelID = "gpt-4.1-mini" - GPT41Nano ModelID = "gpt-4.1-nano" - GPT45Preview ModelID = "gpt-4.5-preview" - GPT4o ModelID = "gpt-4o" - GPT4oMini ModelID = "gpt-4o-mini" - O1 ModelID = "o1" - O1Pro ModelID = "o1-pro" - O1Mini ModelID = "o1-mini" - O3 ModelID = "o3" - O3Mini ModelID = "o3-mini" - O4Mini ModelID = "o4-mini" -) - -var OpenAIModels = map[ModelID]Model{ - GPT41: { - ID: GPT41, - Name: "GPT 4.1", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1", - CostPer1MIn: 2.00, - CostPer1MInCached: 0.50, - CostPer1MOutCached: 0.0, - CostPer1MOut: 8.00, - ContextWindow: 1_047_576, - DefaultMaxTokens: 20000, - SupportsAttachments: true, - }, - GPT41Mini: { - ID: GPT41Mini, - Name: "GPT 4.1 mini", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1", - CostPer1MIn: 0.40, - CostPer1MInCached: 0.10, - CostPer1MOutCached: 0.0, - CostPer1MOut: 1.60, - ContextWindow: 200_000, - DefaultMaxTokens: 20000, - SupportsAttachments: true, - }, - GPT41Nano: { - ID: GPT41Nano, - Name: "GPT 4.1 nano", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1-nano", - CostPer1MIn: 0.10, - CostPer1MInCached: 0.025, - CostPer1MOutCached: 0.0, - CostPer1MOut: 0.40, - ContextWindow: 1_047_576, - DefaultMaxTokens: 20000, - SupportsAttachments: true, - }, - GPT45Preview: { - ID: GPT45Preview, - Name: "GPT 4.5 preview", - Provider: ProviderOpenAI, - APIModel: "gpt-4.5-preview", - CostPer1MIn: 75.00, - CostPer1MInCached: 37.50, - CostPer1MOutCached: 0.0, - CostPer1MOut: 150.00, - ContextWindow: 128_000, - DefaultMaxTokens: 15000, - SupportsAttachments: true, - }, - GPT4o: { - ID: GPT4o, - Name: "GPT 4o", - Provider: ProviderOpenAI, - APIModel: "gpt-4o", - CostPer1MIn: 2.50, - CostPer1MInCached: 1.25, - CostPer1MOutCached: 0.0, - CostPer1MOut: 10.00, - ContextWindow: 128_000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, - GPT4oMini: { - ID: GPT4oMini, - Name: "GPT 4o mini", - Provider: ProviderOpenAI, - APIModel: "gpt-4o-mini", - CostPer1MIn: 0.15, - CostPer1MInCached: 0.075, - CostPer1MOutCached: 0.0, - CostPer1MOut: 0.60, - ContextWindow: 128_000, - SupportsAttachments: true, - }, - O1: { - ID: O1, - Name: "O1", - Provider: ProviderOpenAI, - APIModel: "o1", - CostPer1MIn: 15.00, - CostPer1MInCached: 7.50, - CostPer1MOutCached: 0.0, - CostPer1MOut: 60.00, - ContextWindow: 200_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - O1Pro: { - ID: O1Pro, - Name: "o1 pro", - Provider: ProviderOpenAI, - APIModel: "o1-pro", - CostPer1MIn: 150.00, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - CostPer1MOut: 600.00, - ContextWindow: 200_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - O1Mini: { - ID: O1Mini, - Name: "o1 mini", - Provider: ProviderOpenAI, - APIModel: "o1-mini", - CostPer1MIn: 1.10, - CostPer1MInCached: 0.55, - CostPer1MOutCached: 0.0, - CostPer1MOut: 4.40, - ContextWindow: 128_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - O3: { - ID: O3, - Name: "o3", - Provider: ProviderOpenAI, - APIModel: "o3", - CostPer1MIn: 10.00, - CostPer1MInCached: 2.50, - CostPer1MOutCached: 0.0, - CostPer1MOut: 40.00, - ContextWindow: 200_000, - CanReason: true, - SupportsAttachments: true, - }, - O3Mini: { - ID: O3Mini, - Name: "o3 mini", - Provider: ProviderOpenAI, - APIModel: "o3-mini", - CostPer1MIn: 1.10, - CostPer1MInCached: 0.55, - CostPer1MOutCached: 0.0, - CostPer1MOut: 4.40, - ContextWindow: 200_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: false, - }, - O4Mini: { - ID: O4Mini, - Name: "o4 mini", - Provider: ProviderOpenAI, - APIModel: "o4-mini", - CostPer1MIn: 1.10, - CostPer1MInCached: 0.275, - CostPer1MOutCached: 0.0, - CostPer1MOut: 4.40, - ContextWindow: 128_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/openrouter.go b/internal/llm/models/openrouter.go deleted file mode 100644 index 8884e03442d30787fd505ca6a6c518d299748752..0000000000000000000000000000000000000000 --- a/internal/llm/models/openrouter.go +++ /dev/null @@ -1,276 +0,0 @@ -package models - -const ( - ProviderOpenRouter InferenceProvider = "openrouter" - - OpenRouterGPT41 ModelID = "openrouter.gpt-4.1" - OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini" - OpenRouterGPT41Nano ModelID = "openrouter.gpt-4.1-nano" - OpenRouterGPT45Preview ModelID = "openrouter.gpt-4.5-preview" - OpenRouterGPT4o ModelID = "openrouter.gpt-4o" - OpenRouterGPT4oMini ModelID = "openrouter.gpt-4o-mini" - OpenRouterO1 ModelID = "openrouter.o1" - OpenRouterO1Pro ModelID = "openrouter.o1-pro" - OpenRouterO1Mini ModelID = "openrouter.o1-mini" - OpenRouterO3 ModelID = "openrouter.o3" - OpenRouterO3Mini ModelID = "openrouter.o3-mini" - OpenRouterO4Mini ModelID = "openrouter.o4-mini" - OpenRouterGemini25Flash ModelID = "openrouter.gemini-2.5-flash" - OpenRouterGemini25 ModelID = "openrouter.gemini-2.5" - OpenRouterClaude35Sonnet ModelID = "openrouter.claude-3.5-sonnet" - OpenRouterClaude3Haiku ModelID = "openrouter.claude-3-haiku" - OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet" - OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku" - OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus" - OpenRouterDeepSeekR1Free ModelID = "openrouter.deepseek-r1-free" -) - -var OpenRouterModels = map[ModelID]Model{ - OpenRouterGPT41: { - ID: OpenRouterGPT41, - Name: "OpenRouter – GPT 4.1", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4.1", - CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens, - }, - OpenRouterGPT41Mini: { - ID: OpenRouterGPT41Mini, - Name: "OpenRouter – GPT 4.1 mini", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4.1-mini", - CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens, - }, - OpenRouterGPT41Nano: { - ID: OpenRouterGPT41Nano, - Name: "OpenRouter – GPT 4.1 nano", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4.1-nano", - CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41Nano].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens, - }, - OpenRouterGPT45Preview: { - ID: OpenRouterGPT45Preview, - Name: "OpenRouter – GPT 4.5 preview", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4.5-preview", - CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT45Preview].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens, - }, - OpenRouterGPT4o: { - ID: OpenRouterGPT4o, - Name: "OpenRouter – GPT 4o", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4o", - CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT4o].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens, - }, - OpenRouterGPT4oMini: { - ID: OpenRouterGPT4oMini, - Name: "OpenRouter – GPT 4o mini", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4o-mini", - CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT4oMini].ContextWindow, - }, - OpenRouterO1: { - ID: OpenRouterO1, - Name: "OpenRouter – O1", - Provider: ProviderOpenRouter, - APIModel: "openai/o1", - CostPer1MIn: OpenAIModels[O1].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens, - CanReason: OpenAIModels[O1].CanReason, - }, - OpenRouterO1Pro: { - ID: OpenRouterO1Pro, - Name: "OpenRouter – o1 pro", - Provider: ProviderOpenRouter, - APIModel: "openai/o1-pro", - CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1Pro].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1Pro].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1Pro].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1Pro].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1Pro].DefaultMaxTokens, - CanReason: OpenAIModels[O1Pro].CanReason, - }, - OpenRouterO1Mini: { - ID: OpenRouterO1Mini, - Name: "OpenRouter – o1 mini", - Provider: ProviderOpenRouter, - APIModel: "openai/o1-mini", - CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O1Mini].CanReason, - }, - OpenRouterO3: { - ID: OpenRouterO3, - Name: "OpenRouter – o3", - Provider: ProviderOpenRouter, - APIModel: "openai/o3", - CostPer1MIn: OpenAIModels[O3].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O3].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached, - ContextWindow: OpenAIModels[O3].ContextWindow, - DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens, - CanReason: OpenAIModels[O3].CanReason, - }, - OpenRouterO3Mini: { - ID: OpenRouterO3Mini, - Name: "OpenRouter – o3 mini", - Provider: ProviderOpenRouter, - APIModel: "openai/o3-mini-high", - CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O3Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O3Mini].CanReason, - }, - OpenRouterO4Mini: { - ID: OpenRouterO4Mini, - Name: "OpenRouter – o4 mini", - Provider: ProviderOpenRouter, - APIModel: "openai/o4-mini-high", - CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O4Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O4Mini].CanReason, - }, - OpenRouterGemini25Flash: { - ID: OpenRouterGemini25Flash, - Name: "OpenRouter – Gemini 2.5 Flash", - Provider: ProviderOpenRouter, - APIModel: "google/gemini-2.5-flash-preview:thinking", - CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn, - CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached, - CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut, - CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached, - ContextWindow: GeminiModels[Gemini25Flash].ContextWindow, - DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens, - }, - OpenRouterGemini25: { - ID: OpenRouterGemini25, - Name: "OpenRouter – Gemini 2.5 Pro", - Provider: ProviderOpenRouter, - APIModel: "google/gemini-2.5-pro-preview-03-25", - CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn, - CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached, - CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut, - CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached, - ContextWindow: GeminiModels[Gemini25].ContextWindow, - DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens, - }, - OpenRouterClaude35Sonnet: { - ID: OpenRouterClaude35Sonnet, - Name: "OpenRouter – Claude 3.5 Sonnet", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3.5-sonnet", - CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude35Sonnet].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude35Sonnet].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude35Sonnet].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude35Sonnet].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude35Sonnet].DefaultMaxTokens, - }, - OpenRouterClaude3Haiku: { - ID: OpenRouterClaude3Haiku, - Name: "OpenRouter – Claude 3 Haiku", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3-haiku", - CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude3Haiku].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude3Haiku].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude3Haiku].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude3Haiku].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude3Haiku].DefaultMaxTokens, - }, - OpenRouterClaude37Sonnet: { - ID: OpenRouterClaude37Sonnet, - Name: "OpenRouter – Claude 3.7 Sonnet", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3.7-sonnet", - CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude37Sonnet].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude37Sonnet].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude37Sonnet].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude37Sonnet].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude37Sonnet].DefaultMaxTokens, - CanReason: AnthropicModels[Claude37Sonnet].CanReason, - }, - OpenRouterClaude35Haiku: { - ID: OpenRouterClaude35Haiku, - Name: "OpenRouter – Claude 3.5 Haiku", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3.5-haiku", - CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude35Haiku].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude35Haiku].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude35Haiku].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude35Haiku].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude35Haiku].DefaultMaxTokens, - }, - OpenRouterClaude3Opus: { - ID: OpenRouterClaude3Opus, - Name: "OpenRouter – Claude 3 Opus", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3-opus", - CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude3Opus].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude3Opus].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude3Opus].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude3Opus].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens, - }, - - OpenRouterDeepSeekR1Free: { - ID: OpenRouterDeepSeekR1Free, - Name: "OpenRouter – DeepSeek R1 Free", - Provider: ProviderOpenRouter, - APIModel: "deepseek/deepseek-r1-0528:free", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOut: 0, - CostPer1MOutCached: 0, - ContextWindow: 163_840, - DefaultMaxTokens: 10000, - }, -} diff --git a/internal/llm/models/vertexai.go b/internal/llm/models/vertexai.go deleted file mode 100644 index c9b5744b62c28e2529cac44b1e97234158d2eacf..0000000000000000000000000000000000000000 --- a/internal/llm/models/vertexai.go +++ /dev/null @@ -1,38 +0,0 @@ -package models - -const ( - ProviderVertexAI InferenceProvider = "vertexai" - - // Models - VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash" - VertexAIGemini25 ModelID = "vertexai.gemini-2.5" -) - -var VertexAIGeminiModels = map[ModelID]Model{ - VertexAIGemini25Flash: { - ID: VertexAIGemini25Flash, - Name: "VertexAI: Gemini 2.5 Flash", - Provider: ProviderVertexAI, - APIModel: "gemini-2.5-flash-preview-04-17", - CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn, - CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached, - CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut, - CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached, - ContextWindow: GeminiModels[Gemini25Flash].ContextWindow, - DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens, - SupportsAttachments: true, - }, - VertexAIGemini25: { - ID: VertexAIGemini25, - Name: "VertexAI: Gemini 2.5 Pro", - Provider: ProviderVertexAI, - APIModel: "gemini-2.5-pro-preview-03-25", - CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn, - CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached, - CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut, - CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached, - ContextWindow: GeminiModels[Gemini25].ContextWindow, - DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/xai.go b/internal/llm/models/xai.go deleted file mode 100644 index a59eac97ee6fee5db5550663083062099512eddc..0000000000000000000000000000000000000000 --- a/internal/llm/models/xai.go +++ /dev/null @@ -1,61 +0,0 @@ -package models - -const ( - ProviderXAI InferenceProvider = "xai" - - XAIGrok3Beta ModelID = "grok-3-beta" - XAIGrok3MiniBeta ModelID = "grok-3-mini-beta" - XAIGrok3FastBeta ModelID = "grok-3-fast-beta" - XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta" -) - -var XAIModels = map[ModelID]Model{ - XAIGrok3Beta: { - ID: XAIGrok3Beta, - Name: "Grok3 Beta", - Provider: ProviderXAI, - APIModel: "grok-3-beta", - CostPer1MIn: 3.0, - CostPer1MInCached: 0, - CostPer1MOut: 15, - CostPer1MOutCached: 0, - ContextWindow: 131_072, - DefaultMaxTokens: 20_000, - }, - XAIGrok3MiniBeta: { - ID: XAIGrok3MiniBeta, - Name: "Grok3 Mini Beta", - Provider: ProviderXAI, - APIModel: "grok-3-mini-beta", - CostPer1MIn: 0.3, - CostPer1MInCached: 0, - CostPer1MOut: 0.5, - CostPer1MOutCached: 0, - ContextWindow: 131_072, - DefaultMaxTokens: 20_000, - }, - XAIGrok3FastBeta: { - ID: XAIGrok3FastBeta, - Name: "Grok3 Fast Beta", - Provider: ProviderXAI, - APIModel: "grok-3-fast-beta", - CostPer1MIn: 5, - CostPer1MInCached: 0, - CostPer1MOut: 25, - CostPer1MOutCached: 0, - ContextWindow: 131_072, - DefaultMaxTokens: 20_000, - }, - XAiGrok3MiniFastBeta: { - ID: XAiGrok3MiniFastBeta, - Name: "Grok3 Mini Fast Beta", - Provider: ProviderXAI, - APIModel: "grok-3-mini-fast-beta", - CostPer1MIn: 0.6, - CostPer1MInCached: 0, - CostPer1MOut: 4.0, - CostPer1MOutCached: 0, - ContextWindow: 131_072, - DefaultMaxTokens: 20_000, - }, -} diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index b272f4e9f263ff596d06aae787e8b5a1c3ac2aec..9f1e5e7c19e739167bb9ab2bd359218e88fd4367 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -9,19 +9,27 @@ import ( "time" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/logging" ) -func CoderPrompt(provider models.InferenceProvider) string { +func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string { basePrompt := baseAnthropicCoderPrompt - switch provider { - case models.ProviderOpenAI: + switch p { + case provider.InferenceProviderOpenAI: basePrompt = baseOpenAICoderPrompt } envInfo := getEnvironmentInfo() - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) + basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) + + contextContent := getContextFromPaths(contextFiles) + logging.Debug("Context content", "Context", contextContent) + if contextContent != "" { + return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent) + } + return basePrompt } const baseOpenAICoderPrompt = ` diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index ed75d29c500cce16f16d06892ad8fcabc254a08d..36148edd9c71790c3a4cb06d551cdee06272c8b7 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -1,60 +1,44 @@ package prompt import ( - "fmt" "os" "path/filepath" "strings" "sync" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" - "github.com/charmbracelet/crush/internal/logging" + "github.com/charmbracelet/crush/internal/fur/provider" ) -func GetAgentPrompt(agentName config.AgentName, provider models.InferenceProvider) string { +type PromptID string + +const ( + PromptCoder PromptID = "coder" + PromptTitle PromptID = "title" + PromptTask PromptID = "task" + PromptSummarizer PromptID = "summarizer" + PromptDefault PromptID = "default" +) + +func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string { basePrompt := "" - switch agentName { - case config.AgentCoder: + switch promptID { + case PromptCoder: basePrompt = CoderPrompt(provider) - case config.AgentTitle: + case PromptTitle: basePrompt = TitlePrompt(provider) - case config.AgentTask: + case PromptTask: basePrompt = TaskPrompt(provider) - case config.AgentSummarizer: + case PromptSummarizer: basePrompt = SummarizerPrompt(provider) default: basePrompt = "You are a helpful assistant" } - - if agentName == config.AgentCoder || agentName == config.AgentTask { - // Add context from project-specific instruction files if they exist - contextContent := getContextFromPaths() - logging.Debug("Context content", "Context", contextContent) - if contextContent != "" { - return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent) - } - } return basePrompt } -var ( - onceContext sync.Once - contextContent string -) - -func getContextFromPaths() string { - onceContext.Do(func() { - var ( - cfg = config.Get() - workDir = cfg.WorkingDir - contextPaths = cfg.ContextPaths - ) - - contextContent = processContextPaths(workDir, contextPaths) - }) - - return contextContent +func getContextFromPaths(contextPaths []string) string { + return processContextPaths(config.WorkingDirectory(), contextPaths) } func processContextPaths(workDir string, paths []string) string { diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index a350c55a32260173dabd56e22d9e514e97b3e5a3..41e3fe92c7fb5615b6c93e2aa89bad35820567ef 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -15,16 +15,10 @@ func TestGetContextFromPaths(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - _, err := config.Load(tmpDir, false) + _, err := config.Init(tmpDir, false) if err != nil { t.Fatalf("Failed to load config: %v", err) } - cfg := config.Get() - cfg.WorkingDir = tmpDir - cfg.ContextPaths = []string{ - "file.txt", - "directory/", - } testFiles := []string{ "file.txt", "directory/file_a.txt", @@ -34,7 +28,12 @@ func TestGetContextFromPaths(t *testing.T) { createTestFiles(t, tmpDir, testFiles) - context := getContextFromPaths() + context := getContextFromPaths( + []string{ + "file.txt", + "directory/", + }, + ) expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir) assert.Equal(t, expectedContext, context) } diff --git a/internal/llm/prompt/summarizer.go b/internal/llm/prompt/summarizer.go index f5a1de0f8619252d99082c6ca54e152cc25a7bc7..77d98184bcf985ebb2bc569205b6b4cc77b3d601 100644 --- a/internal/llm/prompt/summarizer.go +++ b/internal/llm/prompt/summarizer.go @@ -1,8 +1,10 @@ package prompt -import "github.com/charmbracelet/crush/internal/llm/models" +import ( + "github.com/charmbracelet/crush/internal/fur/provider" +) -func SummarizerPrompt(_ models.InferenceProvider) string { +func SummarizerPrompt(_ provider.InferenceProvider) string { return `You are a helpful AI assistant tasked with summarizing conversations. When asked to summarize, provide a detailed but concise summary of the conversation. diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index 89acf1f02121ea008359eaa5201222061dad0cff..719c0ef45778814e38b391e86174708edcdd7c3e 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -3,10 +3,10 @@ package prompt import ( "fmt" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" ) -func TaskPrompt(_ models.InferenceProvider) string { +func TaskPrompt(_ provider.InferenceProvider) string { agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 0b3177b37857c24d299df0d6e64393cd60ea23eb..11bab4b6835ac0e53adc578cfddd3133f8b654e5 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,8 +1,10 @@ package prompt -import "github.com/charmbracelet/crush/internal/llm/models" +import ( + "github.com/charmbracelet/crush/internal/fur/provider" +) -func TitlePrompt(_ models.InferenceProvider) string { +func TitlePrompt(_ provider.InferenceProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 709a56263e0a8880d444c8ee7e9cab1373e67344..aca4d5b7f0adc4977fb349956be1005186e267e6 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -13,7 +13,7 @@ import ( "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/message" @@ -59,7 +59,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic var contentBlocks []anthropic.ContentBlockParamUnion contentBlocks = append(contentBlocks, content) for _, binaryContent := range msg.BinaryContent() { - base64Image := binaryContent.String(models.ProviderAnthropic) + base64Image := binaryContent.String(provider.InferenceProviderAnthropic) imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image) contentBlocks = append(contentBlocks, imageBlock) } @@ -164,7 +164,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to // } return anthropic.MessageNewParams{ - Model: anthropic.Model(a.providerOptions.model.APIModel), + Model: anthropic.Model(a.providerOptions.model.ID), MaxTokens: a.providerOptions.maxTokens, Temperature: temperature, Messages: messages, @@ -184,7 +184,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(preparedMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } @@ -233,7 +233,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message, func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { // jsonData, _ := json.Marshal(preparedMessages) // logging.Debug("Prepared messages", "messages", string(jsonData)) } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 8db9c1e84a4e8496be77e69e612de4abb9ce0c07..6b31c7d7fd6625ad7c2962f409f6c50f01ff726b 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "os" "strings" "github.com/charmbracelet/crush/internal/llm/tools" @@ -19,14 +18,8 @@ type bedrockClient struct { type BedrockClient ProviderClient func newBedrockClient(opts providerClientOptions) BedrockClient { - // Apply bedrock specific options if they are added in the future - // Get AWS region from environment - region := os.Getenv("AWS_REGION") - if region == "" { - region = os.Getenv("AWS_DEFAULT_REGION") - } - + region := opts.extraParams["region"] if region == "" { region = "us-east-1" // default region } @@ -39,11 +32,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { // Prefix the model name with region regionPrefix := region[:2] - modelName := opts.model.APIModel - opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName) + modelName := opts.model.ID + opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) // Determine which provider to use based on the model - if strings.Contains(string(opts.model.APIModel), "anthropic") { + if strings.Contains(string(opts.model.ID), "anthropic") { // Create Anthropic client with Bedrock configuration anthropicOpts := opts // TODO: later find a way to check if the AWS account has caching enabled diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index dd54dac4491634de06a31ee00f1ffd13ea935076..a91c1eae2427a7629ee1f4de6d6b9abb5944a972 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -157,7 +157,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too geminiMessages := g.convertMessages(messages) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(geminiMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } @@ -173,7 +173,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history) + chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history) attempts := 0 for { @@ -245,7 +245,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t geminiMessages := g.convertMessages(messages) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(geminiMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } @@ -261,7 +261,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history) + chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history) attempts := 0 eventChan := make(chan ProviderEvent) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 334312f9e8c41f5d68251d9e7bbd890074fa3982..448ab3674f25053453f51c0f48475db5699ee913 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -9,7 +9,7 @@ import ( "time" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/message" @@ -68,7 +68,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) for _, binaryContent := range msg.BinaryContent() { - imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)} + imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)} imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) @@ -153,7 +153,7 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(o.providerOptions.model.APIModel), + Model: openai.ChatModel(o.providerOptions.model.ID), Messages: messages, Tools: tools, } @@ -180,7 +180,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(params) logging.Debug("Prepared messages", "messages", string(jsonData)) } @@ -237,7 +237,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t } cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(params) logging.Debug("Prepared messages", "messages", string(jsonData)) } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 0d98b74c3292c0aa066dfd0676445e587b800b57..3152cd6a9a7e6fd6a68d0e6b54b6ea6853a38273 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -3,9 +3,9 @@ package provider import ( "context" "fmt" - "os" - "github.com/charmbracelet/crush/internal/llm/models" + configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -55,17 +55,18 @@ type Provider interface { StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() models.Model + Model() configv2.Model } type providerClientOptions struct { baseURL string apiKey string - model models.Model + model configv2.Model disableCache bool maxTokens int64 systemMessage string extraHeaders map[string]string + extraParams map[string]string } type ProviderClientOption func(*providerClientOptions) @@ -80,77 +81,6 @@ type baseProvider[C ProviderClient] struct { client C } -func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) { - clientOptions := providerClientOptions{} - for _, o := range opts { - o(&clientOptions) - } - switch providerName { - case models.ProviderAnthropic: - return &baseProvider[AnthropicClient]{ - options: clientOptions, - client: newAnthropicClient(clientOptions, false), - }, nil - case models.ProviderOpenAI: - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderGemini: - return &baseProvider[GeminiClient]{ - options: clientOptions, - client: newGeminiClient(clientOptions), - }, nil - case models.ProviderBedrock: - return &baseProvider[BedrockClient]{ - options: clientOptions, - client: newBedrockClient(clientOptions), - }, nil - case models.ProviderGROQ: - clientOptions.baseURL = "https://api.groq.com/openai/v1" - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderAzure: - return &baseProvider[AzureClient]{ - options: clientOptions, - client: newAzureClient(clientOptions), - }, nil - case models.ProviderVertexAI: - return &baseProvider[VertexAIClient]{ - options: clientOptions, - client: newVertexAIClient(clientOptions), - }, nil - case models.ProviderOpenRouter: - clientOptions.baseURL = "https://openrouter.ai/api/v1" - clientOptions.extraHeaders = map[string]string{ - "HTTP-Referer": "crush.charm.land", - "X-Title": "Crush", - } - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderXAI: - clientOptions.baseURL = "https://api.x.ai/v1" - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderLocal: - clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT") - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderMock: - // TODO: implement mock client for test - panic("not implemented") - } - return nil, fmt.Errorf("provider not supported: %s", providerName) -} - func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { for _, msg := range messages { // The message has no content @@ -167,7 +97,7 @@ func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.M return p.client.send(ctx, messages, tools) } -func (p *baseProvider[C]) Model() models.Model { +func (p *baseProvider[C]) Model() configv2.Model { return p.options.model } @@ -176,7 +106,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message return p.client.stream(ctx, messages, tools) } -func WithModel(model models.Model) ProviderClientOption { +func WithModel(model configv2.Model) ProviderClientOption { return func(options *providerClientOptions) { options.model = model } @@ -199,3 +129,53 @@ func WithSystemMessage(systemMessage string) ProviderClientOption { options.systemMessage = systemMessage } } + +func NewProviderV2(cfg configv2.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { + clientOptions := providerClientOptions{ + baseURL: cfg.BaseURL, + apiKey: cfg.APIKey, + extraHeaders: cfg.ExtraHeaders, + } + for _, o := range opts { + o(&clientOptions) + } + switch cfg.ProviderType { + case provider.TypeAnthropic: + return &baseProvider[AnthropicClient]{ + options: clientOptions, + client: newAnthropicClient(clientOptions, false), + }, nil + case provider.TypeOpenAI: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + case provider.TypeGemini: + return &baseProvider[GeminiClient]{ + options: clientOptions, + client: newGeminiClient(clientOptions), + }, nil + case provider.TypeBedrock: + return &baseProvider[BedrockClient]{ + options: clientOptions, + client: newBedrockClient(clientOptions), + }, nil + case provider.TypeAzure: + return &baseProvider[AzureClient]{ + options: clientOptions, + client: newAzureClient(clientOptions), + }, nil + case provider.TypeVertexAI: + return &baseProvider[VertexAIClient]{ + options: clientOptions, + client: newVertexAIClient(clientOptions), + }, nil + case provider.TypeXAI: + clientOptions.baseURL = "https://api.x.ai/v1" + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + } + return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType) +} diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index 49374d33fa81ab42e9f0c4d6e7905bfa37a6154e..2d95ad3f60db22e1338db3931b0900e83bccab52 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -2,7 +2,6 @@ package provider import ( "context" - "os" "github.com/charmbracelet/crush/internal/logging" "google.golang.org/genai" @@ -11,9 +10,11 @@ import ( type VertexAIClient ProviderClient func newVertexAIClient(opts providerClientOptions) VertexAIClient { + project := opts.extraHeaders["project"] + location := opts.extraHeaders["location"] client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ - Project: os.Getenv("GOOGLE_CLOUD_PROJECT"), - Location: os.Getenv("GOOGLE_CLOUD_LOCATION"), + Project: project, + Location: location, Backend: genai.BackendVertexAI, }) if err != nil { diff --git a/internal/lsp/client.go b/internal/lsp/client.go index f65b3dee20a3ee0264742257ca78a116661f1165..c04f10a8a924f8725609aace7d5363fe1751a791 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -286,7 +286,7 @@ func (c *Client) SetServerState(state ServerState) { // WaitForServerReady waits for the server to be ready by polling the server // with a simple request until it responds successfully or times out func (c *Client) WaitForServerReady(ctx context.Context) error { - cnf := config.Get() + cfg := config.Get() // Set initial state c.SetServerState(StateStarting) @@ -299,7 +299,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Waiting for LSP server to be ready...") } @@ -308,7 +308,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { // For TypeScript-like servers, we need to open some key files first if serverType == ServerTypeTypeScript { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("TypeScript-like server detected, opening key configuration files") } c.openKeyConfigFiles(ctx) @@ -325,7 +325,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { if err == nil { // Server responded successfully c.SetServerState(StateReady) - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("LSP server is ready") } return nil @@ -333,7 +333,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType) } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType) } } @@ -496,7 +496,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error { // openTypeScriptFiles finds and opens TypeScript files to help initialize the server func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) { - cnf := config.Get() + cfg := config.Get() filesOpened := 0 maxFilesToOpen := 5 // Limit to a reasonable number of files @@ -526,7 +526,7 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) { // Try to open the file if err := c.OpenFile(ctx, path); err == nil { filesOpened++ - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opened TypeScript file for initialization", "file", path) } } @@ -535,11 +535,11 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) { return nil }) - if err != nil && cnf.DebugLSP { + if err != nil && cfg.Options.DebugLSP { logging.Debug("Error walking directory for TypeScript files", "error", err) } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opened TypeScript files for initialization", "count", filesOpened) } } @@ -664,7 +664,7 @@ func (c *Client) NotifyChange(ctx context.Context, filepath string) error { } func (c *Client) CloseFile(ctx context.Context, filepath string) error { - cnf := config.Get() + cfg := config.Get() uri := string(protocol.URIFromPath(filepath)) c.openFilesMu.Lock() @@ -680,7 +680,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error { }, } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Closing file", "file", filepath) } if err := c.Notify(ctx, "textDocument/didClose", params); err != nil { @@ -704,7 +704,7 @@ func (c *Client) IsFileOpen(filepath string) bool { // CloseAllFiles closes all currently open files func (c *Client) CloseAllFiles(ctx context.Context) { - cnf := config.Get() + cfg := config.Get() c.openFilesMu.Lock() filesToClose := make([]string, 0, len(c.openFiles)) @@ -719,12 +719,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) { // Then close them all for _, filePath := range filesToClose { err := c.CloseFile(ctx, filePath) - if err != nil && cnf.DebugLSP { + if err != nil && cfg.Options.DebugLSP { logging.Warn("Error closing file", "file", filePath, "error", err) } } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Closed all files", "files", filesToClose) } } diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 9eb258d761ee36a909cddec16b72b2a3d933a5b4..f2fbfd0a589651590185fe9f73fc222e5bd6b08d 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -82,13 +82,13 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche // Notifications func HandleServerMessage(params json.RawMessage) { - cnf := config.Get() + cfg := config.Get() var msg struct { Type int `json:"type"` Message string `json:"message"` } if err := json.Unmarshal(params, &msg); err == nil { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Server message", "type", msg.Type, "message", msg.Message) } } diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index c3d5d762feeccaaa363a189fd8014b705a583681..5433fb552d6ee3dae390dcf74e3e1d9c8b0d74f9 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -18,9 +18,9 @@ func WriteMessage(w io.Writer, msg *Message) error { if err != nil { return fmt.Errorf("failed to marshal message: %w", err) } - cnf := config.Get() + cfg := config.Get() - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID) } @@ -39,7 +39,7 @@ func WriteMessage(w io.Writer, msg *Message) error { // ReadMessage reads a single LSP message from the given reader func ReadMessage(r *bufio.Reader) (*Message, error) { - cnf := config.Get() + cfg := config.Get() // Read headers var contentLength int for { @@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } line = strings.TrimSpace(line) - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received header", "line", line) } @@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Content-Length", "length", contentLength) } @@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { return nil, fmt.Errorf("failed to read content: %w", err) } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received content", "content", string(content)) } @@ -91,11 +91,11 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { // handleMessages reads and dispatches messages in a loop func (c *Client) handleMessages() { - cnf := config.Get() + cfg := config.Get() for { msg, err := ReadMessage(c.stdout) if err != nil { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Error("Error reading message", "error", err) } return @@ -103,7 +103,7 @@ func (c *Client) handleMessages() { // Handle server->client request (has both Method and ID) if msg.Method != "" && msg.ID != 0 { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID) } @@ -157,11 +157,11 @@ func (c *Client) handleMessages() { c.notificationMu.RUnlock() if ok { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Handling notification", "method", msg.Method) } go handler(msg.Params) - } else if cnf.DebugLSP { + } else if cfg.Options.DebugLSP { logging.Debug("No handler for notification", "method", msg.Method) } continue @@ -174,12 +174,12 @@ func (c *Client) handleMessages() { c.handlersMu.RUnlock() if ok { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received response for request", "id", msg.ID) } ch <- msg close(ch) - } else if cnf.DebugLSP { + } else if cfg.Options.DebugLSP { logging.Debug("No handler for response", "id", msg.ID) } } @@ -188,10 +188,10 @@ func (c *Client) handleMessages() { // Call makes a request and waits for the response func (c *Client) Call(ctx context.Context, method string, params any, result any) error { - cnf := config.Get() + cfg := config.Get() id := c.nextID.Add(1) - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Making call", "method", method, "id", id) } @@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any return fmt.Errorf("failed to send request: %w", err) } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Request sent", "method", method, "id", id) } // Wait for response resp := <-ch - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received response", "id", id) } @@ -249,8 +249,8 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any // Notify sends a notification (a request without an ID that doesn't expect a response) func (c *Client) Notify(ctx context.Context, method string, params any) error { - cnf := config.Get() - if cnf.DebugLSP { + cfg := config.Get() + if cfg.Options.DebugLSP { logging.Debug("Sending notification", "method", method) } diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index a69b3c10577d0c89ffb8aa9972a928201e2124f6..3c2dc05909bd8e6d473696efd09d22435f68dc10 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -43,7 +43,7 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher { // AddRegistrations adds file watchers to track func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) { - cnf := config.Get() + cfg := config.Get() logging.Debug("Adding file watcher registrations") w.registrationMu.Lock() @@ -53,7 +53,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc w.registrations = append(w.registrations, watchers...) // Print detailed registration information for debugging - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Adding file watcher registrations", "id", id, "watchers", len(watchers), @@ -122,7 +122,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc highPriorityFilesOpened := w.openHighPriorityFiles(ctx, serverName) filesOpened += highPriorityFilesOpened - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opened high-priority files", "count", highPriorityFilesOpened, "serverName", serverName) @@ -130,7 +130,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // If we've already opened enough high-priority files, we might not need more if filesOpened >= maxFilesToOpen { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Reached file limit with high-priority files", "filesOpened", filesOpened, "maxFiles", maxFilesToOpen) @@ -148,7 +148,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // Skip directories that should be excluded if d.IsDir() { if path != w.workspacePath && shouldExcludeDir(path) { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -176,7 +176,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc }) elapsedTime := time.Since(startTime) - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Limited workspace scan complete", "filesOpened", filesOpened, "maxFiles", maxFilesToOpen, @@ -185,11 +185,11 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc ) } - if err != nil && cnf.DebugLSP { + if err != nil && cfg.Options.DebugLSP { logging.Debug("Error scanning workspace for files to open", "error", err) } }() - } else if cnf.DebugLSP { + } else if cfg.Options.DebugLSP { logging.Debug("Using on-demand file loading for server", "server", serverName) } } @@ -197,7 +197,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // openHighPriorityFiles opens important files for the server type // Returns the number of files opened func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName string) int { - cnf := config.Get() + cfg := config.Get() filesOpened := 0 // Define patterns for high-priority files based on server type @@ -265,7 +265,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName // Use doublestar.Glob to find files matching the pattern (supports ** patterns) matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern) if err != nil { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err) } continue @@ -299,12 +299,12 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName for j := i; j < end; j++ { fullPath := filesToOpen[j] if err := w.client.OpenFile(ctx, fullPath); err != nil { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Error opening high-priority file", "path", fullPath, "error", err) } } else { filesOpened++ - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opened high-priority file", "path", fullPath) } } @@ -321,7 +321,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName // WatchWorkspace sets up file watching for a workspace func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath string) { - cnf := config.Get() + cfg := config.Get() w.workspacePath = workspacePath // Store the watcher in the context for later use @@ -356,7 +356,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str // Skip excluded directories (except workspace root) if d.IsDir() && path != workspacePath { if shouldExcludeDir(path) { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -409,7 +409,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str } // Debug logging - if cnf.DebugLSP { + if cfg.Options.DebugLSP { matched, kind := w.isPathWatched(event.Name) logging.Debug("File event", "path", event.Name, @@ -676,8 +676,8 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan // notifyFileEvent sends a didChangeWatchedFiles notification for a file event func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error { - cnf := config.Get() - if cnf.DebugLSP { + cfg := config.Get() + if cfg.Options.DebugLSP { logging.Debug("Notifying file event", "uri", uri, "changeType", changeType, @@ -826,7 +826,7 @@ func shouldExcludeDir(dirPath string) bool { // shouldExcludeFile returns true if the file should be excluded from opening func shouldExcludeFile(filePath string) bool { fileName := filepath.Base(filePath) - cnf := config.Get() + cfg := config.Get() // Skip dot files if strings.HasPrefix(fileName, ".") { return true @@ -852,12 +852,12 @@ func shouldExcludeFile(filePath string) bool { // Skip large files if info.Size() > maxFileSize { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Skipping large file", "path", filePath, "size", info.Size(), "maxSize", maxFileSize, - "debug", cnf.Debug, + "debug", cfg.Options.Debug, "sizeMB", float64(info.Size())/(1024*1024), "maxSizeMB", float64(maxFileSize)/(1024*1024), ) @@ -870,7 +870,7 @@ func shouldExcludeFile(filePath string) bool { // openMatchingFile opens a file if it matches any of the registered patterns func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { - cnf := config.Get() + cfg := config.Get() // Skip directories info, err := os.Stat(path) if err != nil || info.IsDir() { @@ -890,10 +890,10 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { // Check if the file is a high-priority file that should be opened immediately // This helps with project initialization for certain language servers if isHighPriorityFile(path, serverName) { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opening high-priority file", "path", path, "serverName", serverName) } - if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { + if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { logging.Error("Error opening high-priority file", "path", path, "error", err) } return @@ -905,7 +905,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { // Check file size - for preloading we're more conservative if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size()) } return @@ -937,7 +937,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { if shouldOpen { // Don't need to check if it's already open - the client.OpenFile handles that - if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { + if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { logging.Error("Error opening file", "path", path, "error", err) } } diff --git a/internal/message/content.go b/internal/message/content.go index b9e83ba4dd7fcc96216755a3871f0553b58d88d7..3ab53e381aaf7755c141985ebe740dbc44356471 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -5,7 +5,7 @@ import ( "slices" "time" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" ) type MessageRole string @@ -71,9 +71,9 @@ type BinaryContent struct { Data []byte } -func (bc BinaryContent) String(provider models.InferenceProvider) string { +func (bc BinaryContent) String(p provider.InferenceProvider) string { base64Encoded := base64.StdEncoding.EncodeToString(bc.Data) - if provider == models.ProviderOpenAI { + if p == provider.InferenceProviderOpenAI { return "data:" + bc.MIMEType + ";base64," + base64Encoded } return base64Encoded @@ -113,7 +113,8 @@ type Message struct { Role MessageRole SessionID string Parts []ContentPart - Model models.ModelID + Model string + Provider string CreatedAt int64 UpdatedAt int64 } diff --git a/internal/message/message.go b/internal/message/message.go index 9e241a0b011ee6277402709fdd8be3aefb5df6fe..7cd823bc3129df5f807ec478d9d6c02364c6cfec 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -8,15 +8,15 @@ import ( "time" "github.com/charmbracelet/crush/internal/db" - "github.com/charmbracelet/crush/internal/llm/models" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) type CreateMessageParams struct { - Role MessageRole - Parts []ContentPart - Model models.ModelID + Role MessageRole + Parts []ContentPart + Model string + Provider string } type Service interface { @@ -70,6 +70,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes Role: string(params.Role), Parts: string(partsJSON), Model: sql.NullString{String: string(params.Model), Valid: true}, + Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""}, }) if err != nil { return Message{}, err @@ -154,7 +155,8 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { SessionID: item.SessionID, Role: MessageRole(item.Role), Parts: parts, - Model: models.ModelID(item.Model.String), + Model: item.Model.String, + Provider: item.Provider.String, CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, }, nil diff --git a/internal/tui/components/chat/header/header.go b/internal/tui/components/chat/header/header.go index 78620161a75a3ade2e0e2416351c50699ac8bd4d..d924bdc3453dc3fce0351c490cb17b726fcc2549 100644 --- a/internal/tui/components/chat/header/header.go +++ b/internal/tui/components/chat/header/header.go @@ -7,7 +7,6 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/llm/models" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" "github.com/charmbracelet/crush/internal/pubsub" @@ -112,11 +111,7 @@ func (h *header) details() string { parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount))) } - cfg := config.Get() - agentCfg := cfg.Agents[config.AgentCoder] - selectedModelID := agentCfg.Model - model := models.SupportedModels[selectedModelID] - + model := config.GetAgentModel(config.AgentCoder) percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100 formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage))) parts = append(parts, formattedPercentage) diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index d5e95b4e3ebded500f73840fda483d3be53ca71d..98d8b2979a90f46fa5901bc77d1e8b4a5105f04d 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -10,7 +10,8 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/tui/components/anim" "github.com/charmbracelet/crush/internal/tui/components/core" @@ -290,8 +291,9 @@ func (m *assistantSectionModel) View() tea.View { duration := finishTime.Sub(m.lastUserMessageTime) infoMsg := t.S().Subtle.Render(duration.String()) icon := t.S().Subtle.Render(styles.ModelIcon) - model := t.S().Muted.Render(models.SupportedModels[m.message.Model].Name) - assistant := fmt.Sprintf("%s %s %s", icon, model, infoMsg) + model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model) + modelFormatted := t.S().Muted.Render(model.Name) + assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg) return tea.NewView( t.S().Base.PaddingLeft(2).Render( core.Section(assistant, m.width-2), diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 405bd1f0f8c7891db1958e70f97e290dd9a8d411..bfcc74c43a2727138d479af647ba461bdcc7520c 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -13,7 +13,6 @@ import ( "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/llm/models" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" @@ -406,7 +405,7 @@ func (m *sidebarCmp) mcpBlock() string { mcpList := []string{section, ""} - mcp := config.Get().MCPServers + mcp := config.Get().MCP if len(mcp) == 0 { return lipgloss.JoinVertical( lipgloss.Left, @@ -475,10 +474,7 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string { } func (s *sidebarCmp) currentModelBlock() string { - cfg := config.Get() - agentCfg := cfg.Agents[config.AgentCoder] - selectedModelID := agentCfg.Model - model := models.SupportedModels[selectedModelID] + model := config.GetAgentModel(config.AgentCoder) t := styles.CurrentTheme() diff --git a/internal/tui/components/dialogs/commands/loader.go b/internal/tui/components/dialogs/commands/loader.go index 9f70afa3cd60342028b6d3fd00e017221c179686..9aee528ee48d0f23e48c417f8bee5bc0e3f381c5 100644 --- a/internal/tui/components/dialogs/commands/loader.go +++ b/internal/tui/components/dialogs/commands/loader.go @@ -63,7 +63,7 @@ func buildCommandSources(cfg *config.Config) []commandSource { // Project directory sources = append(sources, commandSource{ - path: filepath.Join(cfg.Data.Directory, "commands"), + path: filepath.Join(cfg.Options.DataDirectory, "commands"), prefix: ProjectCommandPrefix, }) diff --git a/internal/tui/components/dialogs/init/init.go b/internal/tui/components/dialogs/init/init.go index 74d0dc0b3d9d4630b28c4b240fb17fbe611ba21f..4e331198f5984f81db87332e3c998d9477810806 100644 --- a/internal/tui/components/dialogs/init/init.go +++ b/internal/tui/components/dialogs/init/init.go @@ -5,7 +5,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" - "github.com/charmbracelet/crush/internal/config" + configv2 "github.com/charmbracelet/crush/internal/config" cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" @@ -184,7 +184,7 @@ If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules ( Add the .crush directory to the .gitignore file if it's not already there.` // Mark the project as initialized - if err := config.MarkProjectInitialized(); err != nil { + if err := configv2.MarkProjectInitialized(); err != nil { return util.ReportError(err) } @@ -196,7 +196,7 @@ Add the .crush directory to the .gitignore file if it's not already there.` ) } else { // Mark the project as initialized without running the command - if err := config.MarkProjectInitialized(); err != nil { + if err := configv2.MarkProjectInitialized(); err != nil { return util.ReportError(err) } } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 0197b7141560a67008ceac64c31756bd19fff74a..b5f87b16681ea17e2fb303a4b52a3a83ae30eb85 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -1,13 +1,11 @@ package models import ( - "slices" - "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/list" @@ -26,7 +24,7 @@ const ( // ModelSelectedMsg is sent when a model is selected type ModelSelectedMsg struct { - Model models.Model + Model configv2.PreferredModel } // CloseModelDialogMsg is sent when a model is selected @@ -37,6 +35,11 @@ type ModelDialog interface { dialogs.DialogModel } +type ModelOption struct { + Provider provider.Provider + Model provider.Model +} + type modelDialogCmp struct { width int wWidth int // Width of the terminal window @@ -80,47 +83,31 @@ func NewModelDialogCmp() ModelDialog { } } -var ProviderPopularity = map[models.InferenceProvider]int{ - models.ProviderAnthropic: 1, - models.ProviderOpenAI: 2, - models.ProviderGemini: 3, - models.ProviderGROQ: 4, - models.ProviderOpenRouter: 5, - models.ProviderBedrock: 6, - models.ProviderAzure: 7, - models.ProviderVertexAI: 8, - models.ProviderXAI: 9, -} - -var ProviderName = map[models.InferenceProvider]string{ - models.ProviderAnthropic: "Anthropic", - models.ProviderOpenAI: "OpenAI", - models.ProviderGemini: "Gemini", - models.ProviderGROQ: "Groq", - models.ProviderOpenRouter: "OpenRouter", - models.ProviderBedrock: "AWS Bedrock", - models.ProviderAzure: "Azure", - models.ProviderVertexAI: "VertexAI", - models.ProviderXAI: "xAI", -} - func (m *modelDialogCmp) Init() tea.Cmd { - cfg := config.Get() - enabledProviders := getEnabledProviders(cfg) + providers := configv2.Providers() + cfg := configv2.Get() + coderAgent := cfg.Agents[configv2.AgentCoder] modelItems := []util.Model{} - for _, provider := range enabledProviders { - name, ok := ProviderName[provider] - if !ok { - name = string(provider) // Fallback to provider ID if name is not defined + selectIndex := 0 + for _, provider := range providers { + name := provider.Name + if name == "" { + name = string(provider.ID) } modelItems = append(modelItems, commands.NewItemSection(name)) - for _, model := range getModelsForProvider(provider) { - modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model)) + for _, model := range provider.Models { + if model.ID == coderAgent.Model && provider.ID == coderAgent.Provider { + selectIndex = len(modelItems) // Set the selected index to the current model + } + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ + Provider: provider, + Model: model, + })) } } - m.modelList.SetItems(modelItems) - return m.modelList.Init() + + return tea.Sequence(m.modelList.Init(), m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex)) } func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -137,11 +124,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil // No item selected, do nothing } items := m.modelList.Items() - selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model) + selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption) return m, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), - util.CmdHandler(ModelSelectedMsg{Model: selectedItem}), + util.CmdHandler(ModelSelectedMsg{Model: configv2.PreferredModel{ + ModelID: selectedItem.Model.ID, + Provider: selectedItem.Provider.ID, + }}), ) case key.Matches(msg, m.keyMap.Close): return m, util.CmdHandler(dialogs.CloseDialogMsg{}) @@ -189,58 +179,6 @@ func (m *modelDialogCmp) listHeight() int { return min(listHeigh, m.wHeight/2) } -func GetSelectedModel(cfg *config.Config) models.Model { - agentCfg := cfg.Agents[config.AgentCoder] - selectedModelID := agentCfg.Model - return models.SupportedModels[selectedModelID] -} - -func getEnabledProviders(cfg *config.Config) []models.InferenceProvider { - var providers []models.InferenceProvider - for providerID, provider := range cfg.Providers { - if !provider.Disabled { - providers = append(providers, providerID) - } - } - - // Sort by provider popularity - slices.SortFunc(providers, func(a, b models.InferenceProvider) int { - rA := ProviderPopularity[a] - rB := ProviderPopularity[b] - - // models not included in popularity ranking default to last - if rA == 0 { - rA = 999 - } - if rB == 0 { - rB = 999 - } - return rA - rB - }) - return providers -} - -func getModelsForProvider(provider models.InferenceProvider) []models.Model { - var providerModels []models.Model - for _, model := range models.SupportedModels { - if model.Provider == provider { - providerModels = append(providerModels, model) - } - } - - // reverse alphabetical order (if llm naming was consistent latest would appear first) - slices.SortFunc(providerModels, func(a, b models.Model) int { - if a.Name > b.Name { - return -1 - } else if a.Name < b.Name { - return 1 - } - return 0 - }) - - return providerModels -} - func (m *modelDialogCmp) Position() (int, int) { row := m.wHeight/4 - 2 // just a bit above the center col := m.wWidth / 2 diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index ffb6debb0f61cb1fcfa7e180b042b3b8325dd2e5..44d623847765175d3c38eb81122fa3d55abc430d 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -9,7 +9,6 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/tui/components/chat" @@ -171,14 +170,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { util.CmdHandler(ChatFocusedMsg{Focused: false}), ) case key.Matches(msg, p.keyMap.AddAttachment): - cfg := config.Get() - agentCfg := cfg.Agents[config.AgentCoder] - selectedModelID := agentCfg.Model - model := models.SupportedModels[selectedModelID] - if model.SupportsAttachments { + model := config.GetAgentModel(config.AgentCoder) + if model.SupportsImages { return p, util.CmdHandler(OpenFilePickerMsg{}) } else { - return p, util.ReportWarn("File attachments are not supported by the current model: " + string(selectedModelID)) + return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name) } case key.Matches(msg, p.keyMap.Tab): if p.session.ID == "" { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index c6dee6532993becfbda24d115b8e1e5d05e4fd60..54978b53576940e6fa478b7d05af514f66641acf 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" + configv2 "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/permission" @@ -69,7 +70,7 @@ func (a appModel) Init() tea.Cmd { // Check if we should show the init dialog cmds = append(cmds, func() tea.Msg { - shouldShow, err := config.ShouldShowInitDialog() + shouldShow, err := configv2.ProjectNeedsInitialization() if err != nil { return util.InfoMsg{ Type: util.InfoTypeError, @@ -172,7 +173,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Model Switch case models.ModelSelectedMsg: - model, err := a.app.CoderAgent.Update(config.AgentCoder, msg.Model.ID) + model, err := a.app.CoderAgent.Update(msg.Model) if err != nil { return a, util.ReportError(err) } @@ -222,7 +223,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { model := a.app.CoderAgent.Model() contextWindow := model.ContextWindow tokens := session.CompletionTokens + session.PromptTokens - if (tokens >= int64(float64(contextWindow)*0.95)) && config.Get().AutoCompact { + if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{ Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false), From 5721e7b493696ed85723f5dabbc14dc01b29badf Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 14:05:53 +0200 Subject: [PATCH 07/40] chore: change to the small/large model conceptg --- internal/app/app.go | 6 +- internal/config/config.go | 111 +++++++++++++++--- internal/llm/agent/agent.go | 109 ++++------------- internal/llm/provider/anthropic.go | 7 +- internal/llm/provider/bedrock.go | 22 +++- internal/llm/provider/gemini.go | 10 +- internal/llm/provider/openai.go | 10 +- internal/llm/provider/provider.go | 26 ++-- internal/tui/components/dialogs/init/init.go | 6 +- .../tui/components/dialogs/models/models.go | 14 +-- internal/tui/tui.go | 11 +- 11 files changed, 190 insertions(+), 142 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 75042e89648779cf50a4376aa01aa3b6ac8e72a0..b096c1b4f5612901a1cedeaa2ee758b666cda517 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,7 +9,7 @@ import ( "sync" "time" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" @@ -55,9 +55,9 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { // Initialize LSP clients in the background go app.initLSPClients(ctx) - cfg := configv2.Get() + cfg := config.Get() - coderAgentCfg := cfg.Agents[configv2.AgentCoder] + coderAgentCfg := cfg.Agents[config.AgentCoder] if coderAgentCfg.ID == "" { return nil, fmt.Errorf("coder agent configuration is missing") } diff --git a/internal/config/config.go b/internal/config/config.go index 13444a5ccc8e99bdaa57a6156151b45a40176c09..8ebc1ce6cf5226fb7ad43601eb95d346bfebc0ef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -47,6 +47,13 @@ const ( AgentTask AgentID = "task" ) +type ModelType string + +const ( + LargeModel ModelType = "large" + SmallModel ModelType = "small" +) + type Model struct { ID string `json:"id"` Name string `json:"model"` @@ -90,8 +97,7 @@ type Agent struct { // This is the id of the system prompt used by the agent Disabled bool `json:"disabled"` - Provider provider.InferenceProvider `json:"provider"` - Model string `json:"model"` + Model ModelType `json:"model"` // The available tools for the agent // if this is nil, all tools are available @@ -291,8 +297,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) { ID: AgentCoder, Name: "Coder", Description: "An agent that helps with executing coding tasks.", - Provider: cfg.Models.Large.Provider, - Model: cfg.Models.Large.ModelID, + Model: LargeModel, ContextPaths: cfg.Options.ContextPaths, // All tools allowed }, @@ -300,8 +305,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) { ID: AgentTask, Name: "Task", Description: "An agent that helps with searching for context and finding implementation details.", - Provider: cfg.Models.Large.Provider, - Model: cfg.Models.Large.ModelID, + Model: LargeModel, ContextPaths: cfg.Options.ContextPaths, AllowedTools: []string{ "glob", @@ -490,9 +494,8 @@ func mergeAgents(base, global, local *Config) { switch agentID { case AgentCoder: baseAgent := base.Agents[agentID] - if newAgent.Model != "" && newAgent.Provider != "" { + if newAgent.Model != "" { baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider } baseAgent.AllowedMCP = newAgent.AllowedMCP baseAgent.AllowedLSP = newAgent.AllowedLSP @@ -502,9 +505,8 @@ func mergeAgents(base, global, local *Config) { baseAgent.Name = newAgent.Name baseAgent.Description = newAgent.Description baseAgent.Disabled = newAgent.Disabled - if newAgent.Model == "" || newAgent.Provider == "" { - baseAgent.Provider = base.Models.Large.Provider - baseAgent.Model = base.Models.Large.ModelID + if newAgent.Model == "" { + baseAgent.Model = LargeModel } baseAgent.AllowedTools = newAgent.AllowedTools baseAgent.AllowedMCP = newAgent.AllowedMCP @@ -709,6 +711,8 @@ func WorkingDirectory() string { return cwd } +// TODO: Handle error state + func GetAgentModel(agentID AgentID) Model { cfg := Get() agent, ok := cfg.Agents[agentID] @@ -717,15 +721,25 @@ func GetAgentModel(agentID AgentID) Model { return Model{} } - providerConfig, ok := cfg.Providers[agent.Provider] + var model PreferredModel + switch agent.Model { + case LargeModel: + model = cfg.Models.Large + case SmallModel: + model = cfg.Models.Small + default: + logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model) + model = cfg.Models.Large // Fallback to large model + } + providerConfig, ok := cfg.Providers[model.Provider] if !ok { - logging.Error("Provider not found for agent", "agent_id", agentID, "provider", agent.Provider) + logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider) return Model{} } - for _, model := range providerConfig.Models { - if model.ID == agent.Model { - return model + for _, m := range providerConfig.Models { + if m.ID == model.ModelID { + return m } } @@ -733,6 +747,34 @@ func GetAgentModel(agentID AgentID) Model { return Model{} } +func GetAgentProvider(agentID AgentID) ProviderConfig { + cfg := Get() + agent, ok := cfg.Agents[agentID] + if !ok { + logging.Error("Agent not found", "agent_id", agentID) + return ProviderConfig{} + } + + var model PreferredModel + switch agent.Model { + case LargeModel: + model = cfg.Models.Large + case SmallModel: + model = cfg.Models.Small + default: + logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model) + model = cfg.Models.Large // Fallback to large model + } + + providerConfig, ok := cfg.Providers[model.Provider] + if !ok { + logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider) + return ProviderConfig{} + } + + return providerConfig +} + func GetProviderModel(provider provider.InferenceProvider, modelID string) Model { cfg := Get() providerConfig, ok := cfg.Providers[provider] @@ -750,3 +792,40 @@ func GetProviderModel(provider provider.InferenceProvider, modelID string) Model logging.Error("Model not found for provider", "provider", provider, "model_id", modelID) return Model{} } + +func GetModel(modelType ModelType) Model { + cfg := Get() + var model PreferredModel + switch modelType { + case LargeModel: + model = cfg.Models.Large + case SmallModel: + model = cfg.Models.Small + default: + model = cfg.Models.Large // Fallback to large model + } + providerConfig, ok := cfg.Providers[model.Provider] + if !ok { + return Model{} + } + + for _, m := range providerConfig.Models { + if m.ID == model.ModelID { + return m + } + } + return Model{} +} + +func UpdatePreferredModel(modelType ModelType, model PreferredModel) error { + cfg := Get() + switch modelType { + case LargeModel: + cfg.Models.Large = model + case SmallModel: + cfg.Models.Small = model + default: + return fmt.Errorf("unknown model type: %s", modelType) + } + return nil +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index f9e97b164aa98fe1ae76490fdfcf336efb43098f..8c6faf8c4a06bbef5da279847cd14ce2314648cd 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -9,7 +9,7 @@ import ( "sync" "time" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" @@ -49,19 +49,18 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] - Model() configv2.Model + Model() config.Model Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() IsSessionBusy(sessionID string) bool IsBusy() bool - Update(model configv2.PreferredModel) (configv2.Model, error) Summarize(ctx context.Context, sessionID string) error } type agent struct { *pubsub.Broker[AgentEvent] - agentCfg configv2.Agent + agentCfg config.Agent sessions session.Service messages message.Service @@ -76,13 +75,13 @@ type agent struct { activeRequests sync.Map } -var agentPromptMap = map[configv2.AgentID]prompt.PromptID{ - configv2.AgentCoder: prompt.PromptCoder, - configv2.AgentTask: prompt.PromptTask, +var agentPromptMap = map[config.AgentID]prompt.PromptID{ + config.AgentCoder: prompt.PromptCoder, + config.AgentTask: prompt.PromptTask, } func NewAgent( - agentCfg configv2.Agent, + agentCfg config.Agent, // These services are needed in the tools permissions permission.Service, sessions session.Service, @@ -91,7 +90,7 @@ func NewAgent( lspClients map[string]*lsp.Client, ) (Service, error) { ctx := context.Background() - cfg := configv2.Get() + cfg := config.Get() otherTools := GetMcpTools(ctx, permissions) if len(lspClients) > 0 { otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) @@ -109,8 +108,8 @@ func NewAgent( tools.NewWriteTool(lspClients, permissions, history), } - if agentCfg.ID == configv2.AgentCoder { - taskAgentCfg := configv2.Get().Agents[configv2.AgentTask] + if agentCfg.ID == config.AgentCoder { + taskAgentCfg := config.Get().Agents[config.AgentTask] if taskAgentCfg.ID == "" { return nil, fmt.Errorf("task agent not found in config") } @@ -130,26 +129,14 @@ func NewAgent( } allTools = append(allTools, otherTools...) - var providerCfg configv2.ProviderConfig - for _, p := range cfg.Providers { - if p.ID == agentCfg.Provider { - providerCfg = p - break - } - } + providerCfg := config.GetAgentProvider(agentCfg.ID) if providerCfg.ID == "" { - return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider) + return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) } + model := config.GetAgentModel(agentCfg.ID) - var model configv2.Model - for _, m := range providerCfg.Models { - if m.ID == agentCfg.Model { - model = m - break - } - } if model.ID == "" { - return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider) + return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name) } promptID := agentPromptMap[agentCfg.ID] @@ -157,7 +144,7 @@ func NewAgent( promptID = prompt.PromptDefault } opts := []provider.ProviderClientOption{ - provider.WithModel(model), + provider.WithModel(agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), provider.WithMaxTokens(model.DefaultMaxTokens), } @@ -167,9 +154,9 @@ func NewAgent( } smallModelCfg := cfg.Models.Small - var smallModel configv2.Model + var smallModel config.Model - var smallModelProviderCfg configv2.ProviderConfig + var smallModelProviderCfg config.ProviderConfig if smallModelCfg.Provider == providerCfg.ID { smallModelProviderCfg = providerCfg } else { @@ -194,7 +181,7 @@ func NewAgent( } titleOpts := []provider.ProviderClientOption{ - provider.WithModel(smallModel), + provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), provider.WithMaxTokens(40), } @@ -203,7 +190,7 @@ func NewAgent( return nil, err } summarizeOpts := []provider.ProviderClientOption{ - provider.WithModel(smallModel), + provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), provider.WithMaxTokens(smallModel.DefaultMaxTokens), } @@ -240,8 +227,8 @@ func NewAgent( return agent, nil } -func (a *agent) Model() configv2.Model { - return a.provider.Model() +func (a *agent) Model() config.Model { + return config.GetAgentModel(a.agentCfg.ID) } func (a *agent) Cancel(sessionID string) { @@ -336,7 +323,7 @@ func (a *agent) err(err error) AgentEvent { } func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) { - if !a.provider.Model().SupportsImages && attachments != nil { + if !a.Model().SupportsImages && attachments != nil { attachments = nil } events := make(chan AgentEvent) @@ -458,7 +445,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: a.provider.Model().ID, + Model: a.Model().ID, Provider: a.providerID, }) if err != nil { @@ -609,13 +596,13 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg if err := a.messages.Update(ctx, *assistantMsg); err != nil { return fmt.Errorf("failed to update message: %w", err) } - return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage) + return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage) } return nil } -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error { +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error { sess, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -637,52 +624,6 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2 return nil } -func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) { - if a.IsBusy() { - return configv2.Model{}, fmt.Errorf("cannot change model while processing requests") - } - - cfg := configv2.Get() - var providerCfg configv2.ProviderConfig - for _, p := range cfg.Providers { - if p.ID == modelCfg.Provider { - providerCfg = p - break - } - } - if providerCfg.ID == "" { - return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider) - } - - var model configv2.Model - for _, m := range providerCfg.Models { - if m.ID == modelCfg.ModelID { - model = m - break - } - } - if model.ID == "" { - return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider) - } - - promptID := agentPromptMap[a.agentCfg.ID] - if promptID == "" { - promptID = prompt.PromptDefault - } - opts := []provider.ProviderClientOption{ - provider.WithModel(model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), - provider.WithMaxTokens(model.DefaultMaxTokens), - } - agentProvider, err := provider.NewProviderV2(providerCfg, opts...) - if err != nil { - return configv2.Model{}, err - } - a.provider = agentProvider - - return a.provider.Model(), nil -} - func (a *agent) Summarize(ctx context.Context, sessionID string) error { if a.summarizeProvider == nil { return fmt.Errorf("summarize provider not available") diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index aca4d5b7f0adc4977fb349956be1005186e267e6..626882f283c030454477b27b152bd6a717d08476 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -145,6 +145,7 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason { } func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { + model := a.providerOptions.model(a.providerOptions.modelType) var thinkingParam anthropic.ThinkingConfigParamUnion // TODO: Implement a proper thinking function // lastMessage := messages[len(messages)-1] @@ -164,7 +165,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to // } return anthropic.MessageNewParams{ - Model: anthropic.Model(a.providerOptions.model.ID), + Model: anthropic.Model(model.ID), MaxTokens: a.providerOptions.maxTokens, Temperature: temperature, Messages: messages, @@ -425,6 +426,10 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { } } +func (a *anthropicClient) Model() config.Model { + return a.providerOptions.model(a.providerOptions.modelType) +} + // TODO: check if we need func DefaultShouldThinkFn(s string) bool { return strings.Contains(strings.ToLower(s), "think") diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 6b31c7d7fd6625ad7c2962f409f6c50f01ff726b..1519099b00401e32ad5f19c1f6ed253eb8b7130d 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -30,13 +31,20 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { } } - // Prefix the model name with region - regionPrefix := region[:2] - modelName := opts.model.ID - opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) + opts.model = func(modelType config.ModelType) config.Model { + model := config.GetModel(modelType) + + // Prefix the model name with region + regionPrefix := region[:2] + modelName := model.ID + model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) + return model + } + + model := opts.model(opts.modelType) // Determine which provider to use based on the model - if strings.Contains(string(opts.model.ID), "anthropic") { + if strings.Contains(string(model.ID), "anthropic") { // Create Anthropic client with Bedrock configuration anthropicOpts := opts // TODO: later find a way to check if the AWS account has caching enabled @@ -78,3 +86,7 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, return b.childProvider.stream(ctx, messages, tools) } + +func (b *bedrockClient) Model() config.Model { + return b.providerOptions.model(b.providerOptions.modelType) +} diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index a91c1eae2427a7629ee1f4de6d6b9abb5944a972..a5c012861ad9e6b537c0e9bca8e957ef3f38bf2f 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -173,7 +173,8 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history) + model := g.providerOptions.model(g.providerOptions.modelType) + chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 for { @@ -261,7 +262,8 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history) + model := g.providerOptions.model(g.providerOptions.modelType) + chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 eventChan := make(chan ProviderEvent) @@ -439,6 +441,10 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { } } +func (g *geminiClient) Model() config.Model { + return g.providerOptions.model(g.providerOptions.modelType) +} + // Helper functions func parseJsonToMap(jsonStr string) (map[string]any, error) { var result map[string]any diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 448ab3674f25053453f51c0f48475db5699ee913..9af060a80f75309e1e314e3c33df72e607c9c77a 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -152,13 +152,13 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { } func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + model := o.providerOptions.model(o.providerOptions.modelType) params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(o.providerOptions.model.ID), + Model: openai.ChatModel(model.ID), Messages: messages, Tools: tools, } - - if o.providerOptions.model.CanReason { + if model.CanReason { params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) switch o.options.reasoningEffort { case "low": @@ -384,3 +384,7 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { CacheReadTokens: cachedTokens, } } + +func (a *openaiClient) Model() config.Model { + return a.providerOptions.model(a.providerOptions.modelType) +} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 3152cd6a9a7e6fd6a68d0e6b54b6ea6853a38273..9723dc9fe55af414ed415653e3e9e31031395a02 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" @@ -55,13 +55,14 @@ type Provider interface { StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() configv2.Model + Model() config.Model } type providerClientOptions struct { baseURL string apiKey string - model configv2.Model + modelType config.ModelType + model func(config.ModelType) config.Model disableCache bool maxTokens int64 systemMessage string @@ -74,6 +75,8 @@ type ProviderClientOption func(*providerClientOptions) type ProviderClient interface { send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent + + Model() config.Model } type baseProvider[C ProviderClient] struct { @@ -97,18 +100,18 @@ func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.M return p.client.send(ctx, messages, tools) } -func (p *baseProvider[C]) Model() configv2.Model { - return p.options.model -} - func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { messages = p.cleanMessages(messages) return p.client.stream(ctx, messages, tools) } -func WithModel(model configv2.Model) ProviderClientOption { +func (p *baseProvider[C]) Model() config.Model { + return p.client.Model() +} + +func WithModel(model config.ModelType) ProviderClientOption { return func(options *providerClientOptions) { - options.model = model + options.modelType = model } } @@ -130,11 +133,14 @@ func WithSystemMessage(systemMessage string) ProviderClientOption { } } -func NewProviderV2(cfg configv2.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { +func NewProviderV2(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { clientOptions := providerClientOptions{ baseURL: cfg.BaseURL, apiKey: cfg.APIKey, extraHeaders: cfg.ExtraHeaders, + model: func(tp config.ModelType) config.Model { + return config.GetModel(tp) + }, } for _, o := range opts { o(&clientOptions) diff --git a/internal/tui/components/dialogs/init/init.go b/internal/tui/components/dialogs/init/init.go index 4e331198f5984f81db87332e3c998d9477810806..74d0dc0b3d9d4630b28c4b240fb17fbe611ba21f 100644 --- a/internal/tui/components/dialogs/init/init.go +++ b/internal/tui/components/dialogs/init/init.go @@ -5,7 +5,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" @@ -184,7 +184,7 @@ If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules ( Add the .crush directory to the .gitignore file if it's not already there.` // Mark the project as initialized - if err := configv2.MarkProjectInitialized(); err != nil { + if err := config.MarkProjectInitialized(); err != nil { return util.ReportError(err) } @@ -196,7 +196,7 @@ Add the .crush directory to the .gitignore file if it's not already there.` ) } else { // Mark the project as initialized without running the command - if err := configv2.MarkProjectInitialized(); err != nil { + if err := config.MarkProjectInitialized(); err != nil { return util.ReportError(err) } } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index b5f87b16681ea17e2fb303a4b52a3a83ae30eb85..6d5fa155b2371865771b55c16f8fdbf65d3df952 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -4,7 +4,7 @@ import ( "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" - configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core" @@ -24,7 +24,7 @@ const ( // ModelSelectedMsg is sent when a model is selected type ModelSelectedMsg struct { - Model configv2.PreferredModel + Model config.PreferredModel } // CloseModelDialogMsg is sent when a model is selected @@ -84,12 +84,12 @@ func NewModelDialogCmp() ModelDialog { } func (m *modelDialogCmp) Init() tea.Cmd { - providers := configv2.Providers() - cfg := configv2.Get() + providers := config.Providers() - coderAgent := cfg.Agents[configv2.AgentCoder] modelItems := []util.Model{} selectIndex := 0 + agentModel := config.GetAgentModel(config.AgentCoder) + agentProvider := config.GetAgentProvider(config.AgentCoder) for _, provider := range providers { name := provider.Name if name == "" { @@ -97,7 +97,7 @@ func (m *modelDialogCmp) Init() tea.Cmd { } modelItems = append(modelItems, commands.NewItemSection(name)) for _, model := range provider.Models { - if model.ID == coderAgent.Model && provider.ID == coderAgent.Provider { + if model.ID == agentModel.ID && provider.ID == agentProvider.ID { selectIndex = len(modelItems) // Set the selected index to the current model } modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ @@ -128,7 +128,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), - util.CmdHandler(ModelSelectedMsg{Model: configv2.PreferredModel{ + util.CmdHandler(ModelSelectedMsg{Model: config.PreferredModel{ ModelID: selectedItem.Model.ID, Provider: selectedItem.Provider.ID, }}), diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 54978b53576940e6fa478b7d05af514f66641acf..032b481eeaad75531debe7dc453efe19b866dd8d 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -8,7 +8,6 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" - configv2 "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/permission" @@ -70,7 +69,7 @@ func (a appModel) Init() tea.Cmd { // Check if we should show the init dialog cmds = append(cmds, func() tea.Msg { - shouldShow, err := configv2.ProjectNeedsInitialization() + shouldShow, err := config.ProjectNeedsInitialization() if err != nil { return util.InfoMsg{ Type: util.InfoTypeError, @@ -173,12 +172,8 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Model Switch case models.ModelSelectedMsg: - model, err := a.app.CoderAgent.Update(msg.Model) - if err != nil { - return a, util.ReportError(err) - } - - return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", model.Name)) + config.UpdatePreferredModel(config.LargeModel, msg.Model) + return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", msg.Model.ModelID)) // File Picker case chat.OpenFilePickerMsg: From 086a777f433d9b038e4c303d8e60162c83e6b1a5 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 15:56:53 +0200 Subject: [PATCH 08/40] chore: tests --- internal/config/config.go | 102 +- internal/config/config_test.go | 2202 +++++++++++++++++++++++++++++++- 2 files changed, 2263 insertions(+), 41 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 8ebc1ce6cf5226fb7ad43601eb95d346bfebc0ef..bddef684d9e1c45a5ed165cff000c3cb1d8302e2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -275,23 +275,24 @@ func loadConfig(cwd string, debug bool) (*Config, error) { return cfg, nil } preferredProvider := getPreferredProvider(cfg.Providers) - cfg.Models = PreferredModels{ - Large: PreferredModel{ - ModelID: preferredProvider.DefaultLargeModel, - Provider: preferredProvider.ID, - }, - Small: PreferredModel{ - ModelID: preferredProvider.DefaultSmallModel, - Provider: preferredProvider.ID, - }, + if preferredProvider != nil { + cfg.Models = PreferredModels{ + Large: PreferredModel{ + ModelID: preferredProvider.DefaultLargeModel, + Provider: preferredProvider.ID, + }, + Small: PreferredModel{ + ModelID: preferredProvider.DefaultSmallModel, + Provider: preferredProvider.ID, + }, + } + } else { + // No valid providers found, set empty models + cfg.Models = PreferredModels{} } mergeModels(cfg, globalCfg, localConfig) - if preferredProvider == nil { - return nil, errors.New("no valid providers configured") - } - agents := map[AgentID]Agent{ AgentCoder: { ID: AgentCoder, @@ -376,7 +377,7 @@ func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfi if other.ProviderType != "" { base.ProviderType = other.ProviderType } - if len(base.ExtraHeaders) > 0 { + if len(other.ExtraHeaders) > 0 { if base.ExtraHeaders == nil { base.ExtraHeaders = make(map[string]string) } @@ -488,31 +489,71 @@ func mergeAgents(base, global, local *Config) { } for agentID, newAgent := range cfg.Agents { if _, ok := base.Agents[agentID]; !ok { + // New agent - apply defaults newAgent.ID = agentID // Ensure the ID is set correctly + if newAgent.Model == "" { + newAgent.Model = LargeModel // Default model type + } + // Context paths are always additive - start with global, then add custom + if len(newAgent.ContextPaths) > 0 { + newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...) + } else { + newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only + } base.Agents[agentID] = newAgent } else { - switch agentID { - case AgentCoder: - baseAgent := base.Agents[agentID] + baseAgent := base.Agents[agentID] + + // Special handling for known agents - only allow model changes + if agentID == AgentCoder || agentID == AgentTask { if newAgent.Model != "" { baseAgent.Model = newAgent.Model } - baseAgent.AllowedMCP = newAgent.AllowedMCP - baseAgent.AllowedLSP = newAgent.AllowedLSP - base.Agents[agentID] = baseAgent - default: - baseAgent := base.Agents[agentID] - baseAgent.Name = newAgent.Name - baseAgent.Description = newAgent.Description + // For known agents, only allow MCP and LSP configuration + if newAgent.AllowedMCP != nil { + baseAgent.AllowedMCP = newAgent.AllowedMCP + } + if newAgent.AllowedLSP != nil { + baseAgent.AllowedLSP = newAgent.AllowedLSP + } + // Context paths are additive for known agents too + if len(newAgent.ContextPaths) > 0 { + baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) + } + } else { + // Custom agents - allow full merging + if newAgent.Name != "" { + baseAgent.Name = newAgent.Name + } + if newAgent.Description != "" { + baseAgent.Description = newAgent.Description + } + if newAgent.Model != "" { + baseAgent.Model = newAgent.Model + } else if baseAgent.Model == "" { + baseAgent.Model = LargeModel // Default fallback + } + + // Boolean fields - always update (including false values) baseAgent.Disabled = newAgent.Disabled - if newAgent.Model == "" { - baseAgent.Model = LargeModel + + // Slice/Map fields - update if provided (including empty slices/maps) + if newAgent.AllowedTools != nil { + baseAgent.AllowedTools = newAgent.AllowedTools + } + if newAgent.AllowedMCP != nil { + baseAgent.AllowedMCP = newAgent.AllowedMCP + } + if newAgent.AllowedLSP != nil { + baseAgent.AllowedLSP = newAgent.AllowedLSP + } + // Context paths are additive for custom agents too + if len(newAgent.ContextPaths) > 0 { + baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) } - baseAgent.AllowedTools = newAgent.AllowedTools - baseAgent.AllowedMCP = newAgent.AllowedMCP - baseAgent.AllowedLSP = newAgent.AllowedLSP - base.Agents[agentID] = baseAgent } + + base.Agents[agentID] = baseAgent } } } @@ -555,6 +596,7 @@ func mergeProviderConfigs(base, global, local *Config) { err := validateProvider(providerName, providerConfig) if err != nil { logging.Warn("Skipping provider", "name", providerName, "error", err) + continue // Skip invalid providers } finalProviders[providerName] = providerConfig } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 2942c206aa4bb8b81ff3f3fca9a444411359e515..2f8b4a8acd8be9d979a663a53c28788a66ffd396 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -2,25 +2,299 @@ package config import ( "encoding/json" - "fmt" "os" + "path/filepath" + "sync" "testing" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func resetEnvVars() { - os.Setenv("ANTHROPIC_API_KEY", "") - os.Setenv("OPENAI_API_KEY", "") - os.Setenv("GEMINI_API_KEY", "") - os.Setenv("XAI_API_KEY", "") - os.Setenv("OPENROUTER_API_KEY", "") +func reset() { + // Clear all environment variables that could affect config + envVarsToUnset := []string{ + // API Keys + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "GEMINI_API_KEY", + "XAI_API_KEY", + "OPENROUTER_API_KEY", + + // Google Cloud / VertexAI + "GOOGLE_GENAI_USE_VERTEXAI", + "GOOGLE_CLOUD_PROJECT", + "GOOGLE_CLOUD_LOCATION", + + // AWS Credentials + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION", + "AWS_DEFAULT_REGION", + "AWS_PROFILE", + "AWS_DEFAULT_PROFILE", + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", + "AWS_CONTAINER_CREDENTIALS_FULL_URI", + + // Other + "CRUSH_DEV_DEBUG", + } + + for _, envVar := range envVarsToUnset { + os.Unsetenv(envVar) + } + + // Reset singleton + once = sync.Once{} + instance = nil + cwd = "" + testConfigDir = "" } -func TestConfigWithEnv(t *testing.T) { - resetEnvVars() +// Core Configuration Loading Tests + +func TestInit_ValidWorkingDirectory(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotNil(t, cfg) + assert.Equal(t, cwdDir, WorkingDirectory()) + assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory) + assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) +} + +func TestInit_WithDebugFlag(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, err := Init(cwdDir, true) + + require.NoError(t, err) + assert.True(t, cfg.Options.Debug) +} + +func TestInit_SingletonBehavior(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg1, err1 := Init(cwdDir, false) + cfg2, err2 := Init(cwdDir, false) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.Same(t, cfg1, cfg2) // Should be the same instance +} + +func TestGet_BeforeInitialization(t *testing.T) { + reset() + + assert.Panics(t, func() { + Get() + }) +} + +func TestGet_AfterInitialization(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg1, err := Init(cwdDir, false) + require.NoError(t, err) + + cfg2 := Get() + assert.Same(t, cfg1, cfg2) +} + +func TestLoadConfig_NoConfigFiles(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 0) // No providers without env vars or config files + assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) +} + +func TestLoadConfig_OnlyGlobalConfig(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create global config file + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + }, + }, + Options: Options{ + ContextPaths: []string{"custom-context.md"}, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) + assert.Contains(t, cfg.Options.ContextPaths, "custom-context.md") +} + +func TestLoadConfig_OnlyLocalConfig(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create local config file + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "local-key", + ProviderType: provider.TypeAnthropic, + }, + }, + Options: Options{ + TUI: TUIOptions{CompactMode: true}, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err := json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) + assert.True(t, cfg.Options.TUI.CompactMode) +} + +func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create global config + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "global-key", + ProviderType: provider.TypeOpenAI, + }, + }, + Options: Options{ + ContextPaths: []string{"global-context.md"}, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that overrides and adds + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + APIKey: "local-key", // Override global + }, + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + }, + }, + Options: Options{ + ContextPaths: []string{"local-context.md"}, + TUI: TUIOptions{CompactMode: true}, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 2) + + // Check that local config overrode global + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Equal(t, "local-key", openaiProvider.APIKey) + + // Check that local config added new provider + assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) + + // Check that context paths were merged + assert.Contains(t, cfg.Options.ContextPaths, "global-context.md") + assert.Contains(t, cfg.Options.ContextPaths, "local-context.md") + assert.True(t, cfg.Options.TUI.CompactMode) +} + +func TestLoadConfig_MalformedGlobalJSON(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create malformed global config + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644)) + + _, err := Init(cwdDir, false) + assert.Error(t, err) +} + +func TestLoadConfig_MalformedLocalJSON(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create malformed local config + localConfigPath := filepath.Join(cwdDir, "crush.json") + require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644)) + + _, err := Init(cwdDir, false) + assert.Error(t, err) +} + +func TestConfigWithoutEnv(t *testing.T) { + reset() testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, _ := Init(cwdDir, false) + assert.Len(t, cfg.Providers, 0) +} +func TestConfigWithEnv(t *testing.T) { + reset() + testConfigDir = t.TempDir() cwdDir := t.TempDir() os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") @@ -28,8 +302,1914 @@ func TestConfigWithEnv(t *testing.T) { os.Setenv("GEMINI_API_KEY", "test-gemini-key") os.Setenv("XAI_API_KEY", "test-xai-key") os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") + cfg, _ := Init(cwdDir, false) - data, _ := json.MarshalIndent(cfg, "", " ") - fmt.Println(string(data)) assert.Len(t, cfg.Providers, 5) } + +// Environment Variable Tests + +func TestEnvVars_NoEnvironmentVariables(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 0) +} + +func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set all supported API keys + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") + os.Setenv("OPENAI_API_KEY", "test-openai-key") + os.Setenv("GEMINI_API_KEY", "test-gemini-key") + os.Setenv("XAI_API_KEY", "test-xai-key") + os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 5) + + // Verify each provider is configured correctly + anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] + assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey) + assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Equal(t, "test-openai-key", openaiProvider.APIKey) + assert.Equal(t, provider.TypeOpenAI, openaiProvider.ProviderType) + + geminiProvider := cfg.Providers[provider.InferenceProviderGemini] + assert.Equal(t, "test-gemini-key", geminiProvider.APIKey) + assert.Equal(t, provider.TypeGemini, geminiProvider.ProviderType) + + xaiProvider := cfg.Providers[provider.InferenceProviderXAI] + assert.Equal(t, "test-xai-key", xaiProvider.APIKey) + assert.Equal(t, provider.TypeXAI, xaiProvider.ProviderType) + + openrouterProvider := cfg.Providers[provider.InferenceProviderOpenRouter] + assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey) + assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType) + assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL) +} + +func TestEnvVars_PartialEnvironmentVariables(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set only some API keys + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") + os.Setenv("OPENAI_API_KEY", "test-openai-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 2) + assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) + assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) + assert.NotContains(t, cfg.Providers, provider.InferenceProviderGemini) +} + +func TestEnvVars_VertexAIConfiguration(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set VertexAI environment variables + os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") + os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderVertexAI) + + vertexProvider := cfg.Providers[provider.InferenceProviderVertexAI] + assert.Equal(t, provider.TypeVertexAI, vertexProvider.ProviderType) + assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"]) + assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"]) +} + +func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set Google Cloud vars but not the use flag + os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotContains(t, cfg.Providers, provider.InferenceProviderVertexAI) +} + +func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set AWS credentials + os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + os.Setenv("AWS_DEFAULT_REGION", "us-east-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) + + bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] + assert.Equal(t, provider.TypeBedrock, bedrockProvider.ProviderType) + assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"]) +} + +func TestEnvVars_AWSBedrockWithProfile(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set AWS profile + os.Setenv("AWS_PROFILE", "test-profile") + os.Setenv("AWS_REGION", "eu-west-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) + + bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] + assert.Equal(t, "eu-west-1", bedrockProvider.ExtraParams["region"]) +} + +func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set AWS container credentials + os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test") + os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock) +} + +func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set both region variables - AWS_DEFAULT_REGION should take priority + os.Setenv("AWS_ACCESS_KEY_ID", "test-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + os.Setenv("AWS_DEFAULT_REGION", "us-west-2") + os.Setenv("AWS_REGION", "us-east-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] + assert.Equal(t, "us-west-2", bedrockProvider.ExtraParams["region"]) +} + +func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set only AWS_REGION (not AWS_DEFAULT_REGION) + os.Setenv("AWS_ACCESS_KEY_ID", "test-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + os.Setenv("AWS_REGION", "us-east-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock] + assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"]) +} + +func TestEnvVars_NoAWSCredentials(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Don't set any AWS credentials + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotContains(t, cfg.Providers, provider.InferenceProviderBedrock) +} + +func TestEnvVars_CustomEnvironmentVariables(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Test that environment variables are properly resolved from provider definitions + // This test assumes the provider system uses $VARIABLE_NAME format + os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + if len(cfg.Providers) > 0 { + // Verify that the environment variable was resolved + if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists { + assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey) + } + } +} + +func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set multiple types of environment variables + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic") + os.Setenv("OPENAI_API_KEY", "test-openai") + os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") + os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret") + os.Setenv("AWS_DEFAULT_REGION", "us-west-1") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Should have API key providers + VertexAI + Bedrock + expectedProviders := []provider.InferenceProvider{ + provider.InferenceProviderAnthropic, + provider.InferenceProviderOpenAI, + provider.InferenceProviderVertexAI, + provider.InferenceProviderBedrock, + } + + for _, expectedProvider := range expectedProviders { + assert.Contains(t, cfg.Providers, expectedProvider) + } +} + +func TestHasAWSCredentials_AccessKeys(t *testing.T) { + reset() + + os.Setenv("AWS_ACCESS_KEY_ID", "test-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_Profile(t *testing.T) { + reset() + + os.Setenv("AWS_PROFILE", "test-profile") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_DefaultProfile(t *testing.T) { + reset() + + os.Setenv("AWS_DEFAULT_PROFILE", "default") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_Region(t *testing.T) { + reset() + + os.Setenv("AWS_REGION", "us-east-1") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_ContainerCredentials(t *testing.T) { + reset() + + os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test") + + assert.True(t, hasAWSCredentials()) +} + +func TestHasAWSCredentials_NoCredentials(t *testing.T) { + reset() + + assert.False(t, hasAWSCredentials()) +} + +// Provider Configuration Tests + +func TestProviderMerging_GlobalToBase(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create global config with provider + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "global-openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Equal(t, "global-openai-key", openaiProvider.APIKey) + assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel) + assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel) + assert.Len(t, openaiProvider.Models, 1) +} + +func TestProviderMerging_LocalToBase(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create local config with provider + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "local-anthropic-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + }, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err := json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] + assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey) + assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel) +} + +func TestProviderMerging_ConflictingSettings(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Create global config + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "global-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that overrides + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + APIKey: "local-key", + DefaultLargeModel: "gpt-4-turbo", + // Test disabled separately - don't disable here as it causes nil pointer + }, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + // Local should override global + assert.Equal(t, "local-key", openaiProvider.APIKey) + assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel) + assert.False(t, openaiProvider.Disabled) // Should not be disabled + // Global values should remain where not overridden + assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel) +} + +func TestProviderMerging_CustomVsKnownProviders(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + customProviderID := provider.InferenceProvider("custom-provider") + + // Create config with both known and custom providers + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + // Known provider - some fields should not be overrideable + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + BaseURL: "should-not-override", + ProviderType: provider.TypeAnthropic, // Should not override + }, + // Custom provider - all fields should be configurable + customProviderID: { + ID: customProviderID, + APIKey: "custom-key", + BaseURL: "https://custom.api.com", + ProviderType: provider.TypeOpenAI, + }, + }, + } + + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + BaseURL: "https://should-not-change.com", + ProviderType: provider.TypeGemini, // Should not change + }, + customProviderID: { + BaseURL: "https://updated-custom.api.com", + ProviderType: provider.TypeOpenAI, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Known provider should not have BaseURL/ProviderType overridden + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL) + assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType) + + // Custom provider should have all fields configurable + customProvider := cfg.Providers[customProviderID] + assert.Equal(t, "custom-key", customProvider.APIKey) // Should preserve from global + assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL) + assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType) +} + +func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + customProviderID := provider.InferenceProvider("custom-provider") + + // Create config with custom provider missing BaseURL + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + customProviderID: { + ID: customProviderID, + APIKey: "custom-key", + ProviderType: provider.TypeOpenAI, + // Missing BaseURL + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + // Provider should be filtered out due to validation failure + assert.NotContains(t, cfg.Providers, customProviderID) +} + +func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + customProviderID := provider.InferenceProvider("custom-provider") + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + customProviderID: { + ID: customProviderID, + BaseURL: "https://custom.api.com", + ProviderType: provider.TypeOpenAI, + // Missing APIKey + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotContains(t, cfg.Providers, customProviderID) +} + +func TestProviderValidation_CustomProviderInvalidType(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + customProviderID := provider.InferenceProvider("custom-provider") + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + customProviderID: { + ID: customProviderID, + APIKey: "custom-key", + BaseURL: "https://custom.api.com", + ProviderType: provider.Type("invalid-type"), + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.NotContains(t, cfg.Providers, customProviderID) +} + +func TestProviderValidation_KnownProviderValid(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + // BaseURL not required for known providers + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) +} + +func TestProviderValidation_DisabledProvider(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Disabled: true, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + // Disabled providers should still be in the config but marked as disabled + assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) + assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled) +} + +func TestProviderModels_AddingNewModels(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + } + + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + Models: []Model{ + { + ID: "gpt-4-turbo", + Name: "GPT-4 Turbo", + ContextWindow: 128000, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Len(t, openaiProvider.Models, 2) // Should have both models + + modelIDs := make([]string, len(openaiProvider.Models)) + for i, model := range openaiProvider.Models { + modelIDs[i] = model.ID + } + assert.Contains(t, modelIDs, "gpt-4") + assert.Contains(t, modelIDs, "gpt-4-turbo") +} + +func TestProviderModels_DuplicateModelHandling(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + } + + localConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + Models: []Model{ + { + ID: "gpt-4", // Same ID as global + Name: "GPT-4 Updated", + ContextWindow: 16384, + DefaultMaxTokens: 8192, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + assert.Len(t, openaiProvider.Models, 1) // Should not duplicate + + // Should keep the original model (global config) + model := openaiProvider.Models[0] + assert.Equal(t, "gpt-4", model.ID) + assert.Equal(t, "GPT-4", model.Name) // Original name + assert.Equal(t, int64(8192), model.ContextWindow) // Original context window +} + +func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + CostPer1MInCached: 15.0, + CostPer1MOutCached: 30.0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CanReason: true, + ReasoningEffort: "medium", + SupportsImages: true, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] + require.Len(t, openaiProvider.Models, 1) + + model := openaiProvider.Models[0] + assert.Equal(t, 30.0, model.CostPer1MIn) + assert.Equal(t, 60.0, model.CostPer1MOut) + assert.Equal(t, 15.0, model.CostPer1MInCached) + assert.Equal(t, 30.0, model.CostPer1MOutCached) + assert.True(t, model.CanReason) + assert.Equal(t, "medium", model.ReasoningEffort) + assert.True(t, model.SupportsImages) +} + +// Agent Configuration Tests + +func TestDefaultAgents_CoderAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider so we can test agent configuration + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Agents, AgentCoder) + + coderAgent := cfg.Agents[AgentCoder] + assert.Equal(t, AgentCoder, coderAgent.ID) + assert.Equal(t, "Coder", coderAgent.Name) + assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description) + assert.Equal(t, LargeModel, coderAgent.Model) + assert.False(t, coderAgent.Disabled) + assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths) + // Coder agent should have all tools available (nil means all tools) + assert.Nil(t, coderAgent.AllowedTools) +} + +func TestDefaultAgents_TaskAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider so we can test agent configuration + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + assert.Contains(t, cfg.Agents, AgentTask) + + taskAgent := cfg.Agents[AgentTask] + assert.Equal(t, AgentTask, taskAgent.ID) + assert.Equal(t, "Task", taskAgent.Name) + assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) + assert.Equal(t, LargeModel, taskAgent.Model) + assert.False(t, taskAgent.Disabled) + assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths) + + // Task agent should have restricted tools + expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} + assert.Equal(t, expectedTools, taskAgent.AllowedTools) + + // Task agent should have no MCPs or LSPs by default + assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP) + assert.Equal(t, []string{}, taskAgent.AllowedLSP) +} + +func TestAgentMerging_CustomAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config with custom agent + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentID("custom-agent"): { + ID: AgentID("custom-agent"), + Name: "Custom Agent", + Description: "A custom agent for testing", + Model: SmallModel, + AllowedTools: []string{"glob", "grep"}, + AllowedMCP: map[string][]string{"mcp1": {"tool1", "tool2"}}, + AllowedLSP: []string{"typescript", "go"}, + ContextPaths: []string{"custom-context.md"}, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Should have default agents plus custom agent + assert.Contains(t, cfg.Agents, AgentCoder) + assert.Contains(t, cfg.Agents, AgentTask) + assert.Contains(t, cfg.Agents, AgentID("custom-agent")) + + customAgent := cfg.Agents[AgentID("custom-agent")] + assert.Equal(t, "Custom Agent", customAgent.Name) + assert.Equal(t, "A custom agent for testing", customAgent.Description) + assert.Equal(t, SmallModel, customAgent.Model) + assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools) + assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP) + assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP) + // Context paths should be additive (default + custom) + expectedContextPaths := append(defaultContextPaths, "custom-context.md") + assert.Equal(t, expectedContextPaths, customAgent.ContextPaths) +} + +func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config that modifies the default coder agent + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentCoder: { + Model: SmallModel, // Change from default LargeModel + AllowedMCP: map[string][]string{"mcp1": {"tool1"}}, + AllowedLSP: []string{"typescript"}, + ContextPaths: []string{"coder-specific.md"}, // Should be additive + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + coderAgent := cfg.Agents[AgentCoder] + // Should preserve default values for unspecified fields + assert.Equal(t, AgentCoder, coderAgent.ID) + assert.Equal(t, "Coder", coderAgent.Name) + assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description) + + // Context paths should be additive (default + custom) + expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md") + assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) + + // Should update specified fields + assert.Equal(t, SmallModel, coderAgent.Model) + assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP) + assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP) +} + +func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config that modifies the default task agent + // Note: Only model, MCP, and LSP should be configurable for known agents + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentTask: { + Model: SmallModel, // Should be updated + AllowedMCP: map[string][]string{"search-mcp": nil}, // Should be updated + AllowedLSP: []string{"python"}, // Should be updated + // These should be ignored for known agents: + Name: "Search Agent", // Should be ignored + Description: "Custom search agent", // Should be ignored + Disabled: true, // Should be ignored + AllowedTools: []string{"glob", "grep", "view"}, // Should be ignored + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + taskAgent := cfg.Agents[AgentTask] + // Should preserve default values for protected fields + assert.Equal(t, "Task", taskAgent.Name) // Should remain default + assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) // Should remain default + assert.False(t, taskAgent.Disabled) // Should remain default + assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) // Should remain default + + // Should update configurable fields + assert.Equal(t, SmallModel, taskAgent.Model) + assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP) + assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP) +} + +func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with custom agent + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentID("test-agent"): { + ID: AgentID("test-agent"), + Name: "Global Agent", + Description: "Global description", + Model: LargeModel, + AllowedTools: []string{"glob"}, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that overrides + localConfig := Config{ + Agents: map[AgentID]Agent{ + AgentID("test-agent"): { + Name: "Local Agent", + Description: "Local description", + Model: SmallModel, + Disabled: true, + AllowedTools: []string{"grep", "view"}, + AllowedMCP: map[string][]string{"local-mcp": {"tool1"}}, + }, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + testAgent := cfg.Agents[AgentID("test-agent")] + // Local should override global + assert.Equal(t, "Local Agent", testAgent.Name) + assert.Equal(t, "Local description", testAgent.Description) + assert.Equal(t, SmallModel, testAgent.Model) + assert.True(t, testAgent.Disabled) + assert.Equal(t, []string{"grep", "view"}, testAgent.AllowedTools) + assert.Equal(t, map[string][]string{"local-mcp": {"tool1"}}, testAgent.AllowedMCP) +} + +func TestAgentModelTypeAssignment(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config with agents using different model types + globalConfig := Config{ + Agents: map[AgentID]Agent{ + AgentID("large-agent"): { + ID: AgentID("large-agent"), + Name: "Large Model Agent", + Model: LargeModel, + }, + AgentID("small-agent"): { + ID: AgentID("small-agent"), + Name: "Small Model Agent", + Model: SmallModel, + }, + AgentID("default-agent"): { + ID: AgentID("default-agent"), + Name: "Default Model Agent", + // No model specified - should default to LargeModel + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model) + assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model) + assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) // Should default to LargeModel +} + +func TestAgentContextPathOverrides(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config with custom context paths + globalConfig := Config{ + Options: Options{ + ContextPaths: []string{"global-context.md", "shared-context.md"}, + }, + Agents: map[AgentID]Agent{ + AgentID("custom-context-agent"): { + ID: AgentID("custom-context-agent"), + Name: "Custom Context Agent", + ContextPaths: []string{"agent-specific.md", "custom.md"}, + }, + AgentID("default-context-agent"): { + ID: AgentID("default-context-agent"), + Name: "Default Context Agent", + // No ContextPaths specified - should use global + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Agent with custom context paths should have default + global + custom paths (additive) + customAgent := cfg.Agents[AgentID("custom-context-agent")] + expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md") + assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths) + + // Agent without custom context paths should use global + defaults + defaultAgent := cfg.Agents[AgentID("default-context-agent")] + expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md") + assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths) + + // Default agents should also use the merged context paths + coderAgent := cfg.Agents[AgentCoder] + assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) +} + +// Options and Settings Tests + +func TestOptionsMerging_ContextPaths(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with context paths + globalConfig := Config{ + Options: Options{ + ContextPaths: []string{"global1.md", "global2.md"}, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config with additional context paths + localConfig := Config{ + Options: Options{ + ContextPaths: []string{"local1.md", "local2.md"}, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Context paths should be merged: defaults + global + local + expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md") + assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths) +} + +func TestOptionsMerging_TUIOptions(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with TUI options + globalConfig := Config{ + Options: Options{ + TUI: TUIOptions{ + CompactMode: false, // Default value + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that enables compact mode + localConfig := Config{ + Options: Options{ + TUI: TUIOptions{ + CompactMode: true, + }, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Local config should override global + assert.True(t, cfg.Options.TUI.CompactMode) +} + +func TestOptionsMerging_DebugFlags(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with debug flags + globalConfig := Config{ + Options: Options{ + Debug: false, + DebugLSP: false, + DisableAutoSummarize: false, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that enables debug flags + localConfig := Config{ + Options: Options{ + DebugLSP: true, + DisableAutoSummarize: true, + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Local config should override global for boolean flags + assert.False(t, cfg.Options.Debug) // Not set in local, remains global value + assert.True(t, cfg.Options.DebugLSP) // Set to true in local + assert.True(t, cfg.Options.DisableAutoSummarize) // Set to true in local +} + +func TestOptionsMerging_DataDirectory(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with custom data directory + globalConfig := Config{ + Options: Options{ + DataDirectory: "global-data", + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config with different data directory + localConfig := Config{ + Options: Options{ + DataDirectory: "local-data", + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Local config should override global + assert.Equal(t, "local-data", cfg.Options.DataDirectory) +} + +func TestOptionsMerging_DefaultValues(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // No config files - should use defaults + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Should have default values + assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory) + assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) + assert.False(t, cfg.Options.TUI.CompactMode) + assert.False(t, cfg.Options.Debug) + assert.False(t, cfg.Options.DebugLSP) + assert.False(t, cfg.Options.DisableAutoSummarize) +} + +func TestOptionsMerging_DebugFlagFromInit(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create config with debug false + globalConfig := Config{ + Options: Options{ + Debug: false, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Init with debug=true should override config + cfg, err := Init(cwdDir, true) + + require.NoError(t, err) + + // Debug flag from Init should take precedence + assert.True(t, cfg.Options.Debug) +} + +func TestOptionsMerging_ComplexScenario(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + // Create global config with various options + globalConfig := Config{ + Options: Options{ + ContextPaths: []string{"global-context.md"}, + DataDirectory: "global-data", + Debug: false, + DebugLSP: false, + DisableAutoSummarize: false, + TUI: TUIOptions{ + CompactMode: false, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + // Create local config that partially overrides + localConfig := Config{ + Options: Options{ + ContextPaths: []string{"local-context.md"}, + DebugLSP: true, // Override + DisableAutoSummarize: true, // Override + TUI: TUIOptions{ + CompactMode: true, // Override + }, + // DataDirectory and Debug not specified - should keep global values + }, + } + + localConfigPath := filepath.Join(cwdDir, "crush.json") + data, err = json.Marshal(localConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(localConfigPath, data, 0o644)) + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + + // Check merged results + expectedContextPaths := append(defaultContextPaths, "global-context.md", "local-context.md") + assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths) + assert.Equal(t, "global-data", cfg.Options.DataDirectory) // From global + assert.False(t, cfg.Options.Debug) // From global + assert.True(t, cfg.Options.DebugLSP) // From local + assert.True(t, cfg.Options.DisableAutoSummarize) // From local + assert.True(t, cfg.Options.TUI.CompactMode) // From local +} + +// Model Selection Tests + +func TestModelSelection_PreferredModelSelection(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up multiple providers to test selection logic + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") + os.Setenv("OPENAI_API_KEY", "test-openai-key") + + cfg, err := Init(cwdDir, false) + + require.NoError(t, err) + require.Len(t, cfg.Providers, 2) + + // Should have preferred models set + assert.NotEmpty(t, cfg.Models.Large.ModelID) + assert.NotEmpty(t, cfg.Models.Large.Provider) + assert.NotEmpty(t, cfg.Models.Small.ModelID) + assert.NotEmpty(t, cfg.Models.Small.Provider) + + // Both should use the same provider (first available) + assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider) +} + +func TestModelSelection_GetAgentModel(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider with known models + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CanReason: true, + SupportsImages: true, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + CanReason: false, + SupportsImages: false, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetAgentModel for default agents + coderModel := GetAgentModel(AgentCoder) + assert.Equal(t, "gpt-4", coderModel.ID) // Coder uses LargeModel + assert.Equal(t, "GPT-4", coderModel.Name) + assert.True(t, coderModel.CanReason) + assert.True(t, coderModel.SupportsImages) + + taskModel := GetAgentModel(AgentTask) + assert.Equal(t, "gpt-4", taskModel.ID) // Task also uses LargeModel by default + assert.Equal(t, "GPT-4", taskModel.Name) +} + +func TestModelSelection_GetAgentModelWithCustomModelType(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up provider and custom agent with SmallModel + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, + }, + }, + Agents: map[AgentID]Agent{ + AgentID("small-agent"): { + ID: AgentID("small-agent"), + Name: "Small Agent", + Model: SmallModel, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetAgentModel for custom agent with SmallModel + smallAgentModel := GetAgentModel(AgentID("small-agent")) + assert.Equal(t, "gpt-3.5-turbo", smallAgentModel.ID) + assert.Equal(t, "GPT-3.5 Turbo", smallAgentModel.Name) +} + +func TestModelSelection_GetAgentProvider(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up multiple providers + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + }, + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetAgentProvider + coderProvider := GetAgentProvider(AgentCoder) + assert.NotEmpty(t, coderProvider.ID) + assert.NotEmpty(t, coderProvider.APIKey) + assert.NotEmpty(t, coderProvider.ProviderType) +} + +func TestModelSelection_GetProviderModel(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up provider with specific models + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + CostPer1MIn: 1.5, + CostPer1MOut: 2.0, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetProviderModel + gpt4Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-4") + assert.Equal(t, "gpt-4", gpt4Model.ID) + assert.Equal(t, "GPT-4", gpt4Model.Name) + assert.Equal(t, int64(8192), gpt4Model.ContextWindow) + assert.Equal(t, 30.0, gpt4Model.CostPer1MIn) + + gpt35Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-3.5-turbo") + assert.Equal(t, "gpt-3.5-turbo", gpt35Model.ID) + assert.Equal(t, "GPT-3.5 Turbo", gpt35Model.Name) + assert.Equal(t, 1.5, gpt35Model.CostPer1MIn) + + // Test non-existent model + nonExistentModel := GetProviderModel(provider.InferenceProviderOpenAI, "non-existent") + assert.Empty(t, nonExistentModel.ID) +} + +func TestModelSelection_GetModel(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up provider with models + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Test GetModel + largeModel := GetModel(LargeModel) + assert.Equal(t, "gpt-4", largeModel.ID) + assert.Equal(t, "GPT-4", largeModel.Name) + + smallModel := GetModel(SmallModel) + assert.Equal(t, "gpt-3.5-turbo", smallModel.ID) + assert.Equal(t, "GPT-3.5 Turbo", smallModel.Name) +} + +func TestModelSelection_UpdatePreferredModel(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up multiple providers with OpenAI first to ensure it's selected initially + globalConfig := Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + {ID: "gpt-4", Name: "GPT-4"}, + {ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo"}, + }, + }, + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + {ID: "claude-3-opus", Name: "Claude 3 Opus"}, + {ID: "claude-3-haiku", Name: "Claude 3 Haiku"}, + }, + }, + }, + } + + configPath := filepath.Join(testConfigDir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + data, err := json.Marshal(globalConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, data, 0o644)) + + _, err = Init(cwdDir, false) + + require.NoError(t, err) + + // Get initial preferred models (should be OpenAI since it's listed first) + initialLargeModel := GetModel(LargeModel) + initialSmallModel := GetModel(SmallModel) + + // Verify initial models are OpenAI models + assert.Equal(t, "claude-3-opus", initialLargeModel.ID) + assert.Equal(t, "claude-3-haiku", initialSmallModel.ID) + + // Update preferred models to Anthropic + newLargeModel := PreferredModel{ + ModelID: "gpt-4", + Provider: provider.InferenceProviderOpenAI, + } + newSmallModel := PreferredModel{ + ModelID: "gpt-3.5-turbo", + Provider: provider.InferenceProviderOpenAI, + } + + err = UpdatePreferredModel(LargeModel, newLargeModel) + require.NoError(t, err) + + err = UpdatePreferredModel(SmallModel, newSmallModel) + require.NoError(t, err) + + // Verify models were updated + updatedLargeModel := GetModel(LargeModel) + assert.Equal(t, "gpt-4", updatedLargeModel.ID) + assert.NotEqual(t, initialLargeModel.ID, updatedLargeModel.ID) + + updatedSmallModel := GetModel(SmallModel) + assert.Equal(t, "gpt-3.5-turbo", updatedSmallModel.ID) + assert.NotEqual(t, initialSmallModel.ID, updatedSmallModel.ID) +} + +func TestModelSelection_InvalidModelType(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + _, err := Init(cwdDir, false) + require.NoError(t, err) + + // Test UpdatePreferredModel with invalid model type + invalidModel := PreferredModel{ + ModelID: "some-model", + Provider: provider.InferenceProviderAnthropic, + } + + err = UpdatePreferredModel(ModelType("invalid"), invalidModel) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown model type") +} + +func TestModelSelection_NonExistentAgent(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + _, err := Init(cwdDir, false) + require.NoError(t, err) + + // Test GetAgentModel with non-existent agent + nonExistentModel := GetAgentModel(AgentID("non-existent")) + assert.Empty(t, nonExistentModel.ID) + + // Test GetAgentProvider with non-existent agent + nonExistentProvider := GetAgentProvider(AgentID("non-existent")) + assert.Empty(t, nonExistentProvider.ID) +} + +func TestModelSelection_NonExistentProvider(t *testing.T) { + reset() + testConfigDir = t.TempDir() + cwdDir := t.TempDir() + + // Set up a provider + os.Setenv("ANTHROPIC_API_KEY", "test-key") + + _, err := Init(cwdDir, false) + require.NoError(t, err) + + // Test GetProviderModel with non-existent provider + nonExistentModel := GetProviderModel(provider.InferenceProvider("non-existent"), "some-model") + assert.Empty(t, nonExistentModel.ID) +} From 400bf3f7a8ad2dfb3671c3ab9105d3663537b5bb Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:45:02 +0200 Subject: [PATCH 09/40] chore: clean up unused dependencies --- go.mod | 5 ----- go.sum | 18 ------------------ 2 files changed, 23 deletions(-) diff --git a/go.mod b/go.mod index 99e808c14a24cd34d5274c74eba183229f51dd07..29bcf5e483cb43dc7038d9220bc8b5259719e96d 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,6 @@ require ( github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 github.com/sahilm/fuzzy v0.1.1 github.com/spf13/cobra v1.9.1 - github.com/spf13/viper v1.20.0 github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef github.com/stretchr/testify v1.10.0 @@ -78,7 +77,6 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.8 // indirect @@ -99,18 +97,15 @@ require ( github.com/muesli/mango-pflag v0.1.0 // indirect github.com/muesli/roff v0.1.0 // indirect github.com/ncruces/julianday v1.0.0 // indirect - github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 - github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 github.com/spf13/cast v1.7.1 // indirect github.com/spf13/pflag v1.0.6 // indirect - github.com/subosito/gotenv v1.6.0 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index 3a0b0a6177b193cbbea1bb6d331e4b46fc19cb28..ddbea31f43ce1c6ec4235f3c7af740cca67795fe 100644 --- a/go.sum +++ b/go.sum @@ -114,8 +114,6 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= -github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= @@ -127,8 +125,6 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= -github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= -github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= @@ -194,8 +190,6 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/openai/openai-go v0.1.0-beta.2 h1:Ra5nCFkbEl9w+UJwAciC4kqnIBUCcJazhmMA0/YN894= github.com/openai/openai-go v0.1.0-beta.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= -github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= -github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -213,8 +207,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI= github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs= -github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= -github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= @@ -225,18 +217,10 @@ github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= -github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= -github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= -github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= -github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= -github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= -github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.20.0 h1:zrxIyR3RQIOsarIrgL8+sAvALXul9jeEPa06Y0Ph6vY= -github.com/spf13/viper v1.20.0/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q= github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ= @@ -247,8 +231,6 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= -github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= From 953863834943239385e14d69615ab46fab94f42c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:45:14 +0200 Subject: [PATCH 10/40] fix: formatting in validation test --- internal/config/validation_test.go | 462 +++++++++++++++++++++++++++++ 1 file changed, 462 insertions(+) create mode 100644 internal/config/validation_test.go diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0aef035ae7bddfc7532e9dde550ab0184ed180db --- /dev/null +++ b/internal/config/validation_test.go @@ -0,0 +1,462 @@ +package config + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_Validate_ValidConfig(t *testing.T) { + cfg := &Config{ + Models: PreferredModels{ + Large: PreferredModel{ + ModelID: "gpt-4", + Provider: provider.InferenceProviderOpenAI, + }, + Small: PreferredModel{ + ModelID: "gpt-3.5-turbo", + Provider: provider.InferenceProviderOpenAI, + }, + }, + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + CostPer1MIn: 1.5, + CostPer1MOut: 2.0, + }, + }, + }, + }, + Agents: map[AgentID]Agent{ + AgentCoder: { + ID: AgentCoder, + Name: "Coder", + Description: "An agent that helps with executing coding tasks.", + Model: LargeModel, + ContextPaths: []string{"CRUSH.md"}, + }, + AgentTask: { + ID: AgentTask, + Name: "Task", + Description: "An agent that helps with searching for context and finding implementation details.", + Model: LargeModel, + ContextPaths: []string{"CRUSH.md"}, + AllowedTools: []string{"glob", "grep", "ls", "sourcegraph", "view"}, + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + }, + MCP: map[string]MCP{}, + LSP: map[string]LSPConfig{}, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + assert.NoError(t, err) +} + +func TestConfig_Validate_MissingAPIKey(t *testing.T) { + cfg := &Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + ProviderType: provider.TypeOpenAI, + // Missing APIKey + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "API key is required") +} + +func TestConfig_Validate_InvalidProviderType(t *testing.T) { + cfg := &Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.Type("invalid"), + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid provider type") +} + +func TestConfig_Validate_CustomProviderMissingBaseURL(t *testing.T) { + customProvider := provider.InferenceProvider("custom-provider") + cfg := &Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + customProvider: { + ID: customProvider, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + // Missing BaseURL for custom provider + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "BaseURL is required for custom providers") +} + +func TestConfig_Validate_DuplicateModelIDs(t *testing.T) { + cfg := &Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-4", // Duplicate ID + Name: "GPT-4 Duplicate", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "duplicate model ID") +} + +func TestConfig_Validate_InvalidModelFields(t *testing.T) { + cfg := &Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + Models: []Model{ + { + ID: "", // Empty ID + Name: "GPT-4", + ContextWindow: 0, // Invalid context window + DefaultMaxTokens: -1, // Invalid max tokens + CostPer1MIn: -5.0, // Negative cost + }, + }, + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + validationErr := err.(ValidationErrors) + assert.True(t, len(validationErr) >= 4) // Should have multiple validation errors +} + +func TestConfig_Validate_DefaultModelNotFound(t *testing.T) { + cfg := &Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "nonexistent-model", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + }, + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "default large model 'nonexistent-model' not found") +} + +func TestConfig_Validate_AgentIDMismatch(t *testing.T) { + cfg := &Config{ + Agents: map[AgentID]Agent{ + AgentCoder: { + ID: AgentTask, // Wrong ID + Name: "Coder", + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "agent ID mismatch") +} + +func TestConfig_Validate_InvalidAgentModelType(t *testing.T) { + cfg := &Config{ + Agents: map[AgentID]Agent{ + AgentCoder: { + ID: AgentCoder, + Name: "Coder", + Model: ModelType("invalid"), + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid model type") +} + +func TestConfig_Validate_UnknownTool(t *testing.T) { + cfg := &Config{ + Agents: map[AgentID]Agent{ + AgentID("custom-agent"): { + ID: AgentID("custom-agent"), + Name: "Custom Agent", + Model: LargeModel, + AllowedTools: []string{"unknown-tool"}, + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown tool") +} + +func TestConfig_Validate_MCPReference(t *testing.T) { + cfg := &Config{ + Agents: map[AgentID]Agent{ + AgentID("custom-agent"): { + ID: AgentID("custom-agent"), + Name: "Custom Agent", + Model: LargeModel, + AllowedMCP: map[string][]string{"nonexistent-mcp": nil}, + }, + }, + MCP: map[string]MCP{}, // Empty MCP map + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "referenced MCP 'nonexistent-mcp' not found") +} + +func TestConfig_Validate_InvalidMCPType(t *testing.T) { + cfg := &Config{ + MCP: map[string]MCP{ + "test-mcp": { + Type: MCPType("invalid"), + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid MCP type") +} + +func TestConfig_Validate_MCPMissingCommand(t *testing.T) { + cfg := &Config{ + MCP: map[string]MCP{ + "test-mcp": { + Type: MCPStdio, + // Missing Command + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "command is required for stdio MCP") +} + +func TestConfig_Validate_LSPMissingCommand(t *testing.T) { + cfg := &Config{ + LSP: map[string]LSPConfig{ + "test-lsp": { + // Missing Command + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "command is required for LSP") +} + +func TestConfig_Validate_NoValidProviders(t *testing.T) { + cfg := &Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + Disabled: true, // Disabled + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one non-disabled provider is required") +} + +func TestConfig_Validate_MissingDefaultAgents(t *testing.T) { + cfg := &Config{ + Providers: map[provider.InferenceProvider]ProviderConfig{ + provider.InferenceProviderOpenAI: { + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + }, + }, + Agents: map[AgentID]Agent{}, // Missing default agents + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "coder agent is required") + assert.Contains(t, err.Error(), "task agent is required") +} + +func TestConfig_Validate_KnownAgentProtection(t *testing.T) { + cfg := &Config{ + Agents: map[AgentID]Agent{ + AgentCoder: { + ID: AgentCoder, + Name: "Modified Coder", // Should not be allowed + Description: "Modified description", // Should not be allowed + Model: LargeModel, + }, + }, + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "coder agent name cannot be changed") + assert.Contains(t, err.Error(), "coder agent description cannot be changed") +} + +func TestConfig_Validate_EmptyDataDirectory(t *testing.T) { + cfg := &Config{ + Options: Options{ + DataDirectory: "", // Empty + ContextPaths: []string{"CRUSH.md"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "data directory is required") +} + +func TestConfig_Validate_EmptyContextPath(t *testing.T) { + cfg := &Config{ + Options: Options{ + DataDirectory: ".crush", + ContextPaths: []string{""}, // Empty context path + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "context path cannot be empty") +} From f2cb2321edf285750a4de5e168b0657d4f1c31d3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:46:12 +0200 Subject: [PATCH 11/40] chore: migrate config file from .crush.json to crush.json --- .crush.json | 11 ----------- crush.json | 7 +++++++ 2 files changed, 7 insertions(+), 11 deletions(-) delete mode 100644 .crush.json create mode 100644 crush.json diff --git a/.crush.json b/.crush.json deleted file mode 100644 index 0edb084c3ee5f74c9b279b634b5838b3035cee95..0000000000000000000000000000000000000000 --- a/.crush.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "$schema": "./crush-schema.json", - "lsp": { - "Go": { - "command": "gopls" - } - }, - "tui": { - "theme": "crush" - } -} diff --git a/crush.json b/crush.json new file mode 100644 index 0000000000000000000000000000000000000000..1b04ea6c24f8b64a3a12ceb47551f3177fa66302 --- /dev/null +++ b/crush.json @@ -0,0 +1,7 @@ +{ + "lsp": { + "Go": { + "command": "gopls" + } + } +} From e16fc3deeb917fdc6b7b19e3e84557fe67c00f73 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:46:25 +0200 Subject: [PATCH 12/40] feat: add config validation and provider system with mock support --- internal/config/config.go | 524 ++++++++++++++++++++++++++++++- internal/config/config_test.go | 108 ++++++- internal/config/provider.go | 14 + internal/config/provider_mock.go | 177 +++++++++++ internal/config/provider_test.go | 105 +++++++ 5 files changed, 909 insertions(+), 19 deletions(-) create mode 100644 internal/config/provider_mock.go create mode 100644 internal/config/provider_test.go diff --git a/internal/config/config.go b/internal/config/config.go index bddef684d9e1c45a5ed165cff000c3cb1d8302e2..e33aab02a492e8a1a4c55554fe5a3656d101ec1e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -325,6 +325,11 @@ func loadConfig(cwd string, debug bool) (*Config, error) { mergeMCPs(cfg, globalCfg, localConfig) mergeLSPs(cfg, globalCfg, localConfig) + // Validate the final configuration + if err := cfg.Validate(); err != nil { + return cfg, fmt.Errorf("configuration validation failed: %w", err) + } + return cfg, nil } @@ -503,7 +508,7 @@ func mergeAgents(base, global, local *Config) { base.Agents[agentID] = newAgent } else { baseAgent := base.Agents[agentID] - + // Special handling for known agents - only allow model changes if agentID == AgentCoder || agentID == AgentTask { if newAgent.Model != "" { @@ -533,10 +538,10 @@ func mergeAgents(base, global, local *Config) { } else if baseAgent.Model == "" { baseAgent.Model = LargeModel // Default fallback } - + // Boolean fields - always update (including false values) baseAgent.Disabled = newAgent.Disabled - + // Slice/Map fields - update if provided (including empty slices/maps) if newAgent.AllowedTools != nil { baseAgent.AllowedTools = newAgent.AllowedTools @@ -552,7 +557,7 @@ func mergeAgents(base, global, local *Config) { baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) } } - + base.Agents[agentID] = baseAgent } } @@ -666,6 +671,9 @@ func defaultConfigBasedOnEnv() *Config { ContextPaths: defaultContextPaths, }, Providers: make(map[provider.InferenceProvider]ProviderConfig), + Agents: make(map[AgentID]Agent), + LSP: make(map[string]LSPConfig), + MCP: make(map[string]MCP), } providers := Providers() @@ -681,9 +689,7 @@ func defaultConfigBasedOnEnv() *Config { baseURL := p.APIEndpoint if strings.HasPrefix(baseURL, "$") { envVar := strings.TrimPrefix(baseURL, "$") - if url := os.Getenv(envVar); url != "" { - baseURL = url - } + baseURL = os.Getenv(envVar) } providerConfig.BaseURL = baseURL for _, model := range p.Models { @@ -871,3 +877,507 @@ func UpdatePreferredModel(modelType ModelType, model PreferredModel) error { } return nil } + +// ValidationError represents a configuration validation error +type ValidationError struct { + Field string + Message string +} + +func (e ValidationError) Error() string { + return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message) +} + +// ValidationErrors represents multiple validation errors +type ValidationErrors []ValidationError + +func (e ValidationErrors) Error() string { + if len(e) == 0 { + return "no validation errors" + } + if len(e) == 1 { + return e[0].Error() + } + + var messages []string + for _, err := range e { + messages = append(messages, err.Error()) + } + return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; ")) +} + +// HasErrors returns true if there are any validation errors +func (e ValidationErrors) HasErrors() bool { + return len(e) > 0 +} + +// Add appends a new validation error +func (e *ValidationErrors) Add(field, message string) { + *e = append(*e, ValidationError{Field: field, Message: message}) +} + +// Validate performs comprehensive validation of the configuration +func (c *Config) Validate() error { + var errors ValidationErrors + + // Validate providers + c.validateProviders(&errors) + + // Validate models + c.validateModels(&errors) + + // Validate agents + c.validateAgents(&errors) + + // Validate options + c.validateOptions(&errors) + + // Validate MCP configurations + c.validateMCPs(&errors) + + // Validate LSP configurations + c.validateLSPs(&errors) + + // Validate cross-references + c.validateCrossReferences(&errors) + + // Validate completeness + c.validateCompleteness(&errors) + + if errors.HasErrors() { + return errors + } + + return nil +} + +// validateProviders validates all provider configurations +func (c *Config) validateProviders(errors *ValidationErrors) { + if c.Providers == nil { + c.Providers = make(map[provider.InferenceProvider]ProviderConfig) + } + + knownProviders := provider.KnownProviders() + validTypes := []provider.Type{ + provider.TypeOpenAI, + provider.TypeAnthropic, + provider.TypeGemini, + provider.TypeAzure, + provider.TypeBedrock, + provider.TypeVertexAI, + provider.TypeXAI, + } + + for providerID, providerConfig := range c.Providers { + fieldPrefix := fmt.Sprintf("providers.%s", providerID) + + // Validate API key for non-disabled providers + if !providerConfig.Disabled && providerConfig.APIKey == "" { + // Special case for AWS Bedrock and VertexAI which may use other auth methods + if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI { + errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers") + } + } + + // Validate provider type + validType := false + for _, vt := range validTypes { + if providerConfig.ProviderType == vt { + validType = true + break + } + } + if !validType { + errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType)) + } + + // Validate custom providers + isKnownProvider := false + for _, kp := range knownProviders { + if providerID == kp { + isKnownProvider = true + break + } + } + + if !isKnownProvider { + // Custom provider validation + if providerConfig.BaseURL == "" { + errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers") + } + if providerConfig.ProviderType != provider.TypeOpenAI { + errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type") + } + } + + // Validate models + modelIDs := make(map[string]bool) + for i, model := range providerConfig.Models { + modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i) + + // Check for duplicate model IDs + if modelIDs[model.ID] { + errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID)) + } + modelIDs[model.ID] = true + + // Validate required model fields + if model.ID == "" { + errors.Add(modelFieldPrefix+".id", "model ID is required") + } + if model.Name == "" { + errors.Add(modelFieldPrefix+".name", "model name is required") + } + if model.ContextWindow <= 0 { + errors.Add(modelFieldPrefix+".context_window", "context window must be positive") + } + if model.DefaultMaxTokens <= 0 { + errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive") + } + if model.DefaultMaxTokens > model.ContextWindow { + errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window") + } + + // Validate cost fields + if model.CostPer1MIn < 0 { + errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative") + } + if model.CostPer1MOut < 0 { + errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative") + } + if model.CostPer1MInCached < 0 { + errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative") + } + if model.CostPer1MOutCached < 0 { + errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative") + } + } + + // Validate default model references + if providerConfig.DefaultLargeModel != "" { + if !modelIDs[providerConfig.DefaultLargeModel] { + errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel)) + } + } + if providerConfig.DefaultSmallModel != "" { + if !modelIDs[providerConfig.DefaultSmallModel] { + errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel)) + } + } + + // Validate provider-specific requirements + c.validateProviderSpecific(providerID, providerConfig, errors) + } +} + +// validateProviderSpecific validates provider-specific requirements +func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) { + fieldPrefix := fmt.Sprintf("providers.%s", providerID) + + switch providerID { + case provider.InferenceProviderVertexAI: + if !providerConfig.Disabled { + if providerConfig.ExtraParams == nil { + errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration") + } else { + if providerConfig.ExtraParams["project"] == "" { + errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter") + } + if providerConfig.ExtraParams["location"] == "" { + errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter") + } + } + } + case provider.InferenceProviderBedrock: + if !providerConfig.Disabled { + if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" { + errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter") + } + // Check for AWS credentials in environment + if !hasAWSCredentials() { + errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment") + } + } + } +} + +// validateModels validates preferred model configurations +func (c *Config) validateModels(errors *ValidationErrors) { + // Validate large model + if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" { + if c.Models.Large.ModelID == "" { + errors.Add("models.large.model_id", "large model ID is required when provider is set") + } + if c.Models.Large.Provider == "" { + errors.Add("models.large.provider", "large model provider is required when model ID is set") + } + + // Check if provider exists and is not disabled + if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists { + if providerConfig.Disabled { + errors.Add("models.large.provider", "large model provider is disabled") + } + + // Check if model exists in provider + modelExists := false + for _, model := range providerConfig.Models { + if model.ID == c.Models.Large.ModelID { + modelExists = true + break + } + } + if !modelExists { + errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider)) + } + } else { + errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider)) + } + } + + // Validate small model + if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" { + if c.Models.Small.ModelID == "" { + errors.Add("models.small.model_id", "small model ID is required when provider is set") + } + if c.Models.Small.Provider == "" { + errors.Add("models.small.provider", "small model provider is required when model ID is set") + } + + // Check if provider exists and is not disabled + if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists { + if providerConfig.Disabled { + errors.Add("models.small.provider", "small model provider is disabled") + } + + // Check if model exists in provider + modelExists := false + for _, model := range providerConfig.Models { + if model.ID == c.Models.Small.ModelID { + modelExists = true + break + } + } + if !modelExists { + errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider)) + } + } else { + errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider)) + } + } +} + +// validateAgents validates agent configurations +func (c *Config) validateAgents(errors *ValidationErrors) { + if c.Agents == nil { + c.Agents = make(map[AgentID]Agent) + } + + validTools := []string{ + "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", + } + + for agentID, agent := range c.Agents { + fieldPrefix := fmt.Sprintf("agents.%s", agentID) + + // Validate agent ID consistency + if agent.ID != agentID { + errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID)) + } + + // Validate required fields + if agent.ID == "" { + errors.Add(fieldPrefix+".id", "agent ID is required") + } + if agent.Name == "" { + errors.Add(fieldPrefix+".name", "agent name is required") + } + + // Validate model type + if agent.Model != LargeModel && agent.Model != SmallModel { + errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model)) + } + + // Validate allowed tools + if agent.AllowedTools != nil { + for i, tool := range agent.AllowedTools { + validTool := false + for _, vt := range validTools { + if tool == vt { + validTool = true + break + } + } + if !validTool { + errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool)) + } + } + } + + // Validate MCP references + if agent.AllowedMCP != nil { + for mcpName := range agent.AllowedMCP { + if _, exists := c.MCP[mcpName]; !exists { + errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName)) + } + } + } + + // Validate LSP references + if agent.AllowedLSP != nil { + for _, lspName := range agent.AllowedLSP { + if _, exists := c.LSP[lspName]; !exists { + errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName)) + } + } + } + + // Validate context paths (basic path validation) + for i, contextPath := range agent.ContextPaths { + if contextPath == "" { + errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty") + } + // Check for invalid characters in path + if strings.Contains(contextPath, "\x00") { + errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters") + } + } + + // Validate known agents maintain their core properties + if agentID == AgentCoder { + if agent.Name != "Coder" { + errors.Add(fieldPrefix+".name", "coder agent name cannot be changed") + } + if agent.Description != "An agent that helps with executing coding tasks." { + errors.Add(fieldPrefix+".description", "coder agent description cannot be changed") + } + } else if agentID == AgentTask { + if agent.Name != "Task" { + errors.Add(fieldPrefix+".name", "task agent name cannot be changed") + } + if agent.Description != "An agent that helps with searching for context and finding implementation details." { + errors.Add(fieldPrefix+".description", "task agent description cannot be changed") + } + expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} + if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) { + errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed") + } + } + } +} + +// validateOptions validates configuration options +func (c *Config) validateOptions(errors *ValidationErrors) { + // Validate data directory + if c.Options.DataDirectory == "" { + errors.Add("options.data_directory", "data directory is required") + } + + // Validate context paths + for i, contextPath := range c.Options.ContextPaths { + if contextPath == "" { + errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty") + } + if strings.Contains(contextPath, "\x00") { + errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters") + } + } +} + +// validateMCPs validates MCP configurations +func (c *Config) validateMCPs(errors *ValidationErrors) { + if c.MCP == nil { + c.MCP = make(map[string]MCP) + } + + for mcpName, mcpConfig := range c.MCP { + fieldPrefix := fmt.Sprintf("mcp.%s", mcpName) + + // Validate MCP type + if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse { + errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type)) + } + + // Validate based on type + if mcpConfig.Type == MCPStdio { + if mcpConfig.Command == "" { + errors.Add(fieldPrefix+".command", "command is required for stdio MCP") + } + } else if mcpConfig.Type == MCPSse { + if mcpConfig.URL == "" { + errors.Add(fieldPrefix+".url", "URL is required for SSE MCP") + } + } + } +} + +// validateLSPs validates LSP configurations +func (c *Config) validateLSPs(errors *ValidationErrors) { + if c.LSP == nil { + c.LSP = make(map[string]LSPConfig) + } + + for lspName, lspConfig := range c.LSP { + fieldPrefix := fmt.Sprintf("lsp.%s", lspName) + + if lspConfig.Command == "" { + errors.Add(fieldPrefix+".command", "command is required for LSP") + } + } +} + +// validateCrossReferences validates cross-references between different config sections +func (c *Config) validateCrossReferences(errors *ValidationErrors) { + // Validate that agents can use their assigned model types + for agentID, agent := range c.Agents { + fieldPrefix := fmt.Sprintf("agents.%s", agentID) + + var preferredModel PreferredModel + switch agent.Model { + case LargeModel: + preferredModel = c.Models.Large + case SmallModel: + preferredModel = c.Models.Small + } + + if preferredModel.Provider != "" { + if providerConfig, exists := c.Providers[preferredModel.Provider]; exists { + if providerConfig.Disabled { + errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider)) + } + } + } + } +} + +// validateCompleteness validates that the configuration is complete and usable +func (c *Config) validateCompleteness(errors *ValidationErrors) { + // Check for at least one valid, non-disabled provider + hasValidProvider := false + for _, providerConfig := range c.Providers { + if !providerConfig.Disabled { + hasValidProvider = true + break + } + } + if !hasValidProvider { + errors.Add("providers", "at least one non-disabled provider is required") + } + + // Check that default agents exist + if _, exists := c.Agents[AgentCoder]; !exists { + errors.Add("agents", "coder agent is required") + } + if _, exists := c.Agents[AgentTask]; !exists { + errors.Add("agents", "task agent is required") + } + + // Check that preferred models are set if providers exist + if hasValidProvider { + if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" { + errors.Add("models.large", "large preferred model must be configured when providers are available") + } + if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" { + errors.Add("models.small", "small preferred model must be configured when providers are available") + } + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 2f8b4a8acd8be9d979a663a53c28788a66ffd396..f69b3c84701b07c6df948c7abca6e37a65e3c69e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -50,6 +50,10 @@ func reset() { instance = nil cwd = "" testConfigDir = "" + + // Enable mock providers for all tests to avoid API calls + UseMockProviders = true + ResetProviders() } // Core Configuration Loading Tests @@ -133,9 +137,29 @@ func TestLoadConfig_OnlyGlobalConfig(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "test-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + CostPer1MIn: 1.0, + CostPer1MOut: 2.0, + ContextWindow: 4096, + DefaultMaxTokens: 4096, + }, + }, }, }, Options: Options{ @@ -167,9 +191,29 @@ func TestLoadConfig_OnlyLocalConfig(t *testing.T) { localConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "local-key", - ProviderType: provider.TypeAnthropic, + ID: provider.InferenceProviderAnthropic, + APIKey: "local-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + CostPer1MIn: 15.0, + CostPer1MOut: 75.0, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + CostPer1MIn: 0.25, + CostPer1MOut: 1.25, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + }, }, }, Options: Options{ @@ -199,9 +243,29 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "global-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "global-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + CostPer1MIn: 1.0, + CostPer1MOut: 2.0, + ContextWindow: 4096, + DefaultMaxTokens: 4096, + }, + }, }, }, Options: Options{ @@ -222,9 +286,29 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { APIKey: "local-key", // Override global }, provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "anthropic-key", - ProviderType: provider.TypeAnthropic, + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + CostPer1MIn: 15.0, + CostPer1MOut: 75.0, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + CostPer1MIn: 0.25, + CostPer1MOut: 1.25, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + }, }, }, Options: Options{ diff --git a/internal/config/provider.go b/internal/config/provider.go index 4c2b61ff6d5d86f62a8a1833a6ea91b500bbc7b0..09e3b0e3fc84b9e2688ccc4d2559604aca83ddfc 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -15,6 +15,8 @@ var fur = client.New() var ( providerOnc sync.Once // Ensures the initialization happens only once providerList []provider.Provider + // UseMockProviders can be set to true in tests to avoid API calls + UseMockProviders bool ) func providersPath() string { @@ -50,6 +52,12 @@ func loadProviders() ([]provider.Provider, error) { func Providers() []provider.Provider { providerOnc.Do(func() { + // Use mock providers when testing + if UseMockProviders { + providerList = MockProviders() + return + } + // Try to get providers from upstream API if providers, err := fur.GetProviders(); err == nil { providerList = providers @@ -67,3 +75,9 @@ func Providers() []provider.Provider { }) return providerList } + +// ResetProviders resets the provider cache. Useful for testing. +func ResetProviders() { + providerOnc = sync.Once{} + providerList = nil +} diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go new file mode 100644 index 0000000000000000000000000000000000000000..86b87768b95246654e176ca5f40af5aef249c23f --- /dev/null +++ b/internal/config/provider_mock.go @@ -0,0 +1,177 @@ +package config + +import ( + "github.com/charmbracelet/crush/internal/fur/provider" +) + +// MockProviders returns a mock list of providers for testing. +// This avoids making API calls during tests and provides consistent test data. +func MockProviders() []provider.Provider { + return []provider.Provider{ + { + Name: "Anthropic", + ID: provider.InferenceProviderAnthropic, + APIKey: "$ANTHROPIC_API_KEY", + APIEndpoint: "$ANTHROPIC_API_ENDPOINT", + Type: provider.TypeAnthropic, + DefaultLargeModelID: "claude-3-opus", + DefaultSmallModelID: "claude-3-haiku", + Models: []provider.Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + CostPer1MIn: 15.0, + CostPer1MOut: 75.0, + CostPer1MInCached: 18.75, + CostPer1MOutCached: 1.5, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + CostPer1MIn: 0.25, + CostPer1MOut: 1.25, + CostPer1MInCached: 0.3, + CostPer1MOutCached: 0.03, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, + }, + { + ID: "claude-3-5-sonnet-20241022", + Name: "Claude 3.5 Sonnet", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.3, + ContextWindow: 200000, + DefaultMaxTokens: 8192, + CanReason: false, + SupportsImages: true, + }, + { + ID: "claude-3-5-haiku-20241022", + Name: "Claude 3.5 Haiku", + CostPer1MIn: 0.8, + CostPer1MOut: 4.0, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + ContextWindow: 200000, + DefaultMaxTokens: 8192, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "OpenAI", + ID: provider.InferenceProviderOpenAI, + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "$OPENAI_API_ENDPOINT", + Type: provider.TypeOpenAI, + DefaultLargeModelID: "gpt-4", + DefaultSmallModelID: "gpt-3.5-turbo", + Models: []provider.Model{ + { + ID: "gpt-4", + Name: "GPT-4", + CostPer1MIn: 30.0, + CostPer1MOut: 60.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: false, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + CostPer1MIn: 1.0, + CostPer1MOut: 2.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 4096, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: false, + }, + { + ID: "gpt-4-turbo", + Name: "GPT-4 Turbo", + CostPer1MIn: 10.0, + CostPer1MOut: 30.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, + }, + { + ID: "gpt-4o", + Name: "GPT-4o", + CostPer1MIn: 2.5, + CostPer1MOut: 10.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 1.25, + ContextWindow: 128000, + DefaultMaxTokens: 16384, + CanReason: false, + SupportsImages: true, + }, + { + ID: "gpt-4o-mini", + Name: "GPT-4o-mini", + CostPer1MIn: 0.15, + CostPer1MOut: 0.6, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.075, + ContextWindow: 128000, + DefaultMaxTokens: 16384, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "Google Gemini", + ID: provider.InferenceProviderGemini, + APIKey: "$GEMINI_API_KEY", + APIEndpoint: "$GEMINI_API_ENDPOINT", + Type: provider.TypeGemini, + DefaultLargeModelID: "gemini-2.5-pro", + DefaultSmallModelID: "gemini-2.5-flash", + Models: []provider.Model{ + { + ID: "gemini-2.5-pro", + Name: "Gemini 2.5 Pro", + CostPer1MIn: 1.25, + CostPer1MOut: 10.0, + CostPer1MInCached: 1.625, + CostPer1MOutCached: 0.31, + ContextWindow: 1048576, + DefaultMaxTokens: 65536, + CanReason: true, + SupportsImages: true, + }, + { + ID: "gemini-2.5-flash", + Name: "Gemini 2.5 Flash", + CostPer1MIn: 0.3, + CostPer1MOut: 2.5, + CostPer1MInCached: 0.3833, + CostPer1MOutCached: 0.075, + ContextWindow: 1048576, + DefaultMaxTokens: 65535, + CanReason: true, + SupportsImages: true, + }, + }, + }, + } +} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go new file mode 100644 index 0000000000000000000000000000000000000000..92547ff2925699d8519c33656395d3979a095b35 --- /dev/null +++ b/internal/config/provider_test.go @@ -0,0 +1,105 @@ +package config + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMockProviders(t *testing.T) { + // Enable mock providers for testing + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + // Reset providers to ensure we get fresh mock data + ResetProviders() + + providers := Providers() + require.NotEmpty(t, providers, "Mock providers should not be empty") + + // Verify we have the expected mock providers + providerIDs := make(map[provider.InferenceProvider]bool) + for _, p := range providers { + providerIDs[p.ID] = true + } + + assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider") + assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider") + assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider") + + // Verify Anthropic provider details + var anthropicProvider provider.Provider + for _, p := range providers { + if p.ID == provider.InferenceProviderAnthropic { + anthropicProvider = p + break + } + } + + assert.Equal(t, "Anthropic", anthropicProvider.Name) + assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type) + assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID) + assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID) + assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models") + + // Verify model details + var opusModel provider.Model + for _, m := range anthropicProvider.Models { + if m.ID == "claude-3-opus" { + opusModel = m + break + } + } + + assert.Equal(t, "Claude 3 Opus", opusModel.Name) + assert.Equal(t, int64(200000), opusModel.ContextWindow) + assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens) + assert.True(t, opusModel.SupportsImages) +} + +func TestProvidersWithoutMock(t *testing.T) { + // Ensure mock is disabled + originalUseMock := UseMockProviders + UseMockProviders = false + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + // Reset providers to ensure we get fresh data + ResetProviders() + + // This will try to make an actual API call or use cached data + providers := Providers() + + // We can't guarantee what we'll get here since it depends on network/cache + // but we can at least verify the function doesn't panic + t.Logf("Got %d providers without mock", len(providers)) +} + +func TestResetProviders(t *testing.T) { + // Enable mock providers + UseMockProviders = true + defer func() { + UseMockProviders = false + ResetProviders() + }() + + // Get providers once + providers1 := Providers() + require.NotEmpty(t, providers1) + + // Reset and get again + ResetProviders() + providers2 := Providers() + require.NotEmpty(t, providers2) + + // Should get the same mock data + assert.Equal(t, len(providers1), len(providers2)) +} From d949a223d8b7e6eb6078661edccf68de6c2c80e3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:46:35 +0200 Subject: [PATCH 13/40] feat: add provider-specific prompts for OpenAI, Gemini, and Anthropic --- internal/llm/prompt/coder.go | 264 +++++++++++++++++++++++++++++++---- 1 file changed, 234 insertions(+), 30 deletions(-) diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 9f1e5e7c19e739167bb9ab2bd359218e88fd4367..523933d18e5c39ea766c42e1aafe09b5aaff3e63 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -15,10 +15,14 @@ import ( ) func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string { - basePrompt := baseAnthropicCoderPrompt + var basePrompt string switch p { case provider.InferenceProviderOpenAI: basePrompt = baseOpenAICoderPrompt + case provider.InferenceProviderGemini, provider.InferenceProviderVertexAI: + basePrompt = baseGeminiCoderPrompt + default: + basePrompt = baseAnthropicCoderPrompt } envInfo := getEnvironmentInfo() @@ -33,46 +37,91 @@ func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string { } const baseOpenAICoderPrompt = ` -You are operating as and within the Crush CLI, a terminal-based agentic coding assistant built by OpenAI. It wraps OpenAI models to enable natural language interaction with a local codebase. You are expected to be precise, safe, and helpful. - -You can: -- Receive user prompts, project context, and files. -- Stream responses and emit function calls (e.g., shell commands, code edits). +Please resolve the user's task by editing and testing the code files in your current code execution session. +You are a deployed coding agent. +Your session allows you to easily modify and run code in the user's local environment. +The repo(s) are already available in your working directory, and you must fully solve the problem for your answer to be considered correct. +IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. -You are an agent - please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. If you are not sure about file content or codebase structure pertaining to the user's request, use your tools to read files and gather the relevant information: do NOT guess or make up an answer. +# Memory +If the current working directory contains a file called CRUSH.md, it will be automatically added to your context. This file serves multiple purposes: +1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time +2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.) +3. Maintaining useful information about the codebase structure and organization -Please resolve the user's task by editing and testing the code files in your current code execution session. You are a deployed coding agent. Your session allows for you to modify and run code. The repo(s) are already cloned in your working directory, and you must fully solve the problem for your answer to be considered correct. +When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CRUSH.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CRUSH.md so you can remember it for next time. You MUST adhere to the following criteria when executing the task: + - Working on the repo(s) in the current environment is allowed, even if they are proprietary. - Analyzing code for vulnerabilities is allowed. - Showing user code and tool call details is allowed. -- User instructions may overwrite the *CODING GUIDELINES* section in this developer message. +- User instructions may overwrite the _CODING GUIDELINES_ section in this developer message. +- Do not use ` + "`ls -R`" + `, ` + "`find`" + `, or ` + "`grep`" + ` - these are slow in large repos. Use the Agent tool for searching instead. +- Use the ` + "`edit`" + ` tool to modify files: provide file_path, old_string (with sufficient context), and new_string. The edit tool requires: + - Absolute file paths (starting with /) + - Unique old_string matches with 3-5 lines of context before and after + - Exact whitespace and indentation matching + - For new files: provide file_path and new_string, leave old_string empty + - For deleting content: provide file_path and old_string, leave new_string empty + +# Following conventions +When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns. +- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language). +- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions. +- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic. +- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository. + +# Code style +- Do not add comments to the code you write, unless the user asks you to, or the code is complex and requires additional context. + - If completing the user's task requires writing or modifying files: - - Your code and final answer should follow these *CODING GUIDELINES*: - - Fix the problem at the root cause rather than applying surface-level patches, when possible. - - Avoid unneeded complexity in your solution. - - Ignore unrelated bugs or broken tests; it is not your responsibility to fix them. - - Update documentation as necessary. - - Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task. - - Use "git log" and "git blame" to search the history of the codebase if additional context is required; internet access is disabled. - - NEVER add copyright or license headers unless specifically requested. - - You do not need to "git commit" your changes; this will be done automatically for you. - - Once you finish coding, you must - - Check "git status" to sanity check your changes; revert any scratch files or changes. - - Remove all inline comments you added as much as possible, even if they look normal. Check using "git diff". Inline comments must be generally avoided, unless active maintainers of the repo, after long careful study of the code and the issue, will still misinterpret the code without the comments. - - Check if you accidentally add copyright or license headers. If so, remove them. - - For smaller tasks, describe in brief bullet points - - For more complex tasks, include brief high-level description, use bullet points, and include details that would be relevant to a code reviewer. + - Your code and final answer should follow these _CODING GUIDELINES_: + - Fix the problem at the root cause rather than applying surface-level patches, when possible. + - Avoid unneeded complexity in your solution. + - Ignore unrelated bugs or broken tests; it is not your responsibility to fix them. + - Update documentation as necessary. + - Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task. + - Use ` + "`git log`" + ` and ` + "`git blame`" + ` to search the history of the codebase if additional context is required. + - NEVER add copyright or license headers unless specifically requested. + - You do not need to ` + "`git commit`" + ` your changes; this will be done automatically for you. + - If there is a .pre-commit-config.yaml, use ` + "`pre-commit run --files ...`" + ` to check that your changes pass the pre-commit checks. However, do not fix pre-existing errors on lines you didn't touch. + - If pre-commit doesn't work after a few retries, politely inform the user that the pre-commit setup is broken. + - Once you finish coding, you must + - Check ` + "`git status`" + ` to sanity check your changes; revert any scratch files or changes. + - Remove all inline comments you added as much as possible, even if they look normal. Check using ` + "`git diff`" + `. Inline comments must be generally avoided, unless active maintainers of the repo, after long careful study of the code and the issue, will still misinterpret the code without the comments. + - Check if you accidentally add copyright or license headers. If so, remove them. + - Try to run pre-commit if it is available. + - For smaller tasks, describe in brief bullet points + - For more complex tasks, include brief high-level description, use bullet points, and include details that would be relevant to a code reviewer. + +# Doing tasks +The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended: +1. Use the available search tools to understand the codebase and the user's query. +2. Implement the solution using all tools available to you +3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach. +4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to CRUSH.md so that you will know to run it next time. + +NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. + +# Tool usage policy +- When doing file search, prefer to use the Agent tool in order to reduce context usage. +- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in parallel. +- IMPORTANT: The user does not see the full output of the tool responses, so if you need the output of the tool for the response make sure to summarize it for the user. + +# Proactiveness +You are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between: +1. Doing the right thing when asked, including taking actions and follow-up actions +2. Not surprising the user with actions you take without asking +For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions. +3. Do not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did. + - If completing the user's task DOES NOT require writing or modifying files (e.g., the user asks a question about the code base): - - Respond in a friendly tune as a remote teammate, who is knowledgeable, capable and eager to help with coding. + - Respond in a friendly tone as a remote teammate, who is knowledgeable, capable and eager to help with coding. - When your task involves writing or modifying files: - - Do NOT tell the user to "save the file" or "copy the code into a file" if you already created or modified the file using "edit/write". Instead, reference the file as already saved. - - Do NOT show the full contents of large files you have already written, unless the user explicitly asks for them. -- When doing things with paths, always use use the full path, if the working directory is /abc/xyz and you want to edit the file abc.go in the working dir refer to it as /abc/xyz/abc.go. -- If you send a path not including the working dir, the working dir will be prepended to it. -- Remember the user does not see the full output of tools + - Do NOT tell the user to "save the file" or "copy the code into a file" if you already created or modified the file using ` + "`edit`" + `. Instead, reference the file as already saved. + - Do NOT show the full contents of large files you have already written, unless the user explicitly asks for them. - NEVER use emojis in your responses ` @@ -175,6 +224,161 @@ VERY IMPORTANT NEVER use emojis in your responses. You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.` +const baseGeminiCoderPrompt = ` +You are an interactive CLI agent specializing in software engineering tasks. Your primary goal is to help users safely and efficiently, adhering strictly to the following instructions and utilizing your available tools. + +IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. + +# Memory +If the current working directory contains a file called CRUSH.md, it will be automatically added to your context. This file serves multiple purposes: +1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time +2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.) +3. Maintaining useful information about the codebase structure and organization + +When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CRUSH.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CRUSH.md so you can remember it for next time. + +# Core Mandates + +- **Conventions:** Rigorously adhere to existing project conventions when reading or modifying code. Analyze surrounding code, tests, and configuration first. +- **Libraries/Frameworks:** NEVER assume a library/framework is available or appropriate. Verify its established usage within the project (check imports, configuration files like 'package.json', 'Cargo.toml', 'requirements.txt', 'build.gradle', etc., or observe neighboring files) before employing it. +- **Style & Structure:** Mimic the style (formatting, naming), structure, framework choices, typing, and architectural patterns of existing code in the project. +- **Idiomatic Changes:** When editing, understand the local context (imports, functions/classes) to ensure your changes integrate naturally and idiomatically. +- **Comments:** Add code comments sparingly. Focus on *why* something is done, especially for complex logic, rather than *what* is done. Only add high-value comments if necessary for clarity or if requested by the user. Do not edit comments that are separate from the code you are changing. *NEVER* talk to the user or describe your changes through comments. +- **Proactiveness:** Fulfill the user's request thoroughly, including reasonable, directly implied follow-up actions. +- **Confirm Ambiguity/Expansion:** Do not take significant actions beyond the clear scope of the request without confirming with the user. If asked *how* to do something, explain first, don't just do it. +- **Explaining Changes:** After completing a code modification or file operation *do not* provide summaries unless asked. +- **Do Not revert changes:** Do not revert changes to the codebase unless asked to do so by the user. Only revert changes made by you if they have resulted in an error or if the user has explicitly asked you to revert the changes. + +# Primary Workflows + +## Software Engineering Tasks +When requested to perform tasks like fixing bugs, adding features, refactoring, or explaining code, follow this sequence: +1. **Understand:** Think about the user's request and the relevant codebase context. Use ` + "`grep`" + ` and ` + "`glob`" + ` search tools extensively (in parallel if independent) to understand file structures, existing code patterns, and conventions. Use ` + "`view`" + ` to understand context and validate any assumptions you may have. +2. **Plan:** Build a coherent and grounded (based on the understanding in step 1) plan for how you intend to resolve the user's task. Share an extremely concise yet clear plan with the user if it would help the user understand your thought process. As part of the plan, you should try to use a self-verification loop by writing unit tests if relevant to the task. Use output logs or debug statements as part of this self verification loop to arrive at a solution. +3. **Implement:** Use the available tools (e.g., ` + "`edit`" + `, ` + "`write`" + ` ` + "`bash`" + ` ...) to act on the plan, strictly adhering to the project's established conventions (detailed under 'Core Mandates'). +4. **Verify (Tests):** If applicable and feasible, verify the changes using the project's testing procedures. Identify the correct test commands and frameworks by examining 'README' files, build/package configuration (e.g., 'package.json'), or existing test execution patterns. NEVER assume standard test commands. +5. **Verify (Standards):** VERY IMPORTANT: After making code changes, execute the project-specific build, linting and type-checking commands (e.g., 'tsc', 'npm run lint', 'ruff check .') that you have identified for this project (or obtained from the user). This ensures code quality and adherence to standards. If unsure about these commands, you can ask the user if they'd like you to run them and if so how to. + +NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. + +# Operational Guidelines + +## Tone and Style (CLI Interaction) +- **Concise & Direct:** Adopt a professional, direct, and concise tone suitable for a CLI environment. +- **Minimal Output:** Aim for fewer than 3 lines of text output (excluding tool use/code generation) per response whenever practical. Focus strictly on the user's query. +- **Clarity over Brevity (When Needed):** While conciseness is key, prioritize clarity for essential explanations or when seeking necessary clarification if a request is ambiguous. +- **No Chitchat:** Avoid conversational filler, preambles ("Okay, I will now..."), or postambles ("I have finished the changes..."). Get straight to the action or answer. +- **Formatting:** Use GitHub-flavored Markdown. Responses will be rendered in monospace. +- **Tools vs. Text:** Use tools for actions, text output *only* for communication. Do not add explanatory comments within tool calls or code blocks unless specifically part of the required code/command itself. +- **Handling Inability:** If unable/unwilling to fulfill a request, state so briefly (1-2 sentences) without excessive justification. Offer alternatives if appropriate. + +## Security and Safety Rules +- **Explain Critical Commands:** Before executing commands with ` + "`bash`" + ` that modify the file system, codebase, or system state, you *must* provide a brief explanation of the command's purpose and potential impact. Prioritize user understanding and safety. +- **Security First:** Always apply security best practices. Never introduce code that exposes, logs, or commits secrets, API keys, or other sensitive information. + +## Tool Usage +- **File Paths:** Always use absolute paths when referring to files with tools like ` + "`view`" + ` or ` + "`write`" + `. Relative paths are not supported. You must provide an absolute path. +- **Parallelism:** Execute multiple independent tool calls in parallel when feasible (i.e. searching the codebase). +- **Command Execution:** Use the ` + "`bash`" + ` tool for running shell commands, remembering the safety rule to explain modifying commands first. +- **Background Processes:** Use background processes (via ` + "`&`" + `) for commands that are unlikely to stop on their own, e.g. ` + "`node server.js &`" + `. If unsure, ask the user. +- **Interactive Commands:** Try to avoid shell commands that are likely to require user interaction (e.g. ` + "`git rebase -i`" + `). Use non-interactive versions of commands (e.g. ` + "`npm init -y`" + ` instead of ` + "`npm init`" + `) when available, and otherwise remind the user that interactive shell commands are not supported and may cause hangs until canceled by the user. + +# Examples (Illustrating Tone and Workflow) + + +user: 1 + 2 +model: 3 + + + +user: is 13 a prime number? +model: true + + + +user: list files here. +model: [tool_call: ls for path '.'] + + + +user: start the server implemented in server.js +model: [tool_call: bash for 'node server.js &' because it must run in the background] + + + +user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib. +model: Okay, I can refactor 'src/auth.py'. +First, I'll analyze the code and check for a test safety net before planning any changes. +[tool_call: glob for pattern 'tests/test_auth.py'] +[tool_call: view for absolute_path '/path/to/tests/test_auth.py'] +(After analysis) +Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring. +I'll also confirm 'requests' is a dependency. +[tool_call: view for absolute_path '/path/to/requirements.txt'] +(After analysis) +Looks good, 'requests' is available. + +Here's the plan: +1. Replace the 'urllib' calls with 'requests'. +2. Add proper 'try...except' error handling for the new network calls. +3. Remove the old 'urllib' import. +4. Run the project's linter and tests to verify the changes. + +Should I proceed? +user: Yes +model: +[tool_call: edit to apply the refactoring to 'src/auth.py'] +Refactoring complete. Running verification... +[tool_call: bash for 'ruff check src/auth.py && pytest'] +(After verification passes) +All checks passed. This is a stable checkpoint. + + + +user: Delete the temp directory. +model: I can run ` + "`rm -rf ./temp`" + `. This will permanently delete the directory and all its contents. + + + +user: Write tests for someFile.ts +model: +Okay, I can write those tests. First, I'll read ` + "`someFile.ts`" + ` to understand its functionality. +[tool_call: view for absolute_path '/path/to/someFile.ts' or use glob to find ` + "`someFile.ts`" + ` if its location is unknown] +Now I'll look for existing or related test files to understand current testing conventions and dependencies. +[tool_call: glob for pattern '**/*.test.ts' and '**/*.spec.ts'] +(After reviewing existing tests and the file content) +[tool_call: write to create /path/to/someFile.test.ts with the test code] +I've written the tests. Now I'll run the project's test command to verify them. +[tool_call: bash for 'npm run test'] + + + +user: How do I update the user's profile information in this system? +model: +I'm not immediately sure how user profile information is updated. I'll search the codebase for terms like 'UserProfile', 'updateProfile', or 'editUser' to find relevant files or API endpoints. +[tool_call: grep for pattern 'UserProfile|updateProfile|editUser'] +(After reviewing search results, assuming a relevant file like '/path/to/UserProfileService.java' was found) +Okay, ` + "`/path/to/UserProfileService.java`" + ` seems like the most relevant file. I'll read its content to understand how updates are handled. +[tool_call: view for absolute_path '/path/to/UserProfileService.java'] +(After reading the file) +It appears the ` + "`updateUserProfile`" + ` method in ` + "`UserProfileService.java`" + ` is responsible for this. It expects a user ID and a ` + "`UserProfileDTO`" + ` object... + + + +user: Where are all the 'app.config' files in this project? I need to check their settings. +model: +[tool_call: glob for pattern '**/app.config'] +(Assuming glob returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config']) +I found the following 'app.config' files: +- /path/to/moduleA/app.config +- /path/to/moduleB/app.config +To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them? + + +# Final Reminder +Your core function is efficient and safe assistance. Balance extreme conciseness with the crucial need for clarity, especially regarding safety and potential system modifications. Always prioritize user control and project conventions. Never make assumptions about the contents of files; instead use ` + "`view`" + ` to ensure you aren't making broad assumptions. Finally, you are an agent - please keep going until the user's query is completely resolved. +` + func getEnvironmentInfo() string { cwd := config.WorkingDirectory() isGit := isGitRepo(cwd) From 135aceb72f4f4de1694c22f653d95fd6c5015745 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:46:51 +0200 Subject: [PATCH 14/40] feat: add dynamic model switching with agent provider updates --- internal/app/app.go | 4 ++ internal/llm/agent/agent.go | 98 +++++++++++++++++++++++++++++++++++++ internal/tui/tui.go | 7 +++ 3 files changed, 109 insertions(+) diff --git a/internal/app/app.go b/internal/app/app.go index b096c1b4f5612901a1cedeaa2ee758b666cda517..6dd1b9916d593c6f0e053aaef6714723f8fd5c60 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -168,3 +168,7 @@ func (app *App) Shutdown() { } app.CoderAgent.CancelAll() } + +func (app *App) UpdateAgentModel() error { + return app.CoderAgent.UpdateModel() +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 8c6faf8c4a06bbef5da279847cd14ce2314648cd..8312b0f8965a5d02f7ce049abff50953cc56e422 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -56,6 +56,7 @@ type Service interface { IsSessionBusy(sessionID string) bool IsBusy() bool Summarize(ctx context.Context, sessionID string) error + UpdateModel() error } type agent struct { @@ -805,3 +806,100 @@ func (a *agent) CancelAll() { return true }) } + +func (a *agent) UpdateModel() error { + cfg := config.Get() + + // Get current provider configuration + currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID) + if currentProviderCfg.ID == "" { + return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name) + } + + // Check if provider has changed + if string(currentProviderCfg.ID) != a.providerID { + // Provider changed, need to recreate the main provider + model := config.GetAgentModel(a.agentCfg.ID) + if model.ID == "" { + return fmt.Errorf("model not found for agent %s", a.agentCfg.Name) + } + + promptID := agentPromptMap[a.agentCfg.ID] + if promptID == "" { + promptID = prompt.PromptDefault + } + + opts := []provider.ProviderClientOption{ + provider.WithModel(a.agentCfg.Model), + provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), + provider.WithMaxTokens(model.DefaultMaxTokens), + } + + newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...) + if err != nil { + return fmt.Errorf("failed to create new provider: %w", err) + } + + // Update the provider and provider ID + a.provider = newProvider + a.providerID = string(currentProviderCfg.ID) + } + + // Check if small model provider has changed (affects title and summarize providers) + smallModelCfg := cfg.Models.Small + var smallModelProviderCfg config.ProviderConfig + + for _, p := range cfg.Providers { + if p.ID == smallModelCfg.Provider { + smallModelProviderCfg = p + break + } + } + + if smallModelProviderCfg.ID == "" { + return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) + } + + // Check if summarize provider has changed + if string(smallModelProviderCfg.ID) != a.summarizeProviderID { + var smallModel config.Model + for _, m := range smallModelProviderCfg.Models { + if m.ID == smallModelCfg.ModelID { + smallModel = m + break + } + } + if smallModel.ID == "" { + return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID) + } + + // Recreate title provider + titleOpts := []provider.ProviderClientOption{ + provider.WithModel(config.SmallModel), + provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), + provider.WithMaxTokens(40), + } + newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) + if err != nil { + return fmt.Errorf("failed to create new title provider: %w", err) + } + + // Recreate summarize provider + summarizeOpts := []provider.ProviderClientOption{ + provider.WithModel(config.SmallModel), + provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), + provider.WithMaxTokens(smallModel.DefaultMaxTokens), + } + newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) + if err != nil { + return fmt.Errorf("failed to create new summarize provider: %w", err) + } + + // Update the providers and provider ID + a.titleProvider = newTitleProvider + a.summarizeProvider = newSummarizeProvider + a.summarizeProviderID = string(smallModelProviderCfg.ID) + } + + return nil +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 032b481eeaad75531debe7dc453efe19b866dd8d..e3c974ca002529ce1ac90f420afcc5eedf2a45fd 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -173,6 +173,13 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Model Switch case models.ModelSelectedMsg: config.UpdatePreferredModel(config.LargeModel, msg.Model) + + // Update the agent with the new model/provider configuration + if err := a.app.UpdateAgentModel(); err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to update agent model: %v", err)) + return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err)) + } + return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", msg.Model.ModelID)) // File Picker From 1649dd0aa25eb8aecf8a5b2e263c5cca482631b0 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:47:01 +0200 Subject: [PATCH 15/40] fix: correct model selection index in models dialog --- internal/tui/components/dialogs/models/models.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 6d5fa155b2371865771b55c16f8fdbf65d3df952..730cc5d5693dd06821c1ff30cb26398b14325c62 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -97,13 +97,13 @@ func (m *modelDialogCmp) Init() tea.Cmd { } modelItems = append(modelItems, commands.NewItemSection(name)) for _, model := range provider.Models { - if model.ID == agentModel.ID && provider.ID == agentProvider.ID { - selectIndex = len(modelItems) // Set the selected index to the current model - } modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ Provider: provider, Model: model, })) + if model.ID == agentModel.ID && provider.ID == agentProvider.ID { + selectIndex = len(modelItems) - 1 // Set the selected index to the current model + } } } From e3b55cebdfc2f2917592097b9bab3a1d49c30f95 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:47:09 +0200 Subject: [PATCH 16/40] docs: add testing guidelines for mock providers --- CRUSH.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/CRUSH.md b/CRUSH.md index e2f6053c110f317ee85113a2f08359673342e645..6ab75f143258e63d4881bf17447304b7d1ca6ba5 100644 --- a/CRUSH.md +++ b/CRUSH.md @@ -24,6 +24,34 @@ - **File permissions**: Use octal notation (0o755, 0o644) for file permissions - **Comments**: End comments in periods unless comments are at the end of the line. +## Testing with Mock Providers + +When writing tests that involve provider configurations, use the mock providers to avoid API calls: + +```go +func TestYourFunction(t *testing.T) { + // Enable mock providers for testing + originalUseMock := config.UseMockProviders + config.UseMockProviders = true + defer func() { + config.UseMockProviders = originalUseMock + config.ResetProviders() + }() + + // Reset providers to ensure fresh mock data + config.ResetProviders() + + // Your test code here - providers will now return mock data + providers := config.Providers() + // ... test logic +} +``` + +The mock providers include: +- **Anthropic**: claude-3-opus, claude-3-haiku, claude-3-5-sonnet-20241022, claude-3-5-haiku-20241022 +- **OpenAI**: gpt-4, gpt-3.5-turbo, gpt-4-turbo, gpt-4o, gpt-4o-mini +- **Gemini**: gemini-2.5-pro, gemini-2.5-flash + ## Formatting - ALWAYS format any Go code you write. From df9b100449184c289665121406e68483637e3c09 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:48:30 +0200 Subject: [PATCH 17/40] chore: add new fields to provider model --- internal/fur/provider/provider.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go index 8545694dea70b410a3a1912b82313bde2852d942..e3c0f6209cbe71c239da104b38c3022e090599aa 100644 --- a/internal/fur/provider/provider.go +++ b/internal/fur/provider/provider.go @@ -44,16 +44,18 @@ type Provider struct { // Model represents an AI model configuration. type Model struct { - ID string `json:"id"` - Name string `json:"model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - SupportsImages bool `json:"supports_attachments"` + ID string `json:"id"` + Name string `json:"model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + HasReasoningEffort bool `json:"has_reasoning_efforts"` + DefaultReasoningEffort string `json:"default_reasoning_effort,omitempty"` + SupportsImages bool `json:"supports_attachments"` } // KnownProviders returns all the known inference providers. From 12d0f1f237639ca7f6da5e81b8d59a20d08cbbf9 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 18:23:28 +0200 Subject: [PATCH 18/40] feat: add reasoning effort support and new providers - Add HasReasoningEffort field to Model struct with proper JSON field name - Add ReasoningEffort field to PreferredModel for user customization - Implement reasoning effort transfer from provider models to config - Add OpenAI o1-preview and o1-mini models with reasoning capabilities - Add xAI provider with Grok Beta model support - Add OpenRouter provider with Claude 3.5 models - Update OpenAI client to use configurable reasoning effort - Remove unused toolCalls method from Gemini client - Simplify validation logic using slices.Contains - Add comprehensive tests for reasoning effort functionality --- internal/config/config.go | 40 +++---- internal/config/provider_mock.go | 80 ++++++++++++++ internal/config/provider_test.go | 175 +++++++++++++++++++++++++++++++ internal/llm/provider/gemini.go | 21 ---- internal/llm/provider/openai.go | 24 +++-- 5 files changed, 283 insertions(+), 57 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index e33aab02a492e8a1a4c55554fe5a3656d101ec1e..f6814db44cdadefd0e88e57e2bcd2521bf8a3c28 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -65,6 +65,7 @@ type Model struct { DefaultMaxTokens int64 `json:"default_max_tokens"` CanReason bool `json:"can_reason"` ReasoningEffort string `json:"reasoning_effort"` + HasReasoningEffort bool `json:"has_reasoning_effort"` SupportsImages bool `json:"supports_attachments"` } @@ -156,8 +157,9 @@ type Options struct { } type PreferredModel struct { - ModelID string `json:"model_id"` - Provider provider.InferenceProvider `json:"provider"` + ModelID string `json:"model_id"` + Provider provider.InferenceProvider `json:"provider"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` } type PreferredModels struct { @@ -693,7 +695,7 @@ func defaultConfigBasedOnEnv() *Config { } providerConfig.BaseURL = baseURL for _, model := range p.Models { - providerConfig.Models = append(providerConfig.Models, Model{ + configModel := Model{ ID: model.ID, Name: model.Name, CostPer1MIn: model.CostPer1MIn, @@ -704,7 +706,13 @@ func defaultConfigBasedOnEnv() *Config { DefaultMaxTokens: model.DefaultMaxTokens, CanReason: model.CanReason, SupportsImages: model.SupportsImages, - }) + } + // Set reasoning effort for reasoning models + if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { + configModel.HasReasoningEffort = model.HasReasoningEffort + configModel.ReasoningEffort = model.DefaultReasoningEffort + } + providerConfig.Models = append(providerConfig.Models, configModel) } cfg.Providers[p.ID] = providerConfig } @@ -980,25 +988,13 @@ func (c *Config) validateProviders(errors *ValidationErrors) { } // Validate provider type - validType := false - for _, vt := range validTypes { - if providerConfig.ProviderType == vt { - validType = true - break - } - } + validType := slices.Contains(validTypes, providerConfig.ProviderType) if !validType { errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType)) } // Validate custom providers - isKnownProvider := false - for _, kp := range knownProviders { - if providerID == kp { - isKnownProvider = true - break - } - } + isKnownProvider := slices.Contains(knownProviders, providerID) if !isKnownProvider { // Custom provider validation @@ -1200,13 +1196,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) { // Validate allowed tools if agent.AllowedTools != nil { for i, tool := range agent.AllowedTools { - validTool := false - for _, vt := range validTools { - if tool == vt { - validTool = true - break - } - } + validTool := slices.Contains(validTools, tool) if !validTool { errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool)) } diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go index 86b87768b95246654e176ca5f40af5aef249c23f..af92cc2c33f0b0adbe65dbd728b29727c35aeaa8 100644 --- a/internal/config/provider_mock.go +++ b/internal/config/provider_mock.go @@ -136,6 +136,34 @@ func MockProviders() []provider.Provider { CanReason: false, SupportsImages: true, }, + { + ID: "o1-preview", + Name: "o1-preview", + CostPer1MIn: 15.0, + CostPer1MOut: 60.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 32768, + CanReason: true, + HasReasoningEffort: true, + DefaultReasoningEffort: "medium", + SupportsImages: true, + }, + { + ID: "o1-mini", + Name: "o1-mini", + CostPer1MIn: 3.0, + CostPer1MOut: 12.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 65536, + CanReason: true, + HasReasoningEffort: true, + DefaultReasoningEffort: "medium", + SupportsImages: true, + }, }, }, { @@ -173,5 +201,57 @@ func MockProviders() []provider.Provider { }, }, }, + { + Name: "xAI", + ID: provider.InferenceProviderXAI, + APIKey: "$XAI_API_KEY", + APIEndpoint: "https://api.x.ai/v1", + Type: provider.TypeXAI, + DefaultLargeModelID: "grok-beta", + DefaultSmallModelID: "grok-beta", + Models: []provider.Model{ + { + ID: "grok-beta", + Name: "Grok Beta", + CostPer1MIn: 5.0, + CostPer1MOut: 15.0, + ContextWindow: 131072, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "OpenRouter", + ID: provider.InferenceProviderOpenRouter, + APIKey: "$OPENROUTER_API_KEY", + APIEndpoint: "https://openrouter.ai/api/v1", + Type: provider.TypeOpenAI, + DefaultLargeModelID: "anthropic/claude-3.5-sonnet", + DefaultSmallModelID: "anthropic/claude-3.5-haiku", + Models: []provider.Model{ + { + ID: "anthropic/claude-3.5-sonnet", + Name: "Claude 3.5 Sonnet", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + ContextWindow: 200000, + DefaultMaxTokens: 8192, + CanReason: false, + SupportsImages: true, + }, + { + ID: "anthropic/claude-3.5-haiku", + Name: "Claude 3.5 Haiku", + CostPer1MIn: 0.8, + CostPer1MOut: 4.0, + ContextWindow: 200000, + DefaultMaxTokens: 8192, + CanReason: false, + SupportsImages: true, + }, + }, + }, } } diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 92547ff2925699d8519c33656395d3979a095b35..b175107d0df2bfaabc29e88550dc89471baf5188 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,6 +1,7 @@ package config import ( + "encoding/json" "testing" "github.com/charmbracelet/crush/internal/fur/provider" @@ -103,3 +104,177 @@ func TestResetProviders(t *testing.T) { // Should get the same mock data assert.Equal(t, len(providers1), len(providers2)) } + +func TestReasoningEffortSupport(t *testing.T) { + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + ResetProviders() + providers := Providers() + + var openaiProvider provider.Provider + for _, p := range providers { + if p.ID == provider.InferenceProviderOpenAI { + openaiProvider = p + break + } + } + require.NotEmpty(t, openaiProvider.ID) + + var reasoningModel, nonReasoningModel provider.Model + for _, model := range openaiProvider.Models { + if model.CanReason && model.HasReasoningEffort { + reasoningModel = model + } else if !model.CanReason { + nonReasoningModel = model + } + } + + require.NotEmpty(t, reasoningModel.ID) + assert.Equal(t, "medium", reasoningModel.DefaultReasoningEffort) + assert.True(t, reasoningModel.HasReasoningEffort) + + require.NotEmpty(t, nonReasoningModel.ID) + assert.False(t, nonReasoningModel.HasReasoningEffort) + assert.Empty(t, nonReasoningModel.DefaultReasoningEffort) +} + +func TestReasoningEffortConfigTransfer(t *testing.T) { + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + ResetProviders() + t.Setenv("OPENAI_API_KEY", "test-openai-key") + + cfg, err := Init(t.TempDir(), false) + require.NoError(t, err) + + openaiProviderConfig, exists := cfg.Providers[provider.InferenceProviderOpenAI] + require.True(t, exists) + + var foundReasoning, foundNonReasoning bool + for _, model := range openaiProviderConfig.Models { + if model.CanReason && model.HasReasoningEffort && model.ReasoningEffort != "" { + assert.Equal(t, "medium", model.ReasoningEffort) + assert.True(t, model.HasReasoningEffort) + foundReasoning = true + } else if !model.CanReason { + assert.Empty(t, model.ReasoningEffort) + assert.False(t, model.HasReasoningEffort) + foundNonReasoning = true + } + } + + assert.True(t, foundReasoning, "Should find at least one reasoning model") + assert.True(t, foundNonReasoning, "Should find at least one non-reasoning model") +} + +func TestNewProviders(t *testing.T) { + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + ResetProviders() + providers := Providers() + require.NotEmpty(t, providers) + + var xaiProvider, openRouterProvider provider.Provider + for _, p := range providers { + switch p.ID { + case provider.InferenceProviderXAI: + xaiProvider = p + case provider.InferenceProviderOpenRouter: + openRouterProvider = p + } + } + + require.NotEmpty(t, xaiProvider.ID) + assert.Equal(t, "xAI", xaiProvider.Name) + assert.Equal(t, "grok-beta", xaiProvider.DefaultLargeModelID) + + require.NotEmpty(t, openRouterProvider.ID) + assert.Equal(t, "OpenRouter", openRouterProvider.Name) + assert.Equal(t, "anthropic/claude-3.5-sonnet", openRouterProvider.DefaultLargeModelID) +} + +func TestO1ModelsInMockProvider(t *testing.T) { + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + ResetProviders() + providers := Providers() + + var openaiProvider provider.Provider + for _, p := range providers { + if p.ID == provider.InferenceProviderOpenAI { + openaiProvider = p + break + } + } + require.NotEmpty(t, openaiProvider.ID) + + modelTests := []struct { + id string + name string + }{ + {"o1-preview", "o1-preview"}, + {"o1-mini", "o1-mini"}, + } + + for _, test := range modelTests { + var model provider.Model + var found bool + for _, m := range openaiProvider.Models { + if m.ID == test.id { + model = m + found = true + break + } + } + require.True(t, found, "Should find %s model", test.id) + assert.Equal(t, test.name, model.Name) + assert.True(t, model.CanReason) + assert.True(t, model.HasReasoningEffort) + assert.Equal(t, "medium", model.DefaultReasoningEffort) + } +} + +func TestPreferredModelReasoningEffort(t *testing.T) { + // Test that PreferredModel struct can hold reasoning effort + preferredModel := PreferredModel{ + ModelID: "o1-preview", + Provider: provider.InferenceProviderOpenAI, + ReasoningEffort: "high", + } + + assert.Equal(t, "o1-preview", preferredModel.ModelID) + assert.Equal(t, provider.InferenceProviderOpenAI, preferredModel.Provider) + assert.Equal(t, "high", preferredModel.ReasoningEffort) + + // Test JSON marshaling/unmarshaling + jsonData, err := json.Marshal(preferredModel) + require.NoError(t, err) + + var unmarshaled PreferredModel + err = json.Unmarshal(jsonData, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, preferredModel.ModelID, unmarshaled.ModelID) + assert.Equal(t, preferredModel.Provider, unmarshaled.Provider) + assert.Equal(t, preferredModel.ReasoningEffort, unmarshaled.ReasoningEffort) +} diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index a5c012861ad9e6b537c0e9bca8e957ef3f38bf2f..3531c5cc89cced262d5c22f2598216d9cfe4883e 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -407,27 +407,6 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) return true, int64(retryMs), nil } -func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall { - var toolCalls []message.ToolCall - - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - if part.FunctionCall != nil { - id := "call_" + uuid.New().String() - args, _ := json.Marshal(part.FunctionCall.Args) - toolCalls = append(toolCalls, message.ToolCall{ - ID: id, - Name: part.FunctionCall.Name, - Input: string(args), - Type: "function", - }) - } - } - } - - return toolCalls -} - func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { if resp == nil || resp.UsageMetadata == nil { return TokenUsage{} diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 9af060a80f75309e1e314e3c33df72e607c9c77a..f6aaacb0ce09ae665fab4bdbb14b28e13e7684c7 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -18,23 +18,14 @@ import ( "github.com/openai/openai-go/shared" ) -type openaiOptions struct { - reasoningEffort string -} - type openaiClient struct { providerOptions providerClientOptions - options openaiOptions client openai.Client } type OpenAIClient ProviderClient func newOpenAIClient(opts providerClientOptions) OpenAIClient { - openaiOpts := openaiOptions{ - reasoningEffort: "medium", - } - openaiClientOptions := []option.RequestOption{} if opts.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) @@ -52,7 +43,6 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient { client := openai.NewClient(openaiClientOptions...) return &openaiClient{ providerOptions: opts, - options: openaiOpts, client: client, } } @@ -153,6 +143,18 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { model := o.providerOptions.model(o.providerOptions.modelType) + cfg := config.Get() + + modelConfig := cfg.Models.Large + if o.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + + reasoningEffort := model.ReasoningEffort + if modelConfig.ReasoningEffort != "" { + reasoningEffort = modelConfig.ReasoningEffort + } + params := openai.ChatCompletionNewParams{ Model: openai.ChatModel(model.ID), Messages: messages, @@ -160,7 +162,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar } if model.CanReason { params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) - switch o.options.reasoningEffort { + switch reasoningEffort { case "low": params.ReasoningEffort = shared.ReasoningEffortLow case "medium": From 1e18320f828889af1233505ec34489d484be21a3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 18:26:14 +0200 Subject: [PATCH 19/40] chore: fix lint --- internal/config/provider_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index b175107d0df2bfaabc29e88550dc89471baf5188..70224d194a6689d85602d6a0f7d92e03b02fa1b2 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -115,7 +115,7 @@ func TestReasoningEffortSupport(t *testing.T) { ResetProviders() providers := Providers() - + var openaiProvider provider.Provider for _, p := range providers { if p.ID == provider.InferenceProviderOpenAI { @@ -218,7 +218,7 @@ func TestO1ModelsInMockProvider(t *testing.T) { ResetProviders() providers := Providers() - + var openaiProvider provider.Provider for _, p := range providers { if p.ID == provider.InferenceProviderOpenAI { From eb3550a2bb26f9e4199a6339a3144dcd898dacbe Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 28 Jun 2025 11:56:46 +0200 Subject: [PATCH 20/40] chore: change how max tokens works --- internal/config/config.go | 9 ++++++--- internal/llm/agent/agent.go | 6 ------ internal/llm/provider/anthropic.go | 12 +++++++++++- internal/llm/provider/gemini.go | 26 +++++++++++++++++++++----- internal/llm/provider/openai.go | 9 +++++++-- internal/llm/provider/provider.go | 7 ------- 6 files changed, 45 insertions(+), 24 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index f6814db44cdadefd0e88e57e2bcd2521bf8a3c28..32ca8729295bb3994af27ec4359a1b4960527671 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -157,9 +157,12 @@ type Options struct { } type PreferredModel struct { - ModelID string `json:"model_id"` - Provider provider.InferenceProvider `json:"provider"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` + ModelID string `json:"model_id"` + Provider provider.InferenceProvider `json:"provider"` + // Overrides the default reasoning effort for this model + ReasoningEffort string `json:"reasoning_effort,omitempty"` + // Overrides the default max tokens for this model + MaxTokens int64 `json:"max_tokens,omitempty"` } type PreferredModels struct { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 8312b0f8965a5d02f7ce049abff50953cc56e422..5f3d41c2eee4cc41df159066379346bad4a97dfc 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -147,7 +147,6 @@ func NewAgent( opts := []provider.ProviderClientOption{ provider.WithModel(agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), - provider.WithMaxTokens(model.DefaultMaxTokens), } agentProvider, err := provider.NewProviderV2(providerCfg, opts...) if err != nil { @@ -184,7 +183,6 @@ func NewAgent( titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - provider.WithMaxTokens(40), } titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) if err != nil { @@ -193,7 +191,6 @@ func NewAgent( summarizeOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), - provider.WithMaxTokens(smallModel.DefaultMaxTokens), } summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) if err != nil { @@ -832,7 +829,6 @@ func (a *agent) UpdateModel() error { opts := []provider.ProviderClientOption{ provider.WithModel(a.agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), - provider.WithMaxTokens(model.DefaultMaxTokens), } newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...) @@ -877,7 +873,6 @@ func (a *agent) UpdateModel() error { titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - provider.WithMaxTokens(40), } newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) if err != nil { @@ -888,7 +883,6 @@ func (a *agent) UpdateModel() error { summarizeOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), - provider.WithMaxTokens(smallModel.DefaultMaxTokens), } newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) if err != nil { diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 626882f283c030454477b27b152bd6a717d08476..d8d4ec002ed35ec06d6932643e070241fed0e227 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -164,9 +164,19 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to // } // } + cfg := config.Get() + modelConfig := cfg.Models.Large + if a.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } + return anthropic.MessageNewParams{ Model: anthropic.Model(model.ID), - MaxTokens: a.providerOptions.maxTokens, + MaxTokens: maxTokens, Temperature: temperature, Messages: messages, Tools: tools, diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 3531c5cc89cced262d5c22f2598216d9cfe4883e..56263a4389e28289db5adf9392f307d11908e1cc 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -155,17 +155,26 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { // Convert messages geminiMessages := g.convertMessages(messages) - + model := g.providerOptions.model(g.providerOptions.modelType) cfg := config.Get() if cfg.Options.Debug { jsonData, _ := json.Marshal(geminiMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } + modelConfig := cfg.Models.Large + if g.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ - MaxOutputTokens: int32(g.providerOptions.maxTokens), + MaxOutputTokens: int32(maxTokens), SystemInstruction: &genai.Content{ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, @@ -173,7 +182,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if len(tools) > 0 { config.Tools = g.convertTools(tools) } - model := g.providerOptions.model(g.providerOptions.modelType) chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 @@ -245,16 +253,25 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t // Convert messages geminiMessages := g.convertMessages(messages) + model := g.providerOptions.model(g.providerOptions.modelType) cfg := config.Get() if cfg.Options.Debug { jsonData, _ := json.Marshal(geminiMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } + modelConfig := cfg.Models.Large + if g.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ - MaxOutputTokens: int32(g.providerOptions.maxTokens), + MaxOutputTokens: int32(maxTokens), SystemInstruction: &genai.Content{ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, @@ -262,7 +279,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if len(tools) > 0 { config.Tools = g.convertTools(tools) } - model := g.providerOptions.model(g.providerOptions.modelType) chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index f6aaacb0ce09ae665fab4bdbb14b28e13e7684c7..71040b485b426b5d80e078aea7c06c710f93e4e2 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -160,8 +160,13 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar Messages: messages, Tools: tools, } + + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } if model.CanReason { - params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) + params.MaxCompletionTokens = openai.Int(maxTokens) switch reasoningEffort { case "low": params.ReasoningEffort = shared.ReasoningEffortLow @@ -173,7 +178,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar params.ReasoningEffort = shared.ReasoningEffortMedium } } else { - params.MaxTokens = openai.Int(o.providerOptions.maxTokens) + params.MaxTokens = openai.Int(maxTokens) } return params diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 9723dc9fe55af414ed415653e3e9e31031395a02..6da5188b2bf6b17e1f91c9ff5e7eb2bd20931392 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -64,7 +64,6 @@ type providerClientOptions struct { modelType config.ModelType model func(config.ModelType) config.Model disableCache bool - maxTokens int64 systemMessage string extraHeaders map[string]string extraParams map[string]string @@ -121,12 +120,6 @@ func WithDisableCache(disableCache bool) ProviderClientOption { } } -func WithMaxTokens(maxTokens int64) ProviderClientOption { - return func(options *providerClientOptions) { - options.maxTokens = maxTokens - } -} - func WithSystemMessage(systemMessage string) ProviderClientOption { return func(options *providerClientOptions) { options.systemMessage = systemMessage From 03871a565bb8f51d274148c266fb4ae5a7a89c9a Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 28 Jun 2025 12:06:40 +0200 Subject: [PATCH 21/40] chore: make it possible to override maxTokens we still need to handle the case where the max tokens is more than 5% of the total context window, this can cause the endpoint to error. --- internal/llm/agent/agent.go | 18 +++++++++++------- internal/llm/provider/anthropic.go | 5 +++++ internal/llm/provider/gemini.go | 5 +++++ internal/llm/provider/openai.go | 5 +++++ internal/llm/provider/provider.go | 9 ++++++++- 5 files changed, 34 insertions(+), 8 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 5f3d41c2eee4cc41df159066379346bad4a97dfc..57771a7dc98efd2fa897d655aa04b7fef628dab5 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -148,7 +148,7 @@ func NewAgent( provider.WithModel(agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), } - agentProvider, err := provider.NewProviderV2(providerCfg, opts...) + agentProvider, err := provider.NewProvider(providerCfg, opts...) if err != nil { return nil, err } @@ -184,7 +184,7 @@ func NewAgent( provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), } - titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) + titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) if err != nil { return nil, err } @@ -192,7 +192,7 @@ func NewAgent( provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), } - summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) + summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...) if err != nil { return nil, err } @@ -277,7 +277,9 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str if err != nil { return err } - parts := []message.ContentPart{message.TextContent{Text: content}} + parts := []message.ContentPart{message.TextContent{ + Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content), + }} // Use streaming approach like summarization response := a.titleProvider.StreamResponse( @@ -831,7 +833,7 @@ func (a *agent) UpdateModel() error { provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), } - newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...) + newProvider, err := provider.NewProvider(currentProviderCfg, opts...) if err != nil { return fmt.Errorf("failed to create new provider: %w", err) } @@ -873,8 +875,10 @@ func (a *agent) UpdateModel() error { titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), + // We want the title to be short, so we limit the max tokens + provider.WithMaxTokens(40), } - newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) + newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) if err != nil { return fmt.Errorf("failed to create new title provider: %w", err) } @@ -884,7 +888,7 @@ func (a *agent) UpdateModel() error { provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), } - newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) + newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...) if err != nil { return fmt.Errorf("failed to create new summarize provider: %w", err) } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index d8d4ec002ed35ec06d6932643e070241fed0e227..df6b8490ebc48abc7c01a2a938c6f7d395526654 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -174,6 +174,11 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to maxTokens = modelConfig.MaxTokens } + // Override max tokens if set in provider options + if a.providerOptions.maxTokens > 0 { + maxTokens = a.providerOptions.maxTokens + } + return anthropic.MessageNewParams{ Model: anthropic.Model(model.ID), MaxTokens: maxTokens, diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 56263a4389e28289db5adf9392f307d11908e1cc..f644d118b4ef642c5f9e835ecfaa450d9f835f4d 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -268,6 +268,11 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if modelConfig.MaxTokens > 0 { maxTokens = modelConfig.MaxTokens } + + // Override max tokens if set in provider options + if g.providerOptions.maxTokens > 0 { + maxTokens = g.providerOptions.maxTokens + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 71040b485b426b5d80e078aea7c06c710f93e4e2..1ae8847db441181a1a65bcacc8b4bd039b45a0fc 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -165,6 +165,11 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar if modelConfig.MaxTokens > 0 { maxTokens = modelConfig.MaxTokens } + + // Override max tokens if set in provider options + if o.providerOptions.maxTokens > 0 { + maxTokens = o.providerOptions.maxTokens + } if model.CanReason { params.MaxCompletionTokens = openai.Int(maxTokens) switch reasoningEffort { diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 6da5188b2bf6b17e1f91c9ff5e7eb2bd20931392..2133e23309b4d92d8c8b2efbf1bb386a2e7753cd 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -65,6 +65,7 @@ type providerClientOptions struct { model func(config.ModelType) config.Model disableCache bool systemMessage string + maxTokens int64 extraHeaders map[string]string extraParams map[string]string } @@ -126,7 +127,13 @@ func WithSystemMessage(systemMessage string) ProviderClientOption { } } -func NewProviderV2(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { +func WithMaxTokens(maxTokens int64) ProviderClientOption { + return func(options *providerClientOptions) { + options.maxTokens = maxTokens + } +} + +func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { clientOptions := providerClientOptions{ baseURL: cfg.BaseURL, apiKey: cfg.APIKey, From 19c0e04c7bb8be287931346bc53243c805958e4e Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 28 Jun 2025 23:21:48 +0200 Subject: [PATCH 22/40] feat: make it possible for api key to be programatic --- cspell.json | 40 +--------------- internal/config/config.go | 6 +-- internal/config/shell.go | 74 ++++++++++++++++++++++++++++++ internal/llm/agent/agent.go | 4 +- internal/llm/provider/anthropic.go | 32 +++++++++---- internal/llm/provider/gemini.go | 23 +++++++++- internal/llm/provider/openai.go | 31 +++++++++---- internal/llm/provider/provider.go | 9 +++- 8 files changed, 154 insertions(+), 65 deletions(-) create mode 100644 internal/config/shell.go diff --git a/cspell.json b/cspell.json index 2595963182b8e2aa6fe575bbe9ca6a5da0f70a9b..d62c817e8c8699e6172e576eb0f91602dd8417a3 100644 --- a/cspell.json +++ b/cspell.json @@ -1,39 +1 @@ -{ - "flagWords": [], - "words": [ - "afero", - "alecthomas", - "bubbletea", - "charmbracelet", - "charmtone", - "Charple", - "crush", - "diffview", - "Emph", - "filepicker", - "Focusable", - "fsext", - "GROQ", - "Guac", - "imageorient", - "Lanczos", - "lipgloss", - "lsps", - "lucasb", - "nfnt", - "oksvg", - "Preproc", - "rasterx", - "rivo", - "Sourcegraph", - "srwiley", - "Strikethrough", - "termenv", - "textinput", - "trashhalo", - "uniseg", - "Unticked" - ], - "version": "0.2", - "language": "en" -} +{"language":"en","flagWords":[],"version":"0.2","words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai"]} \ No newline at end of file diff --git a/internal/config/config.go b/internal/config/config.go index 32ca8729295bb3994af27ec4359a1b4960527671..74c0c63ecdd2843da0daf4875295d9d1f8ad20d7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -592,11 +592,11 @@ func mergeProviderConfigs(base, global, local *Config) { if cfg == nil { continue } - for providerName, globalProvider := range cfg.Providers { + for providerName, p := range cfg.Providers { if _, ok := base.Providers[providerName]; !ok { - base.Providers[providerName] = globalProvider + base.Providers[providerName] = p } else { - base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider) + base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p) } } } diff --git a/internal/config/shell.go b/internal/config/shell.go new file mode 100644 index 0000000000000000000000000000000000000000..a12ecd1da3b82c113175a1f4825877a7fb94a95c --- /dev/null +++ b/internal/config/shell.go @@ -0,0 +1,74 @@ +package config + +import ( + "context" + "fmt" + "os" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/llm/tools/shell" + "github.com/charmbracelet/crush/internal/logging" +) + +// ExecuteCommand executes a shell command and returns the output +// This is a shared utility that can be used by both provider config and tools +func ExecuteCommand(ctx context.Context, command string, workingDir string) (string, error) { + if workingDir == "" { + workingDir = WorkingDirectory() + } + + persistentShell := shell.GetPersistentShell(workingDir) + + stdout, stderr, err := persistentShell.Exec(ctx, command) + if err != nil { + logging.Debug("Command execution failed", "command", command, "error", err, "stderr", stderr) + return "", fmt.Errorf("command execution failed: %w", err) + } + + return strings.TrimSpace(stdout), nil +} + +// ResolveAPIKey resolves an API key that can be either: +// - A direct string value +// - An environment variable (prefixed with $) +// - A shell command (wrapped in $(...)) +func ResolveAPIKey(apiKey string) (string, error) { + if !strings.HasPrefix(apiKey, "$") { + return apiKey, nil + } + + if strings.HasPrefix(apiKey, "$(") && strings.HasSuffix(apiKey, ")") { + command := strings.TrimSuffix(strings.TrimPrefix(apiKey, "$("), ")") + logging.Debug("Resolving API key from command", "command", command) + return resolveCommandAPIKey(command) + } + + envVar := strings.TrimPrefix(apiKey, "$") + if value := os.Getenv(envVar); value != "" { + logging.Debug("Resolved environment variable", "envVar", envVar, "value", value) + return value, nil + } + + logging.Debug("Environment variable not found", "envVar", envVar) + + return "", fmt.Errorf("environment variable %s not found", envVar) +} + +// resolveCommandAPIKey executes a command to get an API key, with caching support +func resolveCommandAPIKey(command string) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + logging.Debug("Executing command for API key", "command", command) + + workingDir := WorkingDirectory() + + result, err := ExecuteCommand(ctx, command, workingDir) + if err != nil { + return "", fmt.Errorf("failed to execute API key command: %w", err) + } + logging.Debug("Command executed successfully", "command", command, "result", result) + return result, nil +} + diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 57771a7dc98efd2fa897d655aa04b7fef628dab5..d165921f639ffee7127e4044c42d154f091a0dca 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -253,10 +253,10 @@ func (a *agent) IsBusy() bool { if cancelFunc, ok := value.(context.CancelFunc); ok { if cancelFunc != nil { busy = true - return false // Stop iterating + return false } } - return true // Continue iterating + return true }) return busy } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index df6b8490ebc48abc7c01a2a938c6f7d395526654..c86f4372acc4fafd2a829f42489f545c4d589861 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -21,12 +21,20 @@ import ( type anthropicClient struct { providerOptions providerClientOptions + useBedrock bool client anthropic.Client } type AnthropicClient ProviderClient func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient { + return &anthropicClient{ + providerOptions: opts, + client: createAnthropicClient(opts, useBedrock), + } +} + +func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client { anthropicClientOptions := []option.RequestOption{} if opts.apiKey != "" { anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) @@ -34,12 +42,7 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl if useBedrock { anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } - - client := anthropic.NewClient(anthropicClientOptions...) - return &anthropicClient{ - providerOptions: opts, - client: client, - } + return anthropic.NewClient(anthropicClientOptions...) } func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { @@ -385,12 +388,21 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message } func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) { - var apierr *anthropic.Error - if !errors.As(err, &apierr) { + var apiErr *anthropic.Error + if !errors.As(err, &apiErr) { return false, 0, err } - if apierr.StatusCode != 429 && apierr.StatusCode != 529 { + if apiErr.StatusCode == 401 { + a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey) + if err != nil { + return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + } + a.client = createAnthropicClient(a.providerOptions, a.useBedrock) + return true, 0, nil + } + + if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 { return false, 0, err } @@ -399,7 +411,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err } retryMs := 0 - retryAfterValues := apierr.Response.Header.Values("Retry-After") + retryAfterValues := apiErr.Response.Header.Values("Retry-After") backoffMs := 2000 * (1 << (attempts - 1)) jitterMs := int(float64(backoffMs) * 0.2) diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index f644d118b4ef642c5f9e835ecfaa450d9f835f4d..e80af34d0815695ea6ed76d01c25262381a836ec 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -25,7 +25,7 @@ type geminiClient struct { type GeminiClient ProviderClient func newGeminiClient(opts providerClientOptions) GeminiClient { - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) + client, err := createGeminiClient(opts) if err != nil { logging.Error("Failed to create Gemini client", "error", err) return nil @@ -37,6 +37,14 @@ func newGeminiClient(opts providerClientOptions) GeminiClient { } } +func createGeminiClient(opts providerClientOptions) (*genai.Client, error) { + client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) + if err != nil { + return nil, err + } + return client, nil +} + func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content { var history []*genai.Content for _, msg := range messages { @@ -414,6 +422,19 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) errMsg := err.Error() isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests") + // Check for token expiration (401 Unauthorized) + if contains(errMsg, "unauthorized", "invalid api key", "api key expired") { + g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey) + if err != nil { + return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + } + g.client, err = createGeminiClient(g.providerOptions) + if err != nil { + return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err) + } + return true, 0, nil + } + // Check for common rate limit error messages if !isRateLimit { diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 1ae8847db441181a1a65bcacc8b4bd039b45a0fc..e045029651f3e9fc158c9f38cf810584e2c06724 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -26,6 +26,13 @@ type openaiClient struct { type OpenAIClient ProviderClient func newOpenAIClient(opts providerClientOptions) OpenAIClient { + return &openaiClient{ + providerOptions: opts, + client: createOpenAIClient(opts), + } +} + +func createOpenAIClient(opts providerClientOptions) openai.Client { openaiClientOptions := []option.RequestOption{} if opts.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) @@ -40,11 +47,7 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient { } } - client := openai.NewClient(openaiClientOptions...) - return &openaiClient{ - providerOptions: opts, - client: client, - } + return openai.NewClient(openaiClientOptions...) } func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { @@ -339,12 +342,22 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t } func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { - var apierr *openai.Error - if !errors.As(err, &apierr) { + var apiErr *openai.Error + if !errors.As(err, &apiErr) { return false, 0, err } - if apierr.StatusCode != 429 && apierr.StatusCode != 500 { + // Check for token expiration (401 Unauthorized) + if apiErr.StatusCode == 401 { + o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey) + if err != nil { + return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + } + o.client = createOpenAIClient(o.providerOptions) + return true, 0, nil + } + + if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 { return false, 0, err } @@ -353,7 +366,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) } retryMs := 0 - retryAfterValues := apierr.Response.Header.Values("Retry-After") + retryAfterValues := apiErr.Response.Header.Values("Retry-After") backoffMs := 2000 * (1 << (attempts - 1)) jitterMs := int(float64(backoffMs) * 0.2) diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 2133e23309b4d92d8c8b2efbf1bb386a2e7753cd..3ffbf86c00c5e3ca27f1b68965f4ff950f1f7454 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -60,6 +60,7 @@ type Provider interface { type providerClientOptions struct { baseURL string + config config.ProviderConfig apiKey string modelType config.ModelType model func(config.ModelType) config.Model @@ -134,9 +135,15 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption { } func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { + resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey) + if err != nil { + return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err) + } + clientOptions := providerClientOptions{ baseURL: cfg.BaseURL, - apiKey: cfg.APIKey, + config: cfg, + apiKey: resolvedAPIKey, extraHeaders: cfg.ExtraHeaders, model: func(tp config.ModelType) config.Model { return config.GetModel(tp) From 19f1a2d34f782e352cfefd35585e17f2ac150f2e Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 11:01:25 +0200 Subject: [PATCH 23/40] chore: enable thinking for anthropic models --- internal/config/config.go | 7 +++++-- internal/llm/provider/anthropic.go | 30 +++++++----------------------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 74c0c63ecdd2843da0daf4875295d9d1f8ad20d7..69a528d3c57eba4ef4b8802d800bfb0a7a764c14 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -159,10 +159,13 @@ type Options struct { type PreferredModel struct { ModelID string `json:"model_id"` Provider provider.InferenceProvider `json:"provider"` - // Overrides the default reasoning effort for this model + // ReasoningEffort overrides the default reasoning effort for this model ReasoningEffort string `json:"reasoning_effort,omitempty"` - // Overrides the default max tokens for this model + // MaxTokens overrides the default max tokens for this model MaxTokens int64 `json:"max_tokens,omitempty"` + + // Think indicates if the model should think, only applicable for anthropic reasoning models + Think bool `json:"think,omitempty"` } type PreferredModels struct { diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index c86f4372acc4fafd2a829f42489f545c4d589861..05f09ad77e224a59bcb825e85f353e317c7c4a83 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "strings" "time" "github.com/anthropics/anthropic-sdk-go" @@ -150,28 +149,18 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason { func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { model := a.providerOptions.model(a.providerOptions.modelType) var thinkingParam anthropic.ThinkingConfigParamUnion - // TODO: Implement a proper thinking function - // lastMessage := messages[len(messages)-1] - // isUser := lastMessage.Role == anthropic.MessageParamRoleUser - // messageContent := "" - temperature := anthropic.Float(0) - // if isUser { - // for _, m := range lastMessage.Content { - // if m.OfText != nil && m.OfText.Text != "" { - // messageContent = m.OfText.Text - // } - // } - // if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) { - // thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8)) - // temperature = anthropic.Float(1) - // } - // } - cfg := config.Get() modelConfig := cfg.Models.Large if a.providerOptions.modelType == config.SmallModel { modelConfig = cfg.Models.Small } + temperature := anthropic.Float(0) + + if a.Model().CanReason && modelConfig.Think { + thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8)) + temperature = anthropic.Float(1) + } + maxTokens := model.DefaultMaxTokens if modelConfig.MaxTokens > 0 { maxTokens = modelConfig.MaxTokens @@ -456,8 +445,3 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { func (a *anthropicClient) Model() config.Model { return a.providerOptions.model(a.providerOptions.modelType) } - -// TODO: check if we need -func DefaultShouldThinkFn(s string) bool { - return strings.Contains(strings.ToLower(s), "think") -} From 28c85ea0cacafca9e2408366f710fb6c610c802c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 12:11:52 +0200 Subject: [PATCH 24/40] chore: generate schema, fix tests --- Taskfile.yaml | 6 + cmd/schema/README.md | 172 ++- cmd/schema/main.go | 766 ++++++++----- crush-schema.json | 1728 ++++++++++++++++++++++++------ internal/config/config.go | 56 + internal/config/config_test.go | 733 +++++-------- internal/config/provider_mock.go | 306 +++--- internal/config/provider_test.go | 221 +--- internal/config/shell.go | 1 - 9 files changed, 2572 insertions(+), 1417 deletions(-) diff --git a/Taskfile.yaml b/Taskfile.yaml index 7e181d84513bc1abc5d5a32807f88a75898633ea..9453eb363513d5b9e0987a637d761eef1bfe63c3 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -44,3 +44,9 @@ tasks: desc: Allocations profile cmds: - go tool pprof -http :6061 'http://localhost:6060/debug/pprof/allocs' + + schema: + desc: Generate JSON schema for configuration + cmds: + - go run cmd/schema/main.go > crush-schema.json + - echo "Generated crush-schema.json" diff --git a/cmd/schema/README.md b/cmd/schema/README.md index 517fdb4d20fb9f2b819051bd72e6c33f5dea2195..c7f0102e590dcd98ce150e506c28f9016ef50bbc 100644 --- a/cmd/schema/README.md +++ b/cmd/schema/README.md @@ -1,64 +1,182 @@ # Crush Configuration Schema Generator -This tool generates a JSON Schema for the Crush configuration file. The schema can be used to validate configuration files and provide autocompletion in editors that support JSON Schema. +This tool automatically generates a JSON Schema for the Crush configuration file by using Go reflection to analyze the configuration structs. The schema provides validation, autocompletion, and documentation for configuration files. + +## Features + +- **Automated Generation**: Uses reflection to automatically generate schemas from Go structs +- **Always Up-to-Date**: Schema stays in sync with code changes automatically +- **Comprehensive**: Includes all configuration options, types, and validation rules +- **Enhanced**: Adds provider enums, model lists, and custom descriptions +- **Extensible**: Easy to add new fields and modify existing ones ## Usage ```bash +# Generate the schema go run cmd/schema/main.go > crush-schema.json + +# Or use the task runner +task schema ``` -This will generate a JSON Schema file that can be used to validate configuration files. +## How It Works + +The generator: + +1. **Reflects on Config Structs**: Analyzes the `config.Config` struct and all related types +2. **Generates Base Schema**: Creates JSON Schema definitions for all struct fields +3. **Enhances with Runtime Data**: Adds provider lists, model enums, and tool lists from the actual codebase +4. **Adds Custom Descriptions**: Provides meaningful descriptions for configuration options +5. **Sets Default Values**: Includes appropriate defaults for optional fields ## Schema Features The generated schema includes: -- All configuration options with descriptions -- Default values where applicable -- Validation for enum values (e.g., model IDs, provider types) -- Required fields -- Type checking +- **Type Safety**: Proper type definitions for all configuration fields +- **Validation**: Required fields, enum constraints, and format validation +- **Documentation**: Descriptions for all configuration options +- **Defaults**: Default values for optional settings +- **Provider Enums**: Current list of supported providers +- **Model Enums**: Available models from all configured providers +- **Tool Lists**: Valid tool names for agent configurations +- **Cross-References**: Proper relationships between different config sections + +## Adding New Configuration Fields + +To add new configuration options: + +1. **Add to Config Structs**: Add the field to the appropriate struct in `internal/config/` +2. **Add JSON Tags**: Include proper JSON tags with field names +3. **Regenerate Schema**: Run the schema generator to update the JSON schema +4. **Update Validation**: Add any custom validation logic if needed + +Example: +```go +type Options struct { + // ... existing fields ... + + // New field with JSON tag and description + NewFeature bool `json:"new_feature,omitempty"` +} +``` + +The schema generator will automatically: +- Detect the new field +- Generate appropriate JSON schema +- Add type information +- Include in validation ## Using the Schema -You can use the generated schema in several ways: +### Editor Integration + +Most modern editors support JSON Schema: + +**VS Code**: Add to your workspace settings: +```json +{ + "json.schemas": [ + { + "fileMatch": ["crush.json", ".crush.json"], + "url": "./crush-schema.json" + } + ] +} +``` -1. **Editor Integration**: Many editors (VS Code, JetBrains IDEs, etc.) support JSON Schema for validation and autocompletion. You can configure your editor to use the generated schema for `.crush.json` files. +**JetBrains IDEs**: Configure in Settings → Languages & Frameworks → Schemas and DTDs → JSON Schema Mappings -2. **Validation Tools**: You can use tools like [jsonschema](https://github.com/Julian/jsonschema) to validate your configuration files against the schema. +### Validation Tools -3. **Documentation**: The schema serves as documentation for the configuration options. +```bash +# Using jsonschema (Python) +pip install jsonschema +jsonschema -i crush.json crush-schema.json -## Example Configuration +# Using ajv-cli (Node.js) +npm install -g ajv-cli +ajv validate -s crush-schema.json -d crush.json +``` -Here's an example configuration that conforms to the schema: +### Configuration Example ```json { - "data": { - "directory": ".crush" + "models": { + "large": { + "model_id": "claude-3-5-sonnet-20241022", + "provider": "anthropic", + "reasoning_effort": "medium", + "max_tokens": 8192 + }, + "small": { + "model_id": "claude-3-5-haiku-20241022", + "provider": "anthropic" + } }, - "debug": false, "providers": { "anthropic": { - "apiKey": "your-api-key" + "id": "anthropic", + "provider_type": "anthropic", + "api_key": "your-api-key", + "disabled": false } }, "agents": { "coder": { - "model": "claude-3.7-sonnet", - "maxTokens": 5000, - "reasoningEffort": "medium" - }, - "task": { - "model": "claude-3.7-sonnet", - "maxTokens": 5000 + "id": "coder", + "name": "Coder", + "model": "large", + "disabled": false }, - "title": { - "model": "claude-3.7-sonnet", - "maxTokens": 80 + "custom-agent": { + "id": "custom-agent", + "name": "Custom Agent", + "description": "A custom agent for specific tasks", + "model": "small", + "allowed_tools": ["glob", "grep", "view"], + "allowed_mcp": { + "filesystem": ["read", "write"] + } + } + }, + "mcp": { + "filesystem": { + "command": "mcp-filesystem", + "args": ["--root", "/workspace"], + "type": "stdio" + } + }, + "lsp": { + "typescript": { + "command": "typescript-language-server", + "args": ["--stdio"], + "enabled": true + } + }, + "options": { + "context_paths": [ + "README.md", + "docs/", + ".cursorrules" + ], + "data_directory": ".crush", + "debug": false, + "tui": { + "compact_mode": false } } } ``` + +## Maintenance + +The schema generator is designed to be maintenance-free. As long as: + +- Configuration structs have proper JSON tags +- New enums are added to the enhancement functions +- The generator is run after significant config changes + +The schema will stay current with the codebase automatically. \ No newline at end of file diff --git a/cmd/schema/main.go b/cmd/schema/main.go index 9eb88769fd84772628df5332d3dcc1b1b234ac90..34939f1e84b9f3df04c7419a9ac4d7dfdc76386a 100644 --- a/cmd/schema/main.go +++ b/cmd/schema/main.go @@ -1,30 +1,70 @@ -// TODO: FIX THIS package main import ( "encoding/json" "fmt" "os" + "reflect" + "slices" + "strings" "github.com/charmbracelet/crush/internal/config" ) -// JSONSchemaType represents a JSON Schema type -type JSONSchemaType struct { - Type string `json:"type,omitempty"` - Description string `json:"description,omitempty"` - Properties map[string]any `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` - AdditionalProperties any `json:"additionalProperties,omitempty"` - Enum []any `json:"enum,omitempty"` - Items map[string]any `json:"items,omitempty"` - OneOf []map[string]any `json:"oneOf,omitempty"` - AnyOf []map[string]any `json:"anyOf,omitempty"` - Default any `json:"default,omitempty"` +// JSONSchema represents a JSON Schema +type JSONSchema struct { + Schema string `json:"$schema,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Type string `json:"type,omitempty"` + Properties map[string]*JSONSchema `json:"properties,omitempty"` + Items *JSONSchema `json:"items,omitempty"` + Required []string `json:"required,omitempty"` + AdditionalProperties any `json:"additionalProperties,omitempty"` + Enum []any `json:"enum,omitempty"` + Default any `json:"default,omitempty"` + Definitions map[string]*JSONSchema `json:"definitions,omitempty"` + Ref string `json:"$ref,omitempty"` + OneOf []*JSONSchema `json:"oneOf,omitempty"` + AnyOf []*JSONSchema `json:"anyOf,omitempty"` + AllOf []*JSONSchema `json:"allOf,omitempty"` + Not *JSONSchema `json:"not,omitempty"` + Format string `json:"format,omitempty"` + Pattern string `json:"pattern,omitempty"` + MinLength *int `json:"minLength,omitempty"` + MaxLength *int `json:"maxLength,omitempty"` + Minimum *float64 `json:"minimum,omitempty"` + Maximum *float64 `json:"maximum,omitempty"` + ExclusiveMinimum *float64 `json:"exclusiveMinimum,omitempty"` + ExclusiveMaximum *float64 `json:"exclusiveMaximum,omitempty"` + MultipleOf *float64 `json:"multipleOf,omitempty"` + MinItems *int `json:"minItems,omitempty"` + MaxItems *int `json:"maxItems,omitempty"` + UniqueItems *bool `json:"uniqueItems,omitempty"` + MinProperties *int `json:"minProperties,omitempty"` + MaxProperties *int `json:"maxProperties,omitempty"` +} + +// SchemaGenerator generates JSON schemas from Go types +type SchemaGenerator struct { + definitions map[string]*JSONSchema + visited map[reflect.Type]bool +} + +// NewSchemaGenerator creates a new schema generator +func NewSchemaGenerator() *SchemaGenerator { + return &SchemaGenerator{ + definitions: make(map[string]*JSONSchema), + visited: make(map[reflect.Type]bool), + } } func main() { - schema := generateSchema() + // Enable mock providers to avoid API calls during schema generation + config.UseMockProviders = true + + generator := NewSchemaGenerator() + schema := generator.GenerateSchema() // Pretty print the schema encoder := json.NewEncoder(os.Stdout) @@ -35,261 +75,457 @@ func main() { } } -func generateSchema() map[string]any { - schema := map[string]any{ - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "Crush Configuration", - "description": "Configuration schema for the Crush application", - "type": "object", - "properties": map[string]any{}, - } - - // Add Data configuration - schema["properties"].(map[string]any)["data"] = map[string]any{ - "type": "object", - "description": "Storage configuration", - "properties": map[string]any{ - "directory": map[string]any{ - "type": "string", - "description": "Directory where application data is stored", - "default": ".crush", - }, - }, - "required": []string{"directory"}, - } - - // Add working directory - schema["properties"].(map[string]any)["wd"] = map[string]any{ - "type": "string", - "description": "Working directory for the application", - } - - // Add debug flags - schema["properties"].(map[string]any)["debug"] = map[string]any{ - "type": "boolean", - "description": "Enable debug mode", - "default": false, - } - - schema["properties"].(map[string]any)["debugLSP"] = map[string]any{ - "type": "boolean", - "description": "Enable LSP debug mode", - "default": false, - } - - schema["properties"].(map[string]any)["contextPaths"] = map[string]any{ - "type": "array", - "description": "Context paths for the application", - "items": map[string]any{ - "type": "string", - }, - "default": []string{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", - }, - } - - schema["properties"].(map[string]any)["tui"] = map[string]any{ - "type": "object", - "description": "Terminal User Interface configuration", - "properties": map[string]any{ - "theme": map[string]any{ - "type": "string", - "description": "TUI theme name", - "default": "crush", - "enum": []string{ - "crush", - "catppuccin", - "dracula", - "flexoki", - "gruvbox", - "monokai", - "onedark", - "tokyonight", - "tron", - }, - }, - }, - } - - // Add MCP servers - schema["properties"].(map[string]any)["mcpServers"] = map[string]any{ - "type": "object", - "description": "Model Control Protocol server configurations", - "additionalProperties": map[string]any{ - "type": "object", - "description": "MCP server configuration", - "properties": map[string]any{ - "command": map[string]any{ - "type": "string", - "description": "Command to execute for the MCP server", - }, - "env": map[string]any{ - "type": "array", - "description": "Environment variables for the MCP server", - "items": map[string]any{ - "type": "string", - }, - }, - "args": map[string]any{ - "type": "array", - "description": "Command arguments for the MCP server", - "items": map[string]any{ - "type": "string", - }, - }, - "type": map[string]any{ - "type": "string", - "description": "Type of MCP server", - "enum": []string{"stdio", "sse"}, - "default": "stdio", - }, - "url": map[string]any{ - "type": "string", - "description": "URL for SSE type MCP servers", - }, - "headers": map[string]any{ - "type": "object", - "description": "HTTP headers for SSE type MCP servers", - "additionalProperties": map[string]any{ - "type": "string", - }, - }, - }, - "required": []string{"command"}, - }, - } - - // Add providers - providerSchema := map[string]any{ - "type": "object", - "description": "LLM provider configurations", - "additionalProperties": map[string]any{ - "type": "object", - "description": "Provider configuration", - "properties": map[string]any{ - "apiKey": map[string]any{ - "type": "string", - "description": "API key for the provider", - }, - "disabled": map[string]any{ - "type": "boolean", - "description": "Whether the provider is disabled", - "default": false, - }, - }, - }, - } - - providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{ - "type": "string", - "description": "Provider type", - "enum": []string{}, - } - - schema["properties"].(map[string]any)["providers"] = providerSchema - - // Add agents - agentSchema := map[string]any{ - "type": "object", - "description": "Agent configurations", - "additionalProperties": map[string]any{ - "type": "object", - "description": "Agent configuration", - "properties": map[string]any{ - "model": map[string]any{ - "type": "string", - "description": "Model ID for the agent", - }, - "maxTokens": map[string]any{ - "type": "integer", - "description": "Maximum tokens for the agent", - "minimum": 1, - }, - "reasoningEffort": map[string]any{ - "type": "string", - "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", - "enum": []string{"low", "medium", "high"}, - }, - }, - "required": []string{"model"}, - }, - } - - // Add model enum - modelEnum := []string{} - - agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum - - // Add specific agent properties - agentProperties := map[string]any{} - knownAgents := []string{ +// GenerateSchema generates the complete JSON schema for the Crush configuration +func (g *SchemaGenerator) GenerateSchema() *JSONSchema { + // Generate schema for the main Config struct + configType := reflect.TypeOf(config.Config{}) + configSchema := g.generateTypeSchema(configType) + + // Create the root schema + schema := &JSONSchema{ + Schema: "http://json-schema.org/draft-07/schema#", + Title: "Crush Configuration", + Description: "Configuration schema for the Crush application", + Type: configSchema.Type, + Properties: configSchema.Properties, + Required: configSchema.Required, + Definitions: g.definitions, + } + + // Add custom enhancements + g.enhanceSchema(schema) + + return schema +} + +// generateTypeSchema generates a JSON schema for a given Go type +func (g *SchemaGenerator) generateTypeSchema(t reflect.Type) *JSONSchema { + // Handle pointers + if t.Kind() == reflect.Ptr { + return g.generateTypeSchema(t.Elem()) + } + + // Check if we've already processed this type + if g.visited[t] { + // Return a reference to avoid infinite recursion + return &JSONSchema{ + Ref: fmt.Sprintf("#/definitions/%s", t.Name()), + } + } + + switch t.Kind() { + case reflect.String: + return &JSONSchema{Type: "string"} + case reflect.Bool: + return &JSONSchema{Type: "boolean"} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return &JSONSchema{Type: "integer"} + case reflect.Float32, reflect.Float64: + return &JSONSchema{Type: "number"} + case reflect.Slice, reflect.Array: + itemSchema := g.generateTypeSchema(t.Elem()) + return &JSONSchema{ + Type: "array", + Items: itemSchema, + } + case reflect.Map: + valueSchema := g.generateTypeSchema(t.Elem()) + return &JSONSchema{ + Type: "object", + AdditionalProperties: valueSchema, + } + case reflect.Struct: + return g.generateStructSchema(t) + case reflect.Interface: + // For interface{} types, allow any value + return &JSONSchema{} + default: + // Fallback for unknown types + return &JSONSchema{} + } +} + +// generateStructSchema generates a JSON schema for a struct type +func (g *SchemaGenerator) generateStructSchema(t reflect.Type) *JSONSchema { + // Mark as visited to prevent infinite recursion + g.visited[t] = true + + schema := &JSONSchema{ + Type: "object", + Properties: make(map[string]*JSONSchema), + } + + var required []string + + for i := range t.NumField() { + field := t.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Get JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue + } + + // Parse JSON tag + jsonName, options := parseJSONTag(jsonTag) + if jsonName == "" { + jsonName = strings.ToLower(field.Name) + } + + // Generate field schema + fieldSchema := g.generateTypeSchema(field.Type) + + // Add description from field name if not present + if fieldSchema.Description == "" { + fieldSchema.Description = generateFieldDescription(field.Name, field.Type) + } + + // Check if field is required (not omitempty and not a pointer) + if !slices.Contains(options, "omitempty") && field.Type.Kind() != reflect.Ptr { + required = append(required, jsonName) + } + + schema.Properties[jsonName] = fieldSchema + } + + if len(required) > 0 { + schema.Required = required + } + + // Store in definitions if it's a named type + if t.Name() != "" { + g.definitions[t.Name()] = schema + } + + return schema +} + +// parseJSONTag parses a JSON struct tag +func parseJSONTag(tag string) (name string, options []string) { + if tag == "" { + return "", nil + } + + parts := strings.Split(tag, ",") + name = parts[0] + if len(parts) > 1 { + options = parts[1:] + } + return name, options +} + +// generateFieldDescription generates a description for a field based on its name and type +func generateFieldDescription(fieldName string, fieldType reflect.Type) string { + // Convert camelCase to words + words := camelCaseToWords(fieldName) + description := strings.Join(words, " ") + + // Add type-specific information + switch fieldType.Kind() { + case reflect.Bool: + if !strings.Contains(strings.ToLower(description), "enable") && + !strings.Contains(strings.ToLower(description), "disable") { + description = "Enable " + strings.ToLower(description) + } + case reflect.Slice: + if !strings.HasSuffix(description, "s") { + description = description + " list" + } + case reflect.Map: + description = description + " configuration" + } + + return description +} + +// camelCaseToWords converts camelCase to separate words +func camelCaseToWords(s string) []string { + var words []string + var currentWord strings.Builder + + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + if currentWord.Len() > 0 { + words = append(words, currentWord.String()) + currentWord.Reset() + } + } + currentWord.WriteRune(r) + } + + if currentWord.Len() > 0 { + words = append(words, currentWord.String()) + } + + return words +} + +// enhanceSchema adds custom enhancements to the generated schema +func (g *SchemaGenerator) enhanceSchema(schema *JSONSchema) { + // Add provider enums + g.addProviderEnums(schema) + + // Add model enums + g.addModelEnums(schema) + + // Add agent enums + g.addAgentEnums(schema) + + // Add tool enums + g.addToolEnums(schema) + + // Add MCP type enums + g.addMCPTypeEnums(schema) + + // Add model type enums + g.addModelTypeEnums(schema) + + // Add default values + g.addDefaultValues(schema) + + // Add custom descriptions + g.addCustomDescriptions(schema) +} + +// addProviderEnums adds provider enums to the schema +func (g *SchemaGenerator) addProviderEnums(schema *JSONSchema) { + providers := config.Providers() + var providerIDs []any + for _, p := range providers { + providerIDs = append(providerIDs, string(p.ID)) + } + + // Add to PreferredModel provider field + if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { + if providerProp, exists := preferredModelDef.Properties["provider"]; exists { + providerProp.Enum = providerIDs + } + } + + // Add to ProviderConfig ID field + if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { + if idProp, exists := providerConfigDef.Properties["id"]; exists { + idProp.Enum = providerIDs + } + } +} + +// addModelEnums adds model enums to the schema +func (g *SchemaGenerator) addModelEnums(schema *JSONSchema) { + providers := config.Providers() + var modelIDs []any + for _, p := range providers { + for _, m := range p.Models { + modelIDs = append(modelIDs, m.ID) + } + } + + // Add to PreferredModel model_id field + if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { + if modelIDProp, exists := preferredModelDef.Properties["model_id"]; exists { + modelIDProp.Enum = modelIDs + } + } +} + +// addAgentEnums adds agent ID enums to the schema +func (g *SchemaGenerator) addAgentEnums(schema *JSONSchema) { + agentIDs := []any{ string(config.AgentCoder), string(config.AgentTask), } - for _, agentName := range knownAgents { - agentProperties[agentName] = map[string]any{ - "$ref": "#/definitions/agent", - } - } - - // Create a combined schema that allows both specific agents and additional ones - combinedAgentSchema := map[string]any{ - "type": "object", - "description": "Agent configurations", - "properties": agentProperties, - "additionalProperties": agentSchema["additionalProperties"], - } - - schema["properties"].(map[string]any)["agents"] = combinedAgentSchema - schema["definitions"] = map[string]any{ - "agent": agentSchema["additionalProperties"], - } - - // Add LSP configuration - schema["properties"].(map[string]any)["lsp"] = map[string]any{ - "type": "object", - "description": "Language Server Protocol configurations", - "additionalProperties": map[string]any{ - "type": "object", - "description": "LSP configuration for a language", - "properties": map[string]any{ - "disabled": map[string]any{ - "type": "boolean", - "description": "Whether the LSP is disabled", - "default": false, - }, - "command": map[string]any{ - "type": "string", - "description": "Command to execute for the LSP server", - }, - "args": map[string]any{ - "type": "array", - "description": "Command arguments for the LSP server", - "items": map[string]any{ - "type": "string", - }, - }, - "options": map[string]any{ - "type": "object", - "description": "Additional options for the LSP server", - }, - }, - "required": []string{"command"}, - }, + if agentDef, exists := schema.Definitions["Agent"]; exists { + if idProp, exists := agentDef.Properties["id"]; exists { + idProp.Enum = agentIDs + } } +} - return schema +// addToolEnums adds tool enums to the schema +func (g *SchemaGenerator) addToolEnums(schema *JSONSchema) { + tools := []any{ + "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", + } + + if agentDef, exists := schema.Definitions["Agent"]; exists { + if allowedToolsProp, exists := agentDef.Properties["allowed_tools"]; exists { + if allowedToolsProp.Items != nil { + allowedToolsProp.Items.Enum = tools + } + } + } +} + +// addMCPTypeEnums adds MCP type enums to the schema +func (g *SchemaGenerator) addMCPTypeEnums(schema *JSONSchema) { + mcpTypes := []any{ + string(config.MCPStdio), + string(config.MCPSse), + } + + if mcpDef, exists := schema.Definitions["MCP"]; exists { + if typeProp, exists := mcpDef.Properties["type"]; exists { + typeProp.Enum = mcpTypes + } + } +} + +// addModelTypeEnums adds model type enums to the schema +func (g *SchemaGenerator) addModelTypeEnums(schema *JSONSchema) { + modelTypes := []any{ + string(config.LargeModel), + string(config.SmallModel), + } + + if agentDef, exists := schema.Definitions["Agent"]; exists { + if modelProp, exists := agentDef.Properties["model"]; exists { + modelProp.Enum = modelTypes + } + } +} + +// addDefaultValues adds default values to the schema +func (g *SchemaGenerator) addDefaultValues(schema *JSONSchema) { + // Add default context paths + if optionsDef, exists := schema.Definitions["Options"]; exists { + if contextPathsProp, exists := optionsDef.Properties["context_paths"]; exists { + contextPathsProp.Default = []any{ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md", + } + } + if dataDirProp, exists := optionsDef.Properties["data_directory"]; exists { + dataDirProp.Default = ".crush" + } + if debugProp, exists := optionsDef.Properties["debug"]; exists { + debugProp.Default = false + } + if debugLSPProp, exists := optionsDef.Properties["debug_lsp"]; exists { + debugLSPProp.Default = false + } + if disableAutoSummarizeProp, exists := optionsDef.Properties["disable_auto_summarize"]; exists { + disableAutoSummarizeProp.Default = false + } + } + + // Add default MCP type + if mcpDef, exists := schema.Definitions["MCP"]; exists { + if typeProp, exists := mcpDef.Properties["type"]; exists { + typeProp.Default = string(config.MCPStdio) + } + } + + // Add default TUI options + if tuiOptionsDef, exists := schema.Definitions["TUIOptions"]; exists { + if compactModeProp, exists := tuiOptionsDef.Properties["compact_mode"]; exists { + compactModeProp.Default = false + } + } + + // Add default provider disabled + if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { + if disabledProp, exists := providerConfigDef.Properties["disabled"]; exists { + disabledProp.Default = false + } + } + + // Add default agent disabled + if agentDef, exists := schema.Definitions["Agent"]; exists { + if disabledProp, exists := agentDef.Properties["disabled"]; exists { + disabledProp.Default = false + } + } + + // Add default LSP disabled + if lspConfigDef, exists := schema.Definitions["LSPConfig"]; exists { + if disabledProp, exists := lspConfigDef.Properties["enabled"]; exists { + disabledProp.Default = true + } + } +} + +// addCustomDescriptions adds custom descriptions to improve the schema +func (g *SchemaGenerator) addCustomDescriptions(schema *JSONSchema) { + // Enhance main config descriptions + if schema.Properties != nil { + if modelsProp, exists := schema.Properties["models"]; exists { + modelsProp.Description = "Preferred model configurations for large and small model types" + } + if providersProp, exists := schema.Properties["providers"]; exists { + providersProp.Description = "LLM provider configurations" + } + if agentsProp, exists := schema.Properties["agents"]; exists { + agentsProp.Description = "Agent configurations for different tasks" + } + if mcpProp, exists := schema.Properties["mcp"]; exists { + mcpProp.Description = "Model Control Protocol server configurations" + } + if lspProp, exists := schema.Properties["lsp"]; exists { + lspProp.Description = "Language Server Protocol configurations" + } + if optionsProp, exists := schema.Properties["options"]; exists { + optionsProp.Description = "General application options and settings" + } + } + + // Enhance specific field descriptions + if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { + if apiKeyProp, exists := providerConfigDef.Properties["api_key"]; exists { + apiKeyProp.Description = "API key for authenticating with the provider" + } + if baseURLProp, exists := providerConfigDef.Properties["base_url"]; exists { + baseURLProp.Description = "Base URL for the provider API (required for custom providers)" + } + if extraHeadersProp, exists := providerConfigDef.Properties["extra_headers"]; exists { + extraHeadersProp.Description = "Additional HTTP headers to send with requests" + } + if extraParamsProp, exists := providerConfigDef.Properties["extra_params"]; exists { + extraParamsProp.Description = "Additional provider-specific parameters" + } + } + + if agentDef, exists := schema.Definitions["Agent"]; exists { + if allowedToolsProp, exists := agentDef.Properties["allowed_tools"]; exists { + allowedToolsProp.Description = "List of tools this agent is allowed to use (if nil, all tools are allowed)" + } + if allowedMCPProp, exists := agentDef.Properties["allowed_mcp"]; exists { + allowedMCPProp.Description = "Map of MCP servers this agent can use and their allowed tools" + } + if allowedLSPProp, exists := agentDef.Properties["allowed_lsp"]; exists { + allowedLSPProp.Description = "List of LSP servers this agent can use (if nil, all LSPs are allowed)" + } + if contextPathsProp, exists := agentDef.Properties["context_paths"]; exists { + contextPathsProp.Description = "Custom context paths for this agent (additive to global context paths)" + } + } + + if mcpDef, exists := schema.Definitions["MCP"]; exists { + if commandProp, exists := mcpDef.Properties["command"]; exists { + commandProp.Description = "Command to execute for stdio MCP servers" + } + if urlProp, exists := mcpDef.Properties["url"]; exists { + urlProp.Description = "URL for SSE MCP servers" + } + if headersProp, exists := mcpDef.Properties["headers"]; exists { + headersProp.Description = "HTTP headers for SSE MCP servers" + } + } } diff --git a/crush-schema.json b/crush-schema.json index 5412a4badecb3e9d49022a69b3c7eb20fce0812b..f5fa562c5aff42972eb2308c3374969e5d42cac8 100644 --- a/crush-schema.json +++ b/crush-schema.json @@ -1,383 +1,1505 @@ { "$schema": "http://json-schema.org/draft-07/schema#", - "definitions": { - "agent": { - "description": "Agent configuration", - "properties": { - "maxTokens": { - "description": "Maximum tokens for the agent", - "minimum": 1, - "type": "integer" - }, - "model": { - "description": "Model ID for the agent", - "enum": [ - "gpt-4.1", - "llama-3.3-70b-versatile", - "azure.gpt-4.1", - "openrouter.gpt-4o", - "openrouter.o1-mini", - "openrouter.claude-3-haiku", - "claude-3-opus", - "gpt-4o", - "gpt-4o-mini", - "o1", - "meta-llama/llama-4-maverick-17b-128e-instruct", - "azure.o3-mini", - "openrouter.gpt-4o-mini", - "openrouter.o1", - "claude-3.5-haiku", - "o4-mini", - "azure.gpt-4.1-mini", - "openrouter.o3", - "grok-3-beta", - "o3-mini", - "qwen-qwq", - "azure.o1", - "openrouter.gemini-2.5-flash", - "openrouter.gemini-2.5", - "o1-mini", - "azure.gpt-4o", - "openrouter.gpt-4.1-mini", - "openrouter.claude-3.5-sonnet", - "openrouter.o3-mini", - "gpt-4.1-mini", - "gpt-4.5-preview", - "gpt-4.1-nano", - "deepseek-r1-distill-llama-70b", - "azure.gpt-4o-mini", - "openrouter.gpt-4.1", - "bedrock.claude-3.7-sonnet", - "claude-3-haiku", - "o3", - "gemini-2.0-flash-lite", - "azure.o3", - "azure.gpt-4.5-preview", - "openrouter.claude-3-opus", - "grok-3-mini-fast-beta", - "claude-4-sonnet", - "azure.o4-mini", - "grok-3-fast-beta", - "claude-3.5-sonnet", - "azure.o1-mini", - "openrouter.claude-3.7-sonnet", - "openrouter.gpt-4.5-preview", - "grok-3-mini-beta", - "claude-3.7-sonnet", - "gemini-2.0-flash", - "openrouter.deepseek-r1-free", - "vertexai.gemini-2.5-flash", - "vertexai.gemini-2.5", - "o1-pro", - "gemini-2.5", - "meta-llama/llama-4-scout-17b-16e-instruct", - "azure.gpt-4.1-nano", - "openrouter.gpt-4.1-nano", - "gemini-2.5-flash", - "openrouter.o4-mini", - "openrouter.claude-3.5-haiku", - "claude-4-opus", - "openrouter.o1-pro" - ], - "type": "string" - }, - "reasoningEffort": { - "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", - "enum": ["low", "medium", "high"], - "type": "string" - } - }, - "required": ["model"], - "type": "object" - } - }, + "title": "Crush Configuration", "description": "Configuration schema for the Crush application", + "type": "object", "properties": { "agents": { + "description": "Agent configurations for different tasks", + "type": "object", "additionalProperties": { - "description": "Agent configuration", + "type": "object", "properties": { - "maxTokens": { - "description": "Maximum tokens for the agent", - "minimum": 1, - "type": "integer" + "allowed_lsp": { + "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", + "type": "array", + "items": { + "type": "string" + } + }, + "allowed_mcp": { + "description": "Map of MCP servers this agent can use and their allowed tools", + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "allowed_tools": { + "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", + "type": "array", + "items": { + "type": "string", + "enum": [ + "bash", + "edit", + "fetch", + "glob", + "grep", + "ls", + "sourcegraph", + "view", + "write", + "agent" + ] + } + }, + "context_paths": { + "description": "Custom context paths for this agent (additive to global context paths)", + "type": "array", + "items": { + "type": "string" + } + }, + "description": { + "description": "Description", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "coder", + "task" + ] }, "model": { - "description": "Model ID for the agent", + "description": "Model", + "type": "string", "enum": [ - "gpt-4.1", - "llama-3.3-70b-versatile", - "azure.gpt-4.1", - "openrouter.gpt-4o", - "openrouter.o1-mini", - "openrouter.claude-3-haiku", - "claude-3-opus", - "gpt-4o", - "gpt-4o-mini", - "o1", - "meta-llama/llama-4-maverick-17b-128e-instruct", - "azure.o3-mini", - "openrouter.gpt-4o-mini", - "openrouter.o1", - "claude-3.5-haiku", - "o4-mini", - "azure.gpt-4.1-mini", - "openrouter.o3", - "grok-3-beta", - "o3-mini", - "qwen-qwq", - "azure.o1", - "openrouter.gemini-2.5-flash", - "openrouter.gemini-2.5", - "o1-mini", - "azure.gpt-4o", - "openrouter.gpt-4.1-mini", - "openrouter.claude-3.5-sonnet", - "openrouter.o3-mini", - "gpt-4.1-mini", - "gpt-4.5-preview", - "gpt-4.1-nano", - "deepseek-r1-distill-llama-70b", - "azure.gpt-4o-mini", - "openrouter.gpt-4.1", - "bedrock.claude-3.7-sonnet", - "claude-3-haiku", - "o3", - "gemini-2.0-flash-lite", - "azure.o3", - "azure.gpt-4.5-preview", - "openrouter.claude-3-opus", - "grok-3-mini-fast-beta", - "claude-4-sonnet", - "azure.o4-mini", - "grok-3-fast-beta", - "claude-3.5-sonnet", - "azure.o1-mini", - "openrouter.claude-3.7-sonnet", - "openrouter.gpt-4.5-preview", - "grok-3-mini-beta", - "claude-3.7-sonnet", - "gemini-2.0-flash", - "openrouter.deepseek-r1-free", - "vertexai.gemini-2.5-flash", - "vertexai.gemini-2.5", - "o1-pro", - "gemini-2.5", - "meta-llama/llama-4-scout-17b-16e-instruct", - "azure.gpt-4.1-nano", - "openrouter.gpt-4.1-nano", - "gemini-2.5-flash", - "openrouter.o4-mini", - "openrouter.claude-3.5-haiku", - "claude-4-opus", - "openrouter.o1-pro" - ], - "type": "string" + "large", + "small" + ] }, - "reasoningEffort": { - "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", - "enum": ["low", "medium", "high"], + "name": { + "description": "Name", "type": "string" } }, - "required": ["model"], - "type": "object" - }, - "description": "Agent configurations", - "properties": { - "coder": { - "$ref": "#/definitions/agent" - }, - "task": { - "$ref": "#/definitions/agent" - }, - "title": { - "$ref": "#/definitions/agent" - } - }, - "type": "object" - }, - "contextPaths": { - "default": [ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md" - ], - "description": "Context paths for the application", - "items": { - "type": "string" - }, - "type": "array" - }, - "data": { - "description": "Storage configuration", - "properties": { - "directory": { - "default": ".crush", - "description": "Directory where application data is stored", - "type": "string" - } - }, - "required": ["directory"], - "type": "object" - }, - "debug": { - "default": false, - "description": "Enable debug mode", - "type": "boolean" - }, - "debugLSP": { - "default": false, - "description": "Enable LSP debug mode", - "type": "boolean" + "required": [ + "id", + "name", + "disabled", + "model", + "allowed_tools", + "allowed_mcp", + "allowed_lsp", + "context_paths" + ] + } }, "lsp": { + "description": "Language Server Protocol configurations", + "type": "object", "additionalProperties": { - "description": "LSP configuration for a language", + "type": "object", "properties": { "args": { - "description": "Command arguments for the LSP server", + "description": "Args", + "type": "array", "items": { "type": "string" - }, - "type": "array" + } }, "command": { - "description": "Command to execute for the LSP server", + "description": "Command", "type": "string" }, - "disabled": { - "default": false, - "description": "Whether the LSP is disabled", - "type": "boolean" + "enabled": { + "description": "Disabled", + "type": "boolean", + "default": true }, "options": { - "description": "Additional options for the LSP server", - "type": "object" + "description": "Options" } }, - "required": ["command"], - "type": "object" - }, - "description": "Language Server Protocol configurations", - "type": "object" + "required": [ + "enabled", + "command", + "args", + "options" + ] + } }, - "mcpServers": { + "mcp": { + "description": "Model Control Protocol server configurations", + "type": "object", "additionalProperties": { - "description": "MCP server configuration", + "type": "object", "properties": { "args": { - "description": "Command arguments for the MCP server", + "description": "Args", + "type": "array", "items": { "type": "string" - }, - "type": "array" + } }, "command": { - "description": "Command to execute for the MCP server", + "description": "Command to execute for stdio MCP servers", "type": "string" }, "env": { - "description": "Environment variables for the MCP server", + "description": "Env list", + "type": "array", "items": { "type": "string" - }, - "type": "array" + } }, "headers": { + "description": "HTTP headers for SSE MCP servers", + "type": "object", "additionalProperties": { "type": "string" - }, - "description": "HTTP headers for SSE type MCP servers", - "type": "object" + } }, "type": { - "default": "stdio", - "description": "Type of MCP server", - "enum": ["stdio", "sse"], - "type": "string" + "description": "Type", + "type": "string", + "enum": [ + "stdio", + "sse" + ], + "default": "stdio" }, "url": { - "description": "URL for SSE type MCP servers", + "description": "URL for SSE MCP servers", "type": "string" } }, - "required": ["command"], - "type": "object" + "required": [ + "command", + "env", + "args", + "type", + "url", + "headers" + ] + } + }, + "models": { + "description": "Preferred model configurations for large and small model types", + "type": "object", + "properties": { + "large": { + "description": "Large", + "type": "object", + "properties": { + "max_tokens": { + "description": "Max Tokens", + "type": "integer" + }, + "model_id": { + "description": "Model I D", + "type": "string", + "enum": [ + "claude-3-opus", + "claude-3-haiku", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-preview", + "o1-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "grok-beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-haiku" + ] + }, + "provider": { + "description": "Provider", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "think": { + "description": "Enable think", + "type": "boolean" + } + }, + "required": [ + "model_id", + "provider" + ] + }, + "small": { + "description": "Small", + "$ref": "#/definitions/PreferredModel" + } }, - "description": "Model Control Protocol server configurations", - "type": "object" + "required": [ + "large", + "small" + ] + }, + "options": { + "description": "General application options and settings", + "type": "object", + "properties": { + "context_paths": { + "description": "Context Paths", + "type": "array", + "items": { + "type": "string" + }, + "default": [ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md" + ] + }, + "data_directory": { + "description": "Data Directory", + "type": "string", + "default": ".crush" + }, + "debug": { + "description": "Enable debug", + "type": "boolean", + "default": false + }, + "debug_lsp": { + "description": "Enable debug l s p", + "type": "boolean", + "default": false + }, + "disable_auto_summarize": { + "description": "Disable Auto Summarize", + "type": "boolean", + "default": false + }, + "tui": { + "description": "T U I", + "type": "object", + "properties": { + "compact_mode": { + "description": "Enable compact mode", + "type": "boolean", + "default": false + } + }, + "required": [ + "compact_mode" + ] + } + }, + "required": [ + "context_paths", + "tui", + "debug", + "debug_lsp", + "disable_auto_summarize", + "data_directory" + ] }, "providers": { + "description": "LLM provider configurations", + "type": "object", "additionalProperties": { - "description": "Provider configuration", + "type": "object", "properties": { - "apiKey": { - "description": "API key for the provider", + "api_key": { + "description": "API key for authenticating with the provider", + "type": "string" + }, + "base_url": { + "description": "Base URL for the provider API (required for custom providers)", + "type": "string" + }, + "default_large_model": { + "description": "Default Large Model", + "type": "string" + }, + "default_small_model": { + "description": "Default Small Model", "type": "string" }, "disabled": { - "default": false, - "description": "Whether the provider is disabled", - "type": "boolean" + "description": "Disabled", + "type": "boolean", + "default": false + }, + "extra_headers": { + "description": "Additional HTTP headers to send with requests", + "type": "object", + "additionalProperties": { + "type": "string" + } }, - "provider": { - "description": "Provider type", + "extra_params": { + "description": "Additional provider-specific parameters", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "id": { + "description": "I D", + "type": "string", "enum": [ "anthropic", "openai", "gemini", - "groq", - "openrouter", - "bedrock", - "azure", - "vertexai" - ], + "xai", + "openrouter" + ] + }, + "models": { + "description": "Models", + "type": "array", + "items": { + "type": "object", + "properties": { + "can_reason": { + "description": "Enable can reason", + "type": "boolean" + }, + "context_window": { + "description": "Context Window", + "type": "integer" + }, + "cost_per_1m_in": { + "description": "Cost Per1 M In", + "type": "number" + }, + "cost_per_1m_in_cached": { + "description": "Cost Per1 M In Cached", + "type": "number" + }, + "cost_per_1m_out": { + "description": "Cost Per1 M Out", + "type": "number" + }, + "cost_per_1m_out_cached": { + "description": "Cost Per1 M Out Cached", + "type": "number" + }, + "default_max_tokens": { + "description": "Default Max Tokens", + "type": "integer" + }, + "has_reasoning_effort": { + "description": "Enable has reasoning effort", + "type": "boolean" + }, + "id": { + "description": "I D", + "type": "string" + }, + "model": { + "description": "Name", + "type": "string" + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "supports_attachments": { + "description": "Enable supports images", + "type": "boolean" + } + }, + "required": [ + "id", + "model", + "cost_per_1m_in", + "cost_per_1m_out", + "cost_per_1m_in_cached", + "cost_per_1m_out_cached", + "context_window", + "default_max_tokens", + "can_reason", + "reasoning_effort", + "has_reasoning_effort", + "supports_attachments" + ] + } + }, + "provider_type": { + "description": "Provider Type", + "type": "string" + } + }, + "required": [ + "id", + "provider_type", + "disabled" + ] + } + } + }, + "required": [ + "models", + "options" + ], + "definitions": { + "Agent": { + "type": "object", + "properties": { + "allowed_lsp": { + "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", + "type": "array", + "items": { + "type": "string" + } + }, + "allowed_mcp": { + "description": "Map of MCP servers this agent can use and their allowed tools", + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "allowed_tools": { + "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", + "type": "array", + "items": { + "type": "string", + "enum": [ + "bash", + "edit", + "fetch", + "glob", + "grep", + "ls", + "sourcegraph", + "view", + "write", + "agent" + ] + } + }, + "context_paths": { + "description": "Custom context paths for this agent (additive to global context paths)", + "type": "array", + "items": { "type": "string" } }, - "type": "object" + "description": { + "description": "Description", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "coder", + "task" + ] + }, + "model": { + "description": "Model", + "type": "string", + "enum": [ + "large", + "small" + ] + }, + "name": { + "description": "Name", + "type": "string" + } }, - "description": "LLM provider configurations", - "type": "object" + "required": [ + "id", + "name", + "disabled", + "model", + "allowed_tools", + "allowed_mcp", + "allowed_lsp", + "context_paths" + ] + }, + "Config": { + "type": "object", + "properties": { + "agents": { + "description": "Agent configurations for different tasks", + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "allowed_lsp": { + "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", + "type": "array", + "items": { + "type": "string" + } + }, + "allowed_mcp": { + "description": "Map of MCP servers this agent can use and their allowed tools", + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "allowed_tools": { + "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", + "type": "array", + "items": { + "type": "string", + "enum": [ + "bash", + "edit", + "fetch", + "glob", + "grep", + "ls", + "sourcegraph", + "view", + "write", + "agent" + ] + } + }, + "context_paths": { + "description": "Custom context paths for this agent (additive to global context paths)", + "type": "array", + "items": { + "type": "string" + } + }, + "description": { + "description": "Description", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "coder", + "task" + ] + }, + "model": { + "description": "Model", + "type": "string", + "enum": [ + "large", + "small" + ] + }, + "name": { + "description": "Name", + "type": "string" + } + }, + "required": [ + "id", + "name", + "disabled", + "model", + "allowed_tools", + "allowed_mcp", + "allowed_lsp", + "context_paths" + ] + } + }, + "lsp": { + "description": "Language Server Protocol configurations", + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "args": { + "description": "Args", + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "description": "Command", + "type": "string" + }, + "enabled": { + "description": "Disabled", + "type": "boolean", + "default": true + }, + "options": { + "description": "Options" + } + }, + "required": [ + "enabled", + "command", + "args", + "options" + ] + } + }, + "mcp": { + "description": "Model Control Protocol server configurations", + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "args": { + "description": "Args", + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "description": "Command to execute for stdio MCP servers", + "type": "string" + }, + "env": { + "description": "Env list", + "type": "array", + "items": { + "type": "string" + } + }, + "headers": { + "description": "HTTP headers for SSE MCP servers", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "type": { + "description": "Type", + "type": "string", + "enum": [ + "stdio", + "sse" + ], + "default": "stdio" + }, + "url": { + "description": "URL for SSE MCP servers", + "type": "string" + } + }, + "required": [ + "command", + "env", + "args", + "type", + "url", + "headers" + ] + } + }, + "models": { + "description": "Preferred model configurations for large and small model types", + "type": "object", + "properties": { + "large": { + "description": "Large", + "type": "object", + "properties": { + "max_tokens": { + "description": "Max Tokens", + "type": "integer" + }, + "model_id": { + "description": "Model I D", + "type": "string", + "enum": [ + "claude-3-opus", + "claude-3-haiku", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-preview", + "o1-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "grok-beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-haiku" + ] + }, + "provider": { + "description": "Provider", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "think": { + "description": "Enable think", + "type": "boolean" + } + }, + "required": [ + "model_id", + "provider" + ] + }, + "small": { + "description": "Small", + "$ref": "#/definitions/PreferredModel" + } + }, + "required": [ + "large", + "small" + ] + }, + "options": { + "description": "General application options and settings", + "type": "object", + "properties": { + "context_paths": { + "description": "Context Paths", + "type": "array", + "items": { + "type": "string" + }, + "default": [ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md" + ] + }, + "data_directory": { + "description": "Data Directory", + "type": "string", + "default": ".crush" + }, + "debug": { + "description": "Enable debug", + "type": "boolean", + "default": false + }, + "debug_lsp": { + "description": "Enable debug l s p", + "type": "boolean", + "default": false + }, + "disable_auto_summarize": { + "description": "Disable Auto Summarize", + "type": "boolean", + "default": false + }, + "tui": { + "description": "T U I", + "type": "object", + "properties": { + "compact_mode": { + "description": "Enable compact mode", + "type": "boolean", + "default": false + } + }, + "required": [ + "compact_mode" + ] + } + }, + "required": [ + "context_paths", + "tui", + "debug", + "debug_lsp", + "disable_auto_summarize", + "data_directory" + ] + }, + "providers": { + "description": "LLM provider configurations", + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "api_key": { + "description": "API key for authenticating with the provider", + "type": "string" + }, + "base_url": { + "description": "Base URL for the provider API (required for custom providers)", + "type": "string" + }, + "default_large_model": { + "description": "Default Large Model", + "type": "string" + }, + "default_small_model": { + "description": "Default Small Model", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "extra_headers": { + "description": "Additional HTTP headers to send with requests", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "extra_params": { + "description": "Additional provider-specific parameters", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "models": { + "description": "Models", + "type": "array", + "items": { + "type": "object", + "properties": { + "can_reason": { + "description": "Enable can reason", + "type": "boolean" + }, + "context_window": { + "description": "Context Window", + "type": "integer" + }, + "cost_per_1m_in": { + "description": "Cost Per1 M In", + "type": "number" + }, + "cost_per_1m_in_cached": { + "description": "Cost Per1 M In Cached", + "type": "number" + }, + "cost_per_1m_out": { + "description": "Cost Per1 M Out", + "type": "number" + }, + "cost_per_1m_out_cached": { + "description": "Cost Per1 M Out Cached", + "type": "number" + }, + "default_max_tokens": { + "description": "Default Max Tokens", + "type": "integer" + }, + "has_reasoning_effort": { + "description": "Enable has reasoning effort", + "type": "boolean" + }, + "id": { + "description": "I D", + "type": "string" + }, + "model": { + "description": "Name", + "type": "string" + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "supports_attachments": { + "description": "Enable supports images", + "type": "boolean" + } + }, + "required": [ + "id", + "model", + "cost_per_1m_in", + "cost_per_1m_out", + "cost_per_1m_in_cached", + "cost_per_1m_out_cached", + "context_window", + "default_max_tokens", + "can_reason", + "reasoning_effort", + "has_reasoning_effort", + "supports_attachments" + ] + } + }, + "provider_type": { + "description": "Provider Type", + "type": "string" + } + }, + "required": [ + "id", + "provider_type", + "disabled" + ] + } + } + }, + "required": [ + "models", + "options" + ] }, - "tui": { - "description": "Terminal User Interface configuration", + "LSPConfig": { + "type": "object", "properties": { - "theme": { - "default": "crush", - "description": "TUI theme name", + "args": { + "description": "Args", + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "description": "Command", + "type": "string" + }, + "enabled": { + "description": "Disabled", + "type": "boolean", + "default": true + }, + "options": { + "description": "Options" + } + }, + "required": [ + "enabled", + "command", + "args", + "options" + ] + }, + "MCP": { + "type": "object", + "properties": { + "args": { + "description": "Args", + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "description": "Command to execute for stdio MCP servers", + "type": "string" + }, + "env": { + "description": "Env list", + "type": "array", + "items": { + "type": "string" + } + }, + "headers": { + "description": "HTTP headers for SSE MCP servers", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "type": { + "description": "Type", + "type": "string", "enum": [ - "crush", - "catppuccin", - "dracula", - "flexoki", - "gruvbox", - "monokai", - "onedark", - "tokyonight", - "tron" + "stdio", + "sse" ], + "default": "stdio" + }, + "url": { + "description": "URL for SSE MCP servers", + "type": "string" + } + }, + "required": [ + "command", + "env", + "args", + "type", + "url", + "headers" + ] + }, + "Model": { + "type": "object", + "properties": { + "can_reason": { + "description": "Enable can reason", + "type": "boolean" + }, + "context_window": { + "description": "Context Window", + "type": "integer" + }, + "cost_per_1m_in": { + "description": "Cost Per1 M In", + "type": "number" + }, + "cost_per_1m_in_cached": { + "description": "Cost Per1 M In Cached", + "type": "number" + }, + "cost_per_1m_out": { + "description": "Cost Per1 M Out", + "type": "number" + }, + "cost_per_1m_out_cached": { + "description": "Cost Per1 M Out Cached", + "type": "number" + }, + "default_max_tokens": { + "description": "Default Max Tokens", + "type": "integer" + }, + "has_reasoning_effort": { + "description": "Enable has reasoning effort", + "type": "boolean" + }, + "id": { + "description": "I D", + "type": "string" + }, + "model": { + "description": "Name", + "type": "string" + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "supports_attachments": { + "description": "Enable supports images", + "type": "boolean" + } + }, + "required": [ + "id", + "model", + "cost_per_1m_in", + "cost_per_1m_out", + "cost_per_1m_in_cached", + "cost_per_1m_out_cached", + "context_window", + "default_max_tokens", + "can_reason", + "reasoning_effort", + "has_reasoning_effort", + "supports_attachments" + ] + }, + "Options": { + "description": "General application options and settings", + "type": "object", + "properties": { + "context_paths": { + "description": "Context Paths", + "type": "array", + "items": { + "type": "string" + }, + "default": [ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md" + ] + }, + "data_directory": { + "description": "Data Directory", + "type": "string", + "default": ".crush" + }, + "debug": { + "description": "Enable debug", + "type": "boolean", + "default": false + }, + "debug_lsp": { + "description": "Enable debug l s p", + "type": "boolean", + "default": false + }, + "disable_auto_summarize": { + "description": "Disable Auto Summarize", + "type": "boolean", + "default": false + }, + "tui": { + "description": "T U I", + "type": "object", + "properties": { + "compact_mode": { + "description": "Enable compact mode", + "type": "boolean", + "default": false + } + }, + "required": [ + "compact_mode" + ] + } + }, + "required": [ + "context_paths", + "tui", + "debug", + "debug_lsp", + "disable_auto_summarize", + "data_directory" + ] + }, + "PreferredModel": { + "description": "Large", + "type": "object", + "properties": { + "max_tokens": { + "description": "Max Tokens", + "type": "integer" + }, + "model_id": { + "description": "Model I D", + "type": "string", + "enum": [ + "claude-3-opus", + "claude-3-haiku", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-preview", + "o1-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "grok-beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-haiku" + ] + }, + "provider": { + "description": "Provider", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "think": { + "description": "Enable think", + "type": "boolean" + } + }, + "required": [ + "model_id", + "provider" + ] + }, + "PreferredModels": { + "description": "Preferred model configurations for large and small model types", + "type": "object", + "properties": { + "large": { + "description": "Large", + "type": "object", + "properties": { + "max_tokens": { + "description": "Max Tokens", + "type": "integer" + }, + "model_id": { + "description": "Model I D", + "type": "string", + "enum": [ + "claude-3-opus", + "claude-3-haiku", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "gpt-4", + "gpt-3.5-turbo", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-preview", + "o1-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "grok-beta", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3.5-haiku" + ] + }, + "provider": { + "description": "Provider", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "think": { + "description": "Enable think", + "type": "boolean" + } + }, + "required": [ + "model_id", + "provider" + ] + }, + "small": { + "description": "Small", + "$ref": "#/definitions/PreferredModel" + } + }, + "required": [ + "large", + "small" + ] + }, + "ProviderConfig": { + "type": "object", + "properties": { + "api_key": { + "description": "API key for authenticating with the provider", + "type": "string" + }, + "base_url": { + "description": "Base URL for the provider API (required for custom providers)", + "type": "string" + }, + "default_large_model": { + "description": "Default Large Model", + "type": "string" + }, + "default_small_model": { + "description": "Default Small Model", + "type": "string" + }, + "disabled": { + "description": "Disabled", + "type": "boolean", + "default": false + }, + "extra_headers": { + "description": "Additional HTTP headers to send with requests", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "extra_params": { + "description": "Additional provider-specific parameters", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "id": { + "description": "I D", + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "xai", + "openrouter" + ] + }, + "models": { + "description": "Models", + "type": "array", + "items": { + "type": "object", + "properties": { + "can_reason": { + "description": "Enable can reason", + "type": "boolean" + }, + "context_window": { + "description": "Context Window", + "type": "integer" + }, + "cost_per_1m_in": { + "description": "Cost Per1 M In", + "type": "number" + }, + "cost_per_1m_in_cached": { + "description": "Cost Per1 M In Cached", + "type": "number" + }, + "cost_per_1m_out": { + "description": "Cost Per1 M Out", + "type": "number" + }, + "cost_per_1m_out_cached": { + "description": "Cost Per1 M Out Cached", + "type": "number" + }, + "default_max_tokens": { + "description": "Default Max Tokens", + "type": "integer" + }, + "has_reasoning_effort": { + "description": "Enable has reasoning effort", + "type": "boolean" + }, + "id": { + "description": "I D", + "type": "string" + }, + "model": { + "description": "Name", + "type": "string" + }, + "reasoning_effort": { + "description": "Reasoning Effort", + "type": "string" + }, + "supports_attachments": { + "description": "Enable supports images", + "type": "boolean" + } + }, + "required": [ + "id", + "model", + "cost_per_1m_in", + "cost_per_1m_out", + "cost_per_1m_in_cached", + "cost_per_1m_out_cached", + "context_window", + "default_max_tokens", + "can_reason", + "reasoning_effort", + "has_reasoning_effort", + "supports_attachments" + ] + } + }, + "provider_type": { + "description": "Provider Type", "type": "string" } }, - "type": "object" + "required": [ + "id", + "provider_type", + "disabled" + ] }, - "wd": { - "description": "Working directory for the application", - "type": "string" + "TUIOptions": { + "description": "T U I", + "type": "object", + "properties": { + "compact_mode": { + "description": "Enable compact mode", + "type": "boolean", + "default": false + } + }, + "required": [ + "compact_mode" + ] } - }, - "title": "Crush Configuration", - "type": "object" + } } diff --git a/internal/config/config.go b/internal/config/config.go index 69a528d3c57eba4ef4b8802d800bfb0a7a764c14..3caf9f01c4afdba4dd2c29c43fc690dd360173ef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -732,6 +732,34 @@ func defaultConfigBasedOnEnv() *Config { "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), } + // Find the VertexAI provider definition to get default models + for _, p := range providers { + if p.ID == provider.InferenceProviderVertexAI { + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + for _, model := range p.Models { + configModel := Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + } + // Set reasoning effort for reasoning models + if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { + configModel.HasReasoningEffort = model.HasReasoningEffort + configModel.ReasoningEffort = model.DefaultReasoningEffort + } + providerConfig.Models = append(providerConfig.Models, configModel) + } + break + } + } cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig } @@ -743,6 +771,34 @@ func defaultConfigBasedOnEnv() *Config { if providerConfig.ExtraParams["region"] == "" { providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION") } + // Find the Bedrock provider definition to get default models + for _, p := range providers { + if p.ID == provider.InferenceProviderBedrock { + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + for _, model := range p.Models { + configModel := Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + } + // Set reasoning effort for reasoning models + if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { + configModel.HasReasoningEffort = model.HasReasoningEffort + configModel.ReasoningEffort = model.DefaultReasoningEffort + } + providerConfig.Models = append(providerConfig.Models, configModel) + } + break + } + } cfg.Providers[provider.InferenceProviderBedrock] = providerConfig } return cfg diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f69b3c84701b07c6df948c7abca6e37a65e3c69e..b48a9eba0a92a9f9239d6f6e3526c24cc8790ac9 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -93,7 +93,7 @@ func TestInit_SingletonBehavior(t *testing.T) { require.NoError(t, err1) require.NoError(t, err2) - assert.Same(t, cfg1, cfg2) // Should be the same instance + assert.Same(t, cfg1, cfg2) } func TestGet_BeforeInitialization(t *testing.T) { @@ -124,7 +124,7 @@ func TestLoadConfig_NoConfigFiles(t *testing.T) { cfg, err := Init(cwdDir, false) require.NoError(t, err) - assert.Len(t, cfg.Providers, 0) // No providers without env vars or config files + assert.Len(t, cfg.Providers, 0) assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) } @@ -133,7 +133,6 @@ func TestLoadConfig_OnlyGlobalConfig(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create global config file globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -187,7 +186,6 @@ func TestLoadConfig_OnlyLocalConfig(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create local config file localConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderAnthropic: { @@ -239,7 +237,6 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create global config globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -279,7 +276,6 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config that overrides and adds localConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -327,14 +323,11 @@ func TestLoadConfig_BothGlobalAndLocal(t *testing.T) { require.NoError(t, err) assert.Len(t, cfg.Providers, 2) - // Check that local config overrode global openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] assert.Equal(t, "local-key", openaiProvider.APIKey) - // Check that local config added new provider assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) - // Check that context paths were merged assert.Contains(t, cfg.Options.ContextPaths, "global-context.md") assert.Contains(t, cfg.Options.ContextPaths, "local-context.md") assert.True(t, cfg.Options.TUI.CompactMode) @@ -345,7 +338,6 @@ func TestLoadConfig_MalformedGlobalJSON(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create malformed global config configPath := filepath.Join(testConfigDir, "crush.json") require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644)) @@ -359,7 +351,6 @@ func TestLoadConfig_MalformedLocalJSON(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create malformed local config localConfigPath := filepath.Join(cwdDir, "crush.json") require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644)) @@ -409,7 +400,6 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set all supported API keys os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") os.Setenv("OPENAI_API_KEY", "test-openai-key") os.Setenv("GEMINI_API_KEY", "test-gemini-key") @@ -421,7 +411,6 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { require.NoError(t, err) assert.Len(t, cfg.Providers, 5) - // Verify each provider is configured correctly anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey) assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType) @@ -449,7 +438,6 @@ func TestEnvVars_PartialEnvironmentVariables(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set only some API keys os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") os.Setenv("OPENAI_API_KEY", "test-openai-key") @@ -467,7 +455,6 @@ func TestEnvVars_VertexAIConfiguration(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set VertexAI environment variables os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") @@ -488,7 +475,6 @@ func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set Google Cloud vars but not the use flag os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") @@ -503,7 +489,6 @@ func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set AWS credentials os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key") os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") os.Setenv("AWS_DEFAULT_REGION", "us-east-1") @@ -523,7 +508,6 @@ func TestEnvVars_AWSBedrockWithProfile(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set AWS profile os.Setenv("AWS_PROFILE", "test-profile") os.Setenv("AWS_REGION", "eu-west-1") @@ -541,7 +525,6 @@ func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set AWS container credentials os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test") os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1") @@ -556,7 +539,6 @@ func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set both region variables - AWS_DEFAULT_REGION should take priority os.Setenv("AWS_ACCESS_KEY_ID", "test-key") os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") os.Setenv("AWS_DEFAULT_REGION", "us-west-2") @@ -574,7 +556,6 @@ func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set only AWS_REGION (not AWS_DEFAULT_REGION) os.Setenv("AWS_ACCESS_KEY_ID", "test-key") os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") os.Setenv("AWS_REGION", "us-east-1") @@ -591,7 +572,6 @@ func TestEnvVars_NoAWSCredentials(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Don't set any AWS credentials cfg, err := Init(cwdDir, false) require.NoError(t, err) @@ -603,15 +583,12 @@ func TestEnvVars_CustomEnvironmentVariables(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Test that environment variables are properly resolved from provider definitions - // This test assumes the provider system uses $VARIABLE_NAME format os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key") cfg, err := Init(cwdDir, false) require.NoError(t, err) if len(cfg.Providers) > 0 { - // Verify that the environment variable was resolved if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists { assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey) } @@ -623,11 +600,11 @@ func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set multiple types of environment variables os.Setenv("ANTHROPIC_API_KEY", "test-anthropic") os.Setenv("OPENAI_API_KEY", "test-openai") os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true") os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key") os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret") os.Setenv("AWS_DEFAULT_REGION", "us-west-1") @@ -636,7 +613,6 @@ func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) { require.NoError(t, err) - // Should have API key providers + VertexAI + Bedrock expectedProviders := []provider.InferenceProvider{ provider.InferenceProviderAnthropic, provider.InferenceProviderOpenAI, @@ -696,14 +672,11 @@ func TestHasAWSCredentials_NoCredentials(t *testing.T) { assert.False(t, hasAWSCredentials()) } -// Provider Configuration Tests - func TestProviderMerging_GlobalToBase(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create global config with provider globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -719,6 +692,12 @@ func TestProviderMerging_GlobalToBase(t *testing.T) { ContextWindow: 8192, DefaultMaxTokens: 4096, }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, }, }, }, @@ -739,7 +718,7 @@ func TestProviderMerging_GlobalToBase(t *testing.T) { assert.Equal(t, "global-openai-key", openaiProvider.APIKey) assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel) assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel) - assert.Len(t, openaiProvider.Models, 1) + assert.Len(t, openaiProvider.Models, 2) } func TestProviderMerging_LocalToBase(t *testing.T) { @@ -747,7 +726,6 @@ func TestProviderMerging_LocalToBase(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create local config with provider localConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderAnthropic: { @@ -755,6 +733,25 @@ func TestProviderMerging_LocalToBase(t *testing.T) { APIKey: "local-anthropic-key", ProviderType: provider.TypeAnthropic, DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CostPer1MIn: 15.0, + CostPer1MOut: 75.0, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CostPer1MIn: 0.25, + CostPer1MOut: 1.25, + }, + }, }, }, } @@ -772,6 +769,8 @@ func TestProviderMerging_LocalToBase(t *testing.T) { anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey) assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel) + assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModel) + assert.Len(t, anthropicProvider.Models, 2) } func TestProviderMerging_ConflictingSettings(t *testing.T) { @@ -779,7 +778,6 @@ func TestProviderMerging_ConflictingSettings(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Create global config globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { @@ -788,6 +786,26 @@ func TestProviderMerging_ConflictingSettings(t *testing.T) { ProviderType: provider.TypeOpenAI, DefaultLargeModel: "gpt-4", DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + { + ID: "gpt-4-turbo", + Name: "GPT-4 Turbo", + ContextWindow: 128000, + DefaultMaxTokens: 4096, + }, + }, }, }, } @@ -804,7 +822,6 @@ func TestProviderMerging_ConflictingSettings(t *testing.T) { provider.InferenceProviderOpenAI: { APIKey: "local-key", DefaultLargeModel: "gpt-4-turbo", - // Test disabled separately - don't disable here as it causes nil pointer }, }, } @@ -819,11 +836,9 @@ func TestProviderMerging_ConflictingSettings(t *testing.T) { require.NoError(t, err) openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - // Local should override global assert.Equal(t, "local-key", openaiProvider.APIKey) assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel) - assert.False(t, openaiProvider.Disabled) // Should not be disabled - // Global values should remain where not overridden + assert.False(t, openaiProvider.Disabled) assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel) } @@ -834,22 +849,51 @@ func TestProviderMerging_CustomVsKnownProviders(t *testing.T) { customProviderID := provider.InferenceProvider("custom-provider") - // Create config with both known and custom providers globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ - // Known provider - some fields should not be overrideable provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - BaseURL: "should-not-override", - ProviderType: provider.TypeAnthropic, // Should not override + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + BaseURL: "should-not-override", + ProviderType: provider.TypeAnthropic, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, }, - // Custom provider - all fields should be configurable customProviderID: { - ID: customProviderID, - APIKey: "custom-key", - BaseURL: "https://custom.api.com", - ProviderType: provider.TypeOpenAI, + ID: customProviderID, + APIKey: "custom-key", + BaseURL: "https://custom.api.com", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "custom-large", + DefaultSmallModel: "custom-small", + Models: []Model{ + { + ID: "custom-large", + Name: "Custom Large", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "custom-small", + Name: "Custom Small", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, }, }, } @@ -882,14 +926,12 @@ func TestProviderMerging_CustomVsKnownProviders(t *testing.T) { require.NoError(t, err) - // Known provider should not have BaseURL/ProviderType overridden openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL) assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType) - // Custom provider should have all fields configurable customProvider := cfg.Providers[customProviderID] - assert.Equal(t, "custom-key", customProvider.APIKey) // Should preserve from global + assert.Equal(t, "custom-key", customProvider.APIKey) assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL) assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType) } @@ -901,14 +943,12 @@ func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) { customProviderID := provider.InferenceProvider("custom-provider") - // Create config with custom provider missing BaseURL globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ customProviderID: { ID: customProviderID, APIKey: "custom-key", ProviderType: provider.TypeOpenAI, - // Missing BaseURL }, }, } @@ -922,7 +962,6 @@ func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) { cfg, err := Init(cwdDir, false) require.NoError(t, err) - // Provider should be filtered out due to validation failure assert.NotContains(t, cfg.Providers, customProviderID) } @@ -939,7 +978,6 @@ func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) { ID: customProviderID, BaseURL: "https://custom.api.com", ProviderType: provider.TypeOpenAI, - // Missing APIKey }, }, } @@ -994,10 +1032,26 @@ func TestProviderValidation_KnownProviderValid(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - // BaseURL not required for known providers + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, + }, }, } @@ -1022,10 +1076,48 @@ func TestProviderValidation_DisabledProvider(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - Disabled: true, + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + Disabled: true, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-3.5-turbo", + Models: []Model{ + { + ID: "gpt-4", + Name: "GPT-4", + ContextWindow: 8192, + DefaultMaxTokens: 4096, + }, + { + ID: "gpt-3.5-turbo", + Name: "GPT-3.5 Turbo", + ContextWindow: 4096, + DefaultMaxTokens: 2048, + }, + }, + }, + provider.InferenceProviderAnthropic: { + ID: provider.InferenceProviderAnthropic, + APIKey: "anthropic-key", + ProviderType: provider.TypeAnthropic, + Disabled: false, // This one is enabled + DefaultLargeModel: "claude-3-opus", + DefaultSmallModel: "claude-3-haiku", + Models: []Model{ + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + { + ID: "claude-3-haiku", + Name: "Claude 3 Haiku", + ContextWindow: 200000, + DefaultMaxTokens: 4096, + }, + }, }, }, } @@ -1039,9 +1131,10 @@ func TestProviderValidation_DisabledProvider(t *testing.T) { cfg, err := Init(cwdDir, false) require.NoError(t, err) - // Disabled providers should still be in the config but marked as disabled assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI) assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled) + assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic) + assert.False(t, cfg.Providers[provider.InferenceProviderAnthropic].Disabled) } func TestProviderModels_AddingNewModels(t *testing.T) { @@ -1052,9 +1145,11 @@ func TestProviderModels_AddingNewModels(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-4-turbo", Models: []Model{ { ID: "gpt-4", @@ -1098,7 +1193,7 @@ func TestProviderModels_AddingNewModels(t *testing.T) { require.NoError(t, err) openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.Len(t, openaiProvider.Models, 2) // Should have both models + assert.Len(t, openaiProvider.Models, 2) modelIDs := make([]string, len(openaiProvider.Models)) for i, model := range openaiProvider.Models { @@ -1116,9 +1211,11 @@ func TestProviderModels_DuplicateModelHandling(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-4", Models: []Model{ { ID: "gpt-4", @@ -1136,7 +1233,7 @@ func TestProviderModels_DuplicateModelHandling(t *testing.T) { provider.InferenceProviderOpenAI: { Models: []Model{ { - ID: "gpt-4", // Same ID as global + ID: "gpt-4", Name: "GPT-4 Updated", ContextWindow: 16384, DefaultMaxTokens: 8192, @@ -1162,13 +1259,12 @@ func TestProviderModels_DuplicateModelHandling(t *testing.T) { require.NoError(t, err) openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI] - assert.Len(t, openaiProvider.Models, 1) // Should not duplicate + assert.Len(t, openaiProvider.Models, 1) - // Should keep the original model (global config) model := openaiProvider.Models[0] assert.Equal(t, "gpt-4", model.ID) - assert.Equal(t, "GPT-4", model.Name) // Original name - assert.Equal(t, int64(8192), model.ContextWindow) // Original context window + assert.Equal(t, "GPT-4", model.Name) + assert.Equal(t, int64(8192), model.ContextWindow) } func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { @@ -1179,9 +1275,11 @@ func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, + ID: provider.InferenceProviderOpenAI, + APIKey: "openai-key", + ProviderType: provider.TypeOpenAI, + DefaultLargeModel: "gpt-4", + DefaultSmallModel: "gpt-4", Models: []Model{ { ID: "gpt-4", @@ -1224,14 +1322,11 @@ func TestProviderModels_ModelCostAndCapabilities(t *testing.T) { assert.True(t, model.SupportsImages) } -// Agent Configuration Tests - func TestDefaultAgents_CoderAgent(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider so we can test agent configuration os.Setenv("ANTHROPIC_API_KEY", "test-key") cfg, err := Init(cwdDir, false) @@ -1246,7 +1341,6 @@ func TestDefaultAgents_CoderAgent(t *testing.T) { assert.Equal(t, LargeModel, coderAgent.Model) assert.False(t, coderAgent.Disabled) assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths) - // Coder agent should have all tools available (nil means all tools) assert.Nil(t, coderAgent.AllowedTools) } @@ -1255,7 +1349,6 @@ func TestDefaultAgents_TaskAgent(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider so we can test agent configuration os.Setenv("ANTHROPIC_API_KEY", "test-key") cfg, err := Init(cwdDir, false) @@ -1271,11 +1364,9 @@ func TestDefaultAgents_TaskAgent(t *testing.T) { assert.False(t, taskAgent.Disabled) assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths) - // Task agent should have restricted tools expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"} assert.Equal(t, expectedTools, taskAgent.AllowedTools) - // Task agent should have no MCPs or LSPs by default assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP) assert.Equal(t, []string{}, taskAgent.AllowedLSP) } @@ -1285,10 +1376,8 @@ func TestAgentMerging_CustomAgent(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config with custom agent globalConfig := Config{ Agents: map[AgentID]Agent{ AgentID("custom-agent"): { @@ -1302,6 +1391,23 @@ func TestAgentMerging_CustomAgent(t *testing.T) { ContextPaths: []string{"custom-context.md"}, }, }, + MCP: map[string]MCP{ + "mcp1": { + Type: MCPStdio, + Command: "test-mcp-command", + Args: []string{"--test"}, + }, + }, + LSP: map[string]LSPConfig{ + "typescript": { + Command: "typescript-language-server", + Args: []string{"--stdio"}, + }, + "go": { + Command: "gopls", + Args: []string{}, + }, + }, } configPath := filepath.Join(testConfigDir, "crush.json") @@ -1314,7 +1420,6 @@ func TestAgentMerging_CustomAgent(t *testing.T) { require.NoError(t, err) - // Should have default agents plus custom agent assert.Contains(t, cfg.Agents, AgentCoder) assert.Contains(t, cfg.Agents, AgentTask) assert.Contains(t, cfg.Agents, AgentID("custom-agent")) @@ -1326,7 +1431,6 @@ func TestAgentMerging_CustomAgent(t *testing.T) { assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools) assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP) assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP) - // Context paths should be additive (default + custom) expectedContextPaths := append(defaultContextPaths, "custom-context.md") assert.Equal(t, expectedContextPaths, customAgent.ContextPaths) } @@ -1336,17 +1440,28 @@ func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config that modifies the default coder agent globalConfig := Config{ Agents: map[AgentID]Agent{ AgentCoder: { - Model: SmallModel, // Change from default LargeModel + Model: SmallModel, AllowedMCP: map[string][]string{"mcp1": {"tool1"}}, AllowedLSP: []string{"typescript"}, - ContextPaths: []string{"coder-specific.md"}, // Should be additive + ContextPaths: []string{"coder-specific.md"}, + }, + }, + MCP: map[string]MCP{ + "mcp1": { + Type: MCPStdio, + Command: "test-mcp-command", + Args: []string{"--test"}, + }, + }, + LSP: map[string]LSPConfig{ + "typescript": { + Command: "typescript-language-server", + Args: []string{"--stdio"}, }, }, } @@ -1362,16 +1477,13 @@ func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) { require.NoError(t, err) coderAgent := cfg.Agents[AgentCoder] - // Should preserve default values for unspecified fields assert.Equal(t, AgentCoder, coderAgent.ID) assert.Equal(t, "Coder", coderAgent.Name) assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description) - // Context paths should be additive (default + custom) expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md") assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) - // Should update specified fields assert.Equal(t, SmallModel, coderAgent.Model) assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP) assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP) @@ -1382,22 +1494,31 @@ func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config that modifies the default task agent - // Note: Only model, MCP, and LSP should be configurable for known agents globalConfig := Config{ Agents: map[AgentID]Agent{ AgentTask: { - Model: SmallModel, // Should be updated - AllowedMCP: map[string][]string{"search-mcp": nil}, // Should be updated - AllowedLSP: []string{"python"}, // Should be updated - // These should be ignored for known agents: - Name: "Search Agent", // Should be ignored - Description: "Custom search agent", // Should be ignored - Disabled: true, // Should be ignored - AllowedTools: []string{"glob", "grep", "view"}, // Should be ignored + Model: SmallModel, + AllowedMCP: map[string][]string{"search-mcp": nil}, + AllowedLSP: []string{"python"}, + Name: "Search Agent", + Description: "Custom search agent", + Disabled: true, + AllowedTools: []string{"glob", "grep", "view"}, + }, + }, + MCP: map[string]MCP{ + "search-mcp": { + Type: MCPStdio, + Command: "search-mcp-command", + Args: []string{"--search"}, + }, + }, + LSP: map[string]LSPConfig{ + "python": { + Command: "pylsp", + Args: []string{}, }, }, } @@ -1413,13 +1534,11 @@ func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) { require.NoError(t, err) taskAgent := cfg.Agents[AgentTask] - // Should preserve default values for protected fields - assert.Equal(t, "Task", taskAgent.Name) // Should remain default - assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) // Should remain default - assert.False(t, taskAgent.Disabled) // Should remain default - assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) // Should remain default + assert.Equal(t, "Task", taskAgent.Name) + assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) + assert.False(t, taskAgent.Disabled) + assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) - // Should update configurable fields assert.Equal(t, SmallModel, taskAgent.Model) assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP) assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP) @@ -1430,10 +1549,8 @@ func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with custom agent globalConfig := Config{ Agents: map[AgentID]Agent{ AgentID("test-agent"): { @@ -1464,6 +1581,13 @@ func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { AllowedMCP: map[string][]string{"local-mcp": {"tool1"}}, }, }, + MCP: map[string]MCP{ + "local-mcp": { + Type: MCPStdio, + Command: "local-mcp-command", + Args: []string{"--local"}, + }, + }, } localConfigPath := filepath.Join(cwdDir, "crush.json") @@ -1476,7 +1600,6 @@ func TestAgentMerging_LocalOverridesGlobal(t *testing.T) { require.NoError(t, err) testAgent := cfg.Agents[AgentID("test-agent")] - // Local should override global assert.Equal(t, "Local Agent", testAgent.Name) assert.Equal(t, "Local description", testAgent.Description) assert.Equal(t, SmallModel, testAgent.Model) @@ -1490,10 +1613,8 @@ func TestAgentModelTypeAssignment(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config with agents using different model types globalConfig := Config{ Agents: map[AgentID]Agent{ AgentID("large-agent"): { @@ -1509,7 +1630,6 @@ func TestAgentModelTypeAssignment(t *testing.T) { AgentID("default-agent"): { ID: AgentID("default-agent"), Name: "Default Model Agent", - // No model specified - should default to LargeModel }, }, } @@ -1526,7 +1646,7 @@ func TestAgentModelTypeAssignment(t *testing.T) { assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model) assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model) - assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) // Should default to LargeModel + assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) } func TestAgentContextPathOverrides(t *testing.T) { @@ -1534,10 +1654,8 @@ func TestAgentContextPathOverrides(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config with custom context paths globalConfig := Config{ Options: Options{ ContextPaths: []string{"global-context.md", "shared-context.md"}, @@ -1551,7 +1669,6 @@ func TestAgentContextPathOverrides(t *testing.T) { AgentID("default-context-agent"): { ID: AgentID("default-context-agent"), Name: "Default Context Agent", - // No ContextPaths specified - should use global }, }, } @@ -1566,32 +1683,25 @@ func TestAgentContextPathOverrides(t *testing.T) { require.NoError(t, err) - // Agent with custom context paths should have default + global + custom paths (additive) customAgent := cfg.Agents[AgentID("custom-context-agent")] expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md") assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths) - // Agent without custom context paths should use global + defaults defaultAgent := cfg.Agents[AgentID("default-context-agent")] expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md") assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths) - // Default agents should also use the merged context paths coderAgent := cfg.Agents[AgentCoder] assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths) } -// Options and Settings Tests - func TestOptionsMerging_ContextPaths(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with context paths globalConfig := Config{ Options: Options{ ContextPaths: []string{"global1.md", "global2.md"}, @@ -1604,7 +1714,6 @@ func TestOptionsMerging_ContextPaths(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config with additional context paths localConfig := Config{ Options: Options{ ContextPaths: []string{"local1.md", "local2.md"}, @@ -1620,7 +1729,6 @@ func TestOptionsMerging_ContextPaths(t *testing.T) { require.NoError(t, err) - // Context paths should be merged: defaults + global + local expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md") assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths) } @@ -1630,14 +1738,12 @@ func TestOptionsMerging_TUIOptions(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with TUI options globalConfig := Config{ Options: Options{ TUI: TUIOptions{ - CompactMode: false, // Default value + CompactMode: false, }, }, } @@ -1648,7 +1754,6 @@ func TestOptionsMerging_TUIOptions(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config that enables compact mode localConfig := Config{ Options: Options{ TUI: TUIOptions{ @@ -1666,7 +1771,6 @@ func TestOptionsMerging_TUIOptions(t *testing.T) { require.NoError(t, err) - // Local config should override global assert.True(t, cfg.Options.TUI.CompactMode) } @@ -1675,10 +1779,8 @@ func TestOptionsMerging_DebugFlags(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with debug flags globalConfig := Config{ Options: Options{ Debug: false, @@ -1693,7 +1795,6 @@ func TestOptionsMerging_DebugFlags(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config that enables debug flags localConfig := Config{ Options: Options{ DebugLSP: true, @@ -1710,10 +1811,9 @@ func TestOptionsMerging_DebugFlags(t *testing.T) { require.NoError(t, err) - // Local config should override global for boolean flags - assert.False(t, cfg.Options.Debug) // Not set in local, remains global value - assert.True(t, cfg.Options.DebugLSP) // Set to true in local - assert.True(t, cfg.Options.DisableAutoSummarize) // Set to true in local + assert.False(t, cfg.Options.Debug) + assert.True(t, cfg.Options.DebugLSP) + assert.True(t, cfg.Options.DisableAutoSummarize) } func TestOptionsMerging_DataDirectory(t *testing.T) { @@ -1721,10 +1821,8 @@ func TestOptionsMerging_DataDirectory(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create global config with custom data directory globalConfig := Config{ Options: Options{ DataDirectory: "global-data", @@ -1737,7 +1835,6 @@ func TestOptionsMerging_DataDirectory(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Create local config with different data directory localConfig := Config{ Options: Options{ DataDirectory: "local-data", @@ -1753,7 +1850,6 @@ func TestOptionsMerging_DataDirectory(t *testing.T) { require.NoError(t, err) - // Local config should override global assert.Equal(t, "local-data", cfg.Options.DataDirectory) } @@ -1762,15 +1858,12 @@ func TestOptionsMerging_DefaultValues(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // No config files - should use defaults cfg, err := Init(cwdDir, false) require.NoError(t, err) - // Should have default values assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory) assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths) assert.False(t, cfg.Options.TUI.CompactMode) @@ -1784,10 +1877,8 @@ func TestOptionsMerging_DebugFlagFromInit(t *testing.T) { testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up a provider os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Create config with debug false globalConfig := Config{ Options: Options{ Debug: false, @@ -1800,7 +1891,6 @@ func TestOptionsMerging_DebugFlagFromInit(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(configPath, data, 0o644)) - // Init with debug=true should override config cfg, err := Init(cwdDir, true) require.NoError(t, err) @@ -1895,85 +1985,20 @@ func TestModelSelection_PreferredModelSelection(t *testing.T) { assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider) } -func TestModelSelection_GetAgentModel(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider with known models - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - CanReason: true, - SupportsImages: true, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - CanReason: false, - SupportsImages: false, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetAgentModel for default agents - coderModel := GetAgentModel(AgentCoder) - assert.Equal(t, "gpt-4", coderModel.ID) // Coder uses LargeModel - assert.Equal(t, "GPT-4", coderModel.Name) - assert.True(t, coderModel.CanReason) - assert.True(t, coderModel.SupportsImages) - - taskModel := GetAgentModel(AgentTask) - assert.Equal(t, "gpt-4", taskModel.ID) // Task also uses LargeModel by default - assert.Equal(t, "GPT-4", taskModel.Name) -} - -func TestModelSelection_GetAgentModelWithCustomModelType(t *testing.T) { +func TestValidation_InvalidModelReference(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up provider and custom agent with SmallModel globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { ID: provider.InferenceProviderOpenAI, APIKey: "test-key", ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", + DefaultLargeModel: "non-existent-model", DefaultSmallModel: "gpt-3.5-turbo", Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, { ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo", @@ -1983,54 +2008,6 @@ func TestModelSelection_GetAgentModelWithCustomModelType(t *testing.T) { }, }, }, - Agents: map[AgentID]Agent{ - AgentID("small-agent"): { - ID: AgentID("small-agent"), - Name: "Small Agent", - Model: SmallModel, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetAgentModel for custom agent with SmallModel - smallAgentModel := GetAgentModel(AgentID("small-agent")) - assert.Equal(t, "gpt-3.5-turbo", smallAgentModel.ID) - assert.Equal(t, "GPT-3.5 Turbo", smallAgentModel.Name) -} - -func TestModelSelection_GetAgentProvider(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up multiple providers - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - }, - provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "anthropic-key", - ProviderType: provider.TypeAnthropic, - DefaultLargeModel: "claude-3-opus", - DefaultSmallModel: "claude-3-haiku", - }, - }, } configPath := filepath.Join(testConfigDir, "crush.json") @@ -2040,27 +2017,18 @@ func TestModelSelection_GetAgentProvider(t *testing.T) { require.NoError(t, os.WriteFile(configPath, data, 0o644)) _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetAgentProvider - coderProvider := GetAgentProvider(AgentCoder) - assert.NotEmpty(t, coderProvider.ID) - assert.NotEmpty(t, coderProvider.APIKey) - assert.NotEmpty(t, coderProvider.ProviderType) + assert.Error(t, err) } -func TestModelSelection_GetProviderModel(t *testing.T) { +func TestValidation_EmptyAPIKey(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up provider with specific models globalConfig := Config{ Providers: map[provider.InferenceProvider]ProviderConfig{ provider.InferenceProviderOpenAI: { ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", ProviderType: provider.TypeOpenAI, Models: []Model{ { @@ -2068,16 +2036,6 @@ func TestModelSelection_GetProviderModel(t *testing.T) { Name: "GPT-4", ContextWindow: 8192, DefaultMaxTokens: 4096, - CostPer1MIn: 30.0, - CostPer1MOut: 60.0, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - CostPer1MIn: 1.5, - CostPer1MOut: 2.0, }, }, }, @@ -2091,107 +2049,22 @@ func TestModelSelection_GetProviderModel(t *testing.T) { require.NoError(t, os.WriteFile(configPath, data, 0o644)) _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetProviderModel - gpt4Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-4") - assert.Equal(t, "gpt-4", gpt4Model.ID) - assert.Equal(t, "GPT-4", gpt4Model.Name) - assert.Equal(t, int64(8192), gpt4Model.ContextWindow) - assert.Equal(t, 30.0, gpt4Model.CostPer1MIn) - - gpt35Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-3.5-turbo") - assert.Equal(t, "gpt-3.5-turbo", gpt35Model.ID) - assert.Equal(t, "GPT-3.5 Turbo", gpt35Model.Name) - assert.Equal(t, 1.5, gpt35Model.CostPer1MIn) - - // Test non-existent model - nonExistentModel := GetProviderModel(provider.InferenceProviderOpenAI, "non-existent") - assert.Empty(t, nonExistentModel.ID) + assert.Error(t, err) } -func TestModelSelection_GetModel(t *testing.T) { +func TestValidation_InvalidAgentModelType(t *testing.T) { reset() testConfigDir = t.TempDir() cwdDir := t.TempDir() - // Set up provider with models - globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "test-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - { - ID: "gpt-4", - Name: "GPT-4", - ContextWindow: 8192, - DefaultMaxTokens: 4096, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - ContextWindow: 4096, - DefaultMaxTokens: 2048, - }, - }, - }, - }, - } - - configPath := filepath.Join(testConfigDir, "crush.json") - require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) - data, err := json.Marshal(globalConfig) - require.NoError(t, err) - require.NoError(t, os.WriteFile(configPath, data, 0o644)) - - _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Test GetModel - largeModel := GetModel(LargeModel) - assert.Equal(t, "gpt-4", largeModel.ID) - assert.Equal(t, "GPT-4", largeModel.Name) - - smallModel := GetModel(SmallModel) - assert.Equal(t, "gpt-3.5-turbo", smallModel.ID) - assert.Equal(t, "GPT-3.5 Turbo", smallModel.Name) -} - -func TestModelSelection_UpdatePreferredModel(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() + os.Setenv("ANTHROPIC_API_KEY", "test-key") - // Set up multiple providers with OpenAI first to ensure it's selected initially globalConfig := Config{ - Providers: map[provider.InferenceProvider]ProviderConfig{ - provider.InferenceProviderOpenAI: { - ID: provider.InferenceProviderOpenAI, - APIKey: "openai-key", - ProviderType: provider.TypeOpenAI, - DefaultLargeModel: "gpt-4", - DefaultSmallModel: "gpt-3.5-turbo", - Models: []Model{ - {ID: "gpt-4", Name: "GPT-4"}, - {ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo"}, - }, - }, - provider.InferenceProviderAnthropic: { - ID: provider.InferenceProviderAnthropic, - APIKey: "anthropic-key", - ProviderType: provider.TypeAnthropic, - DefaultLargeModel: "claude-3-opus", - DefaultSmallModel: "claude-3-haiku", - Models: []Model{ - {ID: "claude-3-opus", Name: "Claude 3 Opus"}, - {ID: "claude-3-haiku", Name: "Claude 3 Haiku"}, - }, + Agents: map[AgentID]Agent{ + AgentID("invalid-agent"): { + ID: AgentID("invalid-agent"), + Name: "Invalid Agent", + Model: ModelType("invalid"), }, }, } @@ -2203,97 +2076,5 @@ func TestModelSelection_UpdatePreferredModel(t *testing.T) { require.NoError(t, os.WriteFile(configPath, data, 0o644)) _, err = Init(cwdDir, false) - - require.NoError(t, err) - - // Get initial preferred models (should be OpenAI since it's listed first) - initialLargeModel := GetModel(LargeModel) - initialSmallModel := GetModel(SmallModel) - - // Verify initial models are OpenAI models - assert.Equal(t, "claude-3-opus", initialLargeModel.ID) - assert.Equal(t, "claude-3-haiku", initialSmallModel.ID) - - // Update preferred models to Anthropic - newLargeModel := PreferredModel{ - ModelID: "gpt-4", - Provider: provider.InferenceProviderOpenAI, - } - newSmallModel := PreferredModel{ - ModelID: "gpt-3.5-turbo", - Provider: provider.InferenceProviderOpenAI, - } - - err = UpdatePreferredModel(LargeModel, newLargeModel) - require.NoError(t, err) - - err = UpdatePreferredModel(SmallModel, newSmallModel) - require.NoError(t, err) - - // Verify models were updated - updatedLargeModel := GetModel(LargeModel) - assert.Equal(t, "gpt-4", updatedLargeModel.ID) - assert.NotEqual(t, initialLargeModel.ID, updatedLargeModel.ID) - - updatedSmallModel := GetModel(SmallModel) - assert.Equal(t, "gpt-3.5-turbo", updatedSmallModel.ID) - assert.NotEqual(t, initialSmallModel.ID, updatedSmallModel.ID) -} - -func TestModelSelection_InvalidModelType(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - _, err := Init(cwdDir, false) - require.NoError(t, err) - - // Test UpdatePreferredModel with invalid model type - invalidModel := PreferredModel{ - ModelID: "some-model", - Provider: provider.InferenceProviderAnthropic, - } - - err = UpdatePreferredModel(ModelType("invalid"), invalidModel) assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown model type") -} - -func TestModelSelection_NonExistentAgent(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - _, err := Init(cwdDir, false) - require.NoError(t, err) - - // Test GetAgentModel with non-existent agent - nonExistentModel := GetAgentModel(AgentID("non-existent")) - assert.Empty(t, nonExistentModel.ID) - - // Test GetAgentProvider with non-existent agent - nonExistentProvider := GetAgentProvider(AgentID("non-existent")) - assert.Empty(t, nonExistentProvider.ID) -} - -func TestModelSelection_NonExistentProvider(t *testing.T) { - reset() - testConfigDir = t.TempDir() - cwdDir := t.TempDir() - - // Set up a provider - os.Setenv("ANTHROPIC_API_KEY", "test-key") - - _, err := Init(cwdDir, false) - require.NoError(t, err) - - // Test GetProviderModel with non-existent provider - nonExistentModel := GetProviderModel(provider.InferenceProvider("non-existent"), "some-model") - assert.Empty(t, nonExistentModel.ID) } diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go index af92cc2c33f0b0adbe65dbd728b29727c35aeaa8..73d39d761b15ae682b272f356c62234aaa3ca0ad 100644 --- a/internal/config/provider_mock.go +++ b/internal/config/provider_mock.go @@ -6,6 +6,7 @@ import ( // MockProviders returns a mock list of providers for testing. // This avoids making API calls during tests and provides consistent test data. +// Simplified version with only default models from each provider. func MockProviders() []provider.Provider { return []provider.Provider{ { @@ -14,43 +15,19 @@ func MockProviders() []provider.Provider { APIKey: "$ANTHROPIC_API_KEY", APIEndpoint: "$ANTHROPIC_API_ENDPOINT", Type: provider.TypeAnthropic, - DefaultLargeModelID: "claude-3-opus", - DefaultSmallModelID: "claude-3-haiku", + DefaultLargeModelID: "claude-sonnet-4-20250514", + DefaultSmallModelID: "claude-3-5-haiku-20241022", Models: []provider.Model{ { - ID: "claude-3-opus", - Name: "Claude 3 Opus", - CostPer1MIn: 15.0, - CostPer1MOut: 75.0, - CostPer1MInCached: 18.75, - CostPer1MOutCached: 1.5, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: true, - }, - { - ID: "claude-3-haiku", - Name: "Claude 3 Haiku", - CostPer1MIn: 0.25, - CostPer1MOut: 1.25, - CostPer1MInCached: 0.3, - CostPer1MOutCached: 0.03, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: true, - }, - { - ID: "claude-3-5-sonnet-20241022", - Name: "Claude 3.5 Sonnet", + ID: "claude-sonnet-4-20250514", + Name: "Claude Sonnet 4", CostPer1MIn: 3.0, CostPer1MOut: 15.0, CostPer1MInCached: 3.75, CostPer1MOutCached: 0.3, ContextWindow: 200000, - DefaultMaxTokens: 8192, - CanReason: false, + DefaultMaxTokens: 50000, + CanReason: true, SupportsImages: true, }, { @@ -61,7 +38,7 @@ func MockProviders() []provider.Provider { CostPer1MInCached: 1.0, CostPer1MOutCached: 0.08, ContextWindow: 200000, - DefaultMaxTokens: 8192, + DefaultMaxTokens: 5000, CanReason: false, SupportsImages: true, }, @@ -73,44 +50,22 @@ func MockProviders() []provider.Provider { APIKey: "$OPENAI_API_KEY", APIEndpoint: "$OPENAI_API_ENDPOINT", Type: provider.TypeOpenAI, - DefaultLargeModelID: "gpt-4", - DefaultSmallModelID: "gpt-3.5-turbo", + DefaultLargeModelID: "codex-mini-latest", + DefaultSmallModelID: "gpt-4o", Models: []provider.Model{ { - ID: "gpt-4", - Name: "GPT-4", - CostPer1MIn: 30.0, - CostPer1MOut: 60.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 8192, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: false, - }, - { - ID: "gpt-3.5-turbo", - Name: "GPT-3.5 Turbo", - CostPer1MIn: 1.0, - CostPer1MOut: 2.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 4096, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: false, - }, - { - ID: "gpt-4-turbo", - Name: "GPT-4 Turbo", - CostPer1MIn: 10.0, - CostPer1MOut: 30.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 128000, - DefaultMaxTokens: 4096, - CanReason: false, - SupportsImages: true, + ID: "codex-mini-latest", + Name: "Codex Mini", + CostPer1MIn: 1.5, + CostPer1MOut: 6.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.375, + ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: true, + HasReasoningEffort: true, + DefaultReasoningEffort: "medium", + SupportsImages: true, }, { ID: "gpt-4o", @@ -120,50 +75,10 @@ func MockProviders() []provider.Provider { CostPer1MInCached: 0.0, CostPer1MOutCached: 1.25, ContextWindow: 128000, - DefaultMaxTokens: 16384, + DefaultMaxTokens: 20000, CanReason: false, SupportsImages: true, }, - { - ID: "gpt-4o-mini", - Name: "GPT-4o-mini", - CostPer1MIn: 0.15, - CostPer1MOut: 0.6, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.075, - ContextWindow: 128000, - DefaultMaxTokens: 16384, - CanReason: false, - SupportsImages: true, - }, - { - ID: "o1-preview", - Name: "o1-preview", - CostPer1MIn: 15.0, - CostPer1MOut: 60.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 128000, - DefaultMaxTokens: 32768, - CanReason: true, - HasReasoningEffort: true, - DefaultReasoningEffort: "medium", - SupportsImages: true, - }, - { - ID: "o1-mini", - Name: "o1-mini", - CostPer1MIn: 3.0, - CostPer1MOut: 12.0, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - ContextWindow: 128000, - DefaultMaxTokens: 65536, - CanReason: true, - HasReasoningEffort: true, - DefaultReasoningEffort: "medium", - SupportsImages: true, - }, }, }, { @@ -183,7 +98,7 @@ func MockProviders() []provider.Provider { CostPer1MInCached: 1.625, CostPer1MOutCached: 0.31, ContextWindow: 1048576, - DefaultMaxTokens: 65536, + DefaultMaxTokens: 50000, CanReason: true, SupportsImages: true, }, @@ -195,7 +110,7 @@ func MockProviders() []provider.Provider { CostPer1MInCached: 0.3833, CostPer1MOutCached: 0.075, ContextWindow: 1048576, - DefaultMaxTokens: 65535, + DefaultMaxTokens: 50000, CanReason: true, SupportsImages: true, }, @@ -207,18 +122,135 @@ func MockProviders() []provider.Provider { APIKey: "$XAI_API_KEY", APIEndpoint: "https://api.x.ai/v1", Type: provider.TypeXAI, - DefaultLargeModelID: "grok-beta", - DefaultSmallModelID: "grok-beta", + DefaultLargeModelID: "grok-3", + DefaultSmallModelID: "grok-3-mini", Models: []provider.Model{ { - ID: "grok-beta", - Name: "Grok Beta", - CostPer1MIn: 5.0, + ID: "grok-3", + Name: "Grok 3", + CostPer1MIn: 3.0, CostPer1MOut: 15.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.75, ContextWindow: 131072, - DefaultMaxTokens: 4096, + DefaultMaxTokens: 20000, CanReason: false, - SupportsImages: true, + SupportsImages: false, + }, + { + ID: "grok-3-mini", + Name: "Grok 3 Mini", + CostPer1MIn: 0.3, + CostPer1MOut: 0.5, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.075, + ContextWindow: 131072, + DefaultMaxTokens: 20000, + CanReason: true, + SupportsImages: false, + }, + }, + }, + { + Name: "Azure OpenAI", + ID: provider.InferenceProviderAzure, + APIKey: "$AZURE_OPENAI_API_KEY", + APIEndpoint: "$AZURE_OPENAI_API_ENDPOINT", + Type: provider.TypeAzure, + DefaultLargeModelID: "o4-mini", + DefaultSmallModelID: "gpt-4o", + Models: []provider.Model{ + { + ID: "o4-mini", + Name: "o4 Mini", + CostPer1MIn: 1.1, + CostPer1MOut: 4.4, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.275, + ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: true, + HasReasoningEffort: false, + DefaultReasoningEffort: "medium", + SupportsImages: true, + }, + { + ID: "gpt-4o", + Name: "GPT-4o", + CostPer1MIn: 2.5, + CostPer1MOut: 10.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 1.25, + ContextWindow: 128000, + DefaultMaxTokens: 20000, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "AWS Bedrock", + ID: provider.InferenceProviderBedrock, + Type: provider.TypeBedrock, + DefaultLargeModelID: "anthropic.claude-sonnet-4-20250514-v1:0", + DefaultSmallModelID: "anthropic.claude-3-5-haiku-20241022-v1:0", + Models: []provider.Model{ + { + ID: "anthropic.claude-sonnet-4-20250514-v1:0", + Name: "AWS Claude Sonnet 4", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.3, + ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: true, + SupportsImages: true, + }, + { + ID: "anthropic.claude-3-5-haiku-20241022-v1:0", + Name: "AWS Claude 3.5 Haiku", + CostPer1MIn: 0.8, + CostPer1MOut: 4.0, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "Google Vertex AI", + ID: provider.InferenceProviderVertexAI, + Type: provider.TypeVertexAI, + DefaultLargeModelID: "gemini-2.5-pro", + DefaultSmallModelID: "gemini-2.5-flash", + Models: []provider.Model{ + { + ID: "gemini-2.5-pro", + Name: "Gemini 2.5 Pro", + CostPer1MIn: 1.25, + CostPer1MOut: 10.0, + CostPer1MInCached: 1.625, + CostPer1MOutCached: 0.31, + ContextWindow: 1048576, + DefaultMaxTokens: 50000, + CanReason: true, + SupportsImages: true, + }, + { + ID: "gemini-2.5-flash", + Name: "Gemini 2.5 Flash", + CostPer1MIn: 0.3, + CostPer1MOut: 2.5, + CostPer1MInCached: 0.3833, + CostPer1MOutCached: 0.075, + ContextWindow: 1048576, + DefaultMaxTokens: 50000, + CanReason: true, + SupportsImages: true, }, }, }, @@ -228,28 +260,32 @@ func MockProviders() []provider.Provider { APIKey: "$OPENROUTER_API_KEY", APIEndpoint: "https://openrouter.ai/api/v1", Type: provider.TypeOpenAI, - DefaultLargeModelID: "anthropic/claude-3.5-sonnet", - DefaultSmallModelID: "anthropic/claude-3.5-haiku", + DefaultLargeModelID: "anthropic/claude-sonnet-4", + DefaultSmallModelID: "anthropic/claude-haiku-3.5", Models: []provider.Model{ { - ID: "anthropic/claude-3.5-sonnet", - Name: "Claude 3.5 Sonnet", - CostPer1MIn: 3.0, - CostPer1MOut: 15.0, - ContextWindow: 200000, - DefaultMaxTokens: 8192, - CanReason: false, - SupportsImages: true, + ID: "anthropic/claude-sonnet-4", + Name: "Anthropic: Claude Sonnet 4", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.3, + ContextWindow: 200000, + DefaultMaxTokens: 32000, + CanReason: true, + SupportsImages: true, }, { - ID: "anthropic/claude-3.5-haiku", - Name: "Claude 3.5 Haiku", - CostPer1MIn: 0.8, - CostPer1MOut: 4.0, - ContextWindow: 200000, - DefaultMaxTokens: 8192, - CanReason: false, - SupportsImages: true, + ID: "anthropic/claude-haiku-3.5", + Name: "Anthropic: Claude 3.5 Haiku", + CostPer1MIn: 0.8, + CostPer1MOut: 4.0, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + ContextWindow: 200000, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, }, }, }, diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 70224d194a6689d85602d6a0f7d92e03b02fa1b2..8f2a31f06ab121fa049e7ca8bed159976cb2e92f 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,7 +1,6 @@ package config import ( - "encoding/json" "testing" "github.com/charmbracelet/crush/internal/fur/provider" @@ -9,8 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestMockProviders(t *testing.T) { - // Enable mock providers for testing +func TestProviders_MockEnabled(t *testing.T) { originalUseMock := UseMockProviders UseMockProviders = true defer func() { @@ -18,94 +16,38 @@ func TestMockProviders(t *testing.T) { ResetProviders() }() - // Reset providers to ensure we get fresh mock data ResetProviders() - providers := Providers() - require.NotEmpty(t, providers, "Mock providers should not be empty") + require.NotEmpty(t, providers) - // Verify we have the expected mock providers providerIDs := make(map[provider.InferenceProvider]bool) for _, p := range providers { providerIDs[p.ID] = true } - assert.True(t, providerIDs[provider.InferenceProviderAnthropic], "Should have Anthropic provider") - assert.True(t, providerIDs[provider.InferenceProviderOpenAI], "Should have OpenAI provider") - assert.True(t, providerIDs[provider.InferenceProviderGemini], "Should have Gemini provider") - - // Verify Anthropic provider details - var anthropicProvider provider.Provider - for _, p := range providers { - if p.ID == provider.InferenceProviderAnthropic { - anthropicProvider = p - break - } - } - - assert.Equal(t, "Anthropic", anthropicProvider.Name) - assert.Equal(t, provider.TypeAnthropic, anthropicProvider.Type) - assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModelID) - assert.Equal(t, "claude-3-haiku", anthropicProvider.DefaultSmallModelID) - assert.Len(t, anthropicProvider.Models, 4, "Anthropic should have 4 models") - - // Verify model details - var opusModel provider.Model - for _, m := range anthropicProvider.Models { - if m.ID == "claude-3-opus" { - opusModel = m - break - } - } - - assert.Equal(t, "Claude 3 Opus", opusModel.Name) - assert.Equal(t, int64(200000), opusModel.ContextWindow) - assert.Equal(t, int64(4096), opusModel.DefaultMaxTokens) - assert.True(t, opusModel.SupportsImages) -} - -func TestProvidersWithoutMock(t *testing.T) { - // Ensure mock is disabled - originalUseMock := UseMockProviders - UseMockProviders = false - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - // Reset providers to ensure we get fresh data - ResetProviders() - - // This will try to make an actual API call or use cached data - providers := Providers() - - // We can't guarantee what we'll get here since it depends on network/cache - // but we can at least verify the function doesn't panic - t.Logf("Got %d providers without mock", len(providers)) + assert.True(t, providerIDs[provider.InferenceProviderAnthropic]) + assert.True(t, providerIDs[provider.InferenceProviderOpenAI]) + assert.True(t, providerIDs[provider.InferenceProviderGemini]) } -func TestResetProviders(t *testing.T) { - // Enable mock providers +func TestProviders_ResetFunctionality(t *testing.T) { UseMockProviders = true defer func() { UseMockProviders = false ResetProviders() }() - // Get providers once providers1 := Providers() require.NotEmpty(t, providers1) - // Reset and get again ResetProviders() providers2 := Providers() require.NotEmpty(t, providers2) - // Should get the same mock data assert.Equal(t, len(providers1), len(providers2)) } -func TestReasoningEffortSupport(t *testing.T) { +func TestProviders_ModelCapabilities(t *testing.T) { originalUseMock := UseMockProviders UseMockProviders = true defer func() { @@ -125,156 +67,15 @@ func TestReasoningEffortSupport(t *testing.T) { } require.NotEmpty(t, openaiProvider.ID) - var reasoningModel, nonReasoningModel provider.Model + var foundReasoning, foundNonReasoning bool for _, model := range openaiProvider.Models { if model.CanReason && model.HasReasoningEffort { - reasoningModel = model - } else if !model.CanReason { - nonReasoningModel = model - } - } - - require.NotEmpty(t, reasoningModel.ID) - assert.Equal(t, "medium", reasoningModel.DefaultReasoningEffort) - assert.True(t, reasoningModel.HasReasoningEffort) - - require.NotEmpty(t, nonReasoningModel.ID) - assert.False(t, nonReasoningModel.HasReasoningEffort) - assert.Empty(t, nonReasoningModel.DefaultReasoningEffort) -} - -func TestReasoningEffortConfigTransfer(t *testing.T) { - originalUseMock := UseMockProviders - UseMockProviders = true - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - ResetProviders() - t.Setenv("OPENAI_API_KEY", "test-openai-key") - - cfg, err := Init(t.TempDir(), false) - require.NoError(t, err) - - openaiProviderConfig, exists := cfg.Providers[provider.InferenceProviderOpenAI] - require.True(t, exists) - - var foundReasoning, foundNonReasoning bool - for _, model := range openaiProviderConfig.Models { - if model.CanReason && model.HasReasoningEffort && model.ReasoningEffort != "" { - assert.Equal(t, "medium", model.ReasoningEffort) - assert.True(t, model.HasReasoningEffort) foundReasoning = true } else if !model.CanReason { - assert.Empty(t, model.ReasoningEffort) - assert.False(t, model.HasReasoningEffort) foundNonReasoning = true } } - assert.True(t, foundReasoning, "Should find at least one reasoning model") - assert.True(t, foundNonReasoning, "Should find at least one non-reasoning model") -} - -func TestNewProviders(t *testing.T) { - originalUseMock := UseMockProviders - UseMockProviders = true - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - ResetProviders() - providers := Providers() - require.NotEmpty(t, providers) - - var xaiProvider, openRouterProvider provider.Provider - for _, p := range providers { - switch p.ID { - case provider.InferenceProviderXAI: - xaiProvider = p - case provider.InferenceProviderOpenRouter: - openRouterProvider = p - } - } - - require.NotEmpty(t, xaiProvider.ID) - assert.Equal(t, "xAI", xaiProvider.Name) - assert.Equal(t, "grok-beta", xaiProvider.DefaultLargeModelID) - - require.NotEmpty(t, openRouterProvider.ID) - assert.Equal(t, "OpenRouter", openRouterProvider.Name) - assert.Equal(t, "anthropic/claude-3.5-sonnet", openRouterProvider.DefaultLargeModelID) -} - -func TestO1ModelsInMockProvider(t *testing.T) { - originalUseMock := UseMockProviders - UseMockProviders = true - defer func() { - UseMockProviders = originalUseMock - ResetProviders() - }() - - ResetProviders() - providers := Providers() - - var openaiProvider provider.Provider - for _, p := range providers { - if p.ID == provider.InferenceProviderOpenAI { - openaiProvider = p - break - } - } - require.NotEmpty(t, openaiProvider.ID) - - modelTests := []struct { - id string - name string - }{ - {"o1-preview", "o1-preview"}, - {"o1-mini", "o1-mini"}, - } - - for _, test := range modelTests { - var model provider.Model - var found bool - for _, m := range openaiProvider.Models { - if m.ID == test.id { - model = m - found = true - break - } - } - require.True(t, found, "Should find %s model", test.id) - assert.Equal(t, test.name, model.Name) - assert.True(t, model.CanReason) - assert.True(t, model.HasReasoningEffort) - assert.Equal(t, "medium", model.DefaultReasoningEffort) - } -} - -func TestPreferredModelReasoningEffort(t *testing.T) { - // Test that PreferredModel struct can hold reasoning effort - preferredModel := PreferredModel{ - ModelID: "o1-preview", - Provider: provider.InferenceProviderOpenAI, - ReasoningEffort: "high", - } - - assert.Equal(t, "o1-preview", preferredModel.ModelID) - assert.Equal(t, provider.InferenceProviderOpenAI, preferredModel.Provider) - assert.Equal(t, "high", preferredModel.ReasoningEffort) - - // Test JSON marshaling/unmarshaling - jsonData, err := json.Marshal(preferredModel) - require.NoError(t, err) - - var unmarshaled PreferredModel - err = json.Unmarshal(jsonData, &unmarshaled) - require.NoError(t, err) - - assert.Equal(t, preferredModel.ModelID, unmarshaled.ModelID) - assert.Equal(t, preferredModel.Provider, unmarshaled.Provider) - assert.Equal(t, preferredModel.ReasoningEffort, unmarshaled.ReasoningEffort) -} + assert.True(t, foundReasoning) + assert.True(t, foundNonReasoning) +} \ No newline at end of file diff --git a/internal/config/shell.go b/internal/config/shell.go index a12ecd1da3b82c113175a1f4825877a7fb94a95c..74931bfefc3a9e16e830fac2c3478a6f0d5396f2 100644 --- a/internal/config/shell.go +++ b/internal/config/shell.go @@ -71,4 +71,3 @@ func resolveCommandAPIKey(command string) (string, error) { logging.Debug("Command executed successfully", "command", command, "result", result) return result, nil } - From 21719ee394b679b3b5522d3a128b1b6148c65353 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 13:43:08 +0200 Subject: [PATCH 25/40] chore: change how we generate the json schema --- cmd/schema/main.go | 528 ++------- crush-schema.json | 1761 ++++++++---------------------- crush.json | 3 +- cspell.json | 2 +- go.mod | 8 + go.sum | 11 + internal/config/config.go | 187 ++-- internal/config/config_test.go | 1 - internal/config/provider_mock.go | 36 +- internal/config/provider_test.go | 2 +- 10 files changed, 702 insertions(+), 1837 deletions(-) diff --git a/cmd/schema/main.go b/cmd/schema/main.go index 34939f1e84b9f3df04c7419a9ac4d7dfdc76386a..43f361662cd5d357e4d3b736ba7b4f3af2222724 100644 --- a/cmd/schema/main.go +++ b/cmd/schema/main.go @@ -4,67 +4,31 @@ import ( "encoding/json" "fmt" "os" - "reflect" - "slices" - "strings" "github.com/charmbracelet/crush/internal/config" + "github.com/invopop/jsonschema" ) -// JSONSchema represents a JSON Schema -type JSONSchema struct { - Schema string `json:"$schema,omitempty"` - Title string `json:"title,omitempty"` - Description string `json:"description,omitempty"` - Type string `json:"type,omitempty"` - Properties map[string]*JSONSchema `json:"properties,omitempty"` - Items *JSONSchema `json:"items,omitempty"` - Required []string `json:"required,omitempty"` - AdditionalProperties any `json:"additionalProperties,omitempty"` - Enum []any `json:"enum,omitempty"` - Default any `json:"default,omitempty"` - Definitions map[string]*JSONSchema `json:"definitions,omitempty"` - Ref string `json:"$ref,omitempty"` - OneOf []*JSONSchema `json:"oneOf,omitempty"` - AnyOf []*JSONSchema `json:"anyOf,omitempty"` - AllOf []*JSONSchema `json:"allOf,omitempty"` - Not *JSONSchema `json:"not,omitempty"` - Format string `json:"format,omitempty"` - Pattern string `json:"pattern,omitempty"` - MinLength *int `json:"minLength,omitempty"` - MaxLength *int `json:"maxLength,omitempty"` - Minimum *float64 `json:"minimum,omitempty"` - Maximum *float64 `json:"maximum,omitempty"` - ExclusiveMinimum *float64 `json:"exclusiveMinimum,omitempty"` - ExclusiveMaximum *float64 `json:"exclusiveMaximum,omitempty"` - MultipleOf *float64 `json:"multipleOf,omitempty"` - MinItems *int `json:"minItems,omitempty"` - MaxItems *int `json:"maxItems,omitempty"` - UniqueItems *bool `json:"uniqueItems,omitempty"` - MinProperties *int `json:"minProperties,omitempty"` - MaxProperties *int `json:"maxProperties,omitempty"` -} - -// SchemaGenerator generates JSON schemas from Go types -type SchemaGenerator struct { - definitions map[string]*JSONSchema - visited map[reflect.Type]bool -} - -// NewSchemaGenerator creates a new schema generator -func NewSchemaGenerator() *SchemaGenerator { - return &SchemaGenerator{ - definitions: make(map[string]*JSONSchema), - visited: make(map[reflect.Type]bool), +func main() { + // Create a new reflector + r := &jsonschema.Reflector{ + // Use anonymous schemas to avoid ID conflicts + Anonymous: true, + // Expand the root struct instead of referencing it + ExpandedStruct: true, + AllowAdditionalProperties: true, } -} -func main() { - // Enable mock providers to avoid API calls during schema generation - config.UseMockProviders = true + // Generate schema for the main Config struct + schema := r.Reflect(&config.Config{}) + + // Enhance the schema with additional information + enhanceSchema(schema) - generator := NewSchemaGenerator() - schema := generator.GenerateSchema() + // Set the schema metadata + schema.Version = "https://json-schema.org/draft/2020-12/schema" + schema.Title = "Crush Configuration" + schema.Description = "Configuration schema for the Crush application" // Pretty print the schema encoder := json.NewEncoder(os.Stdout) @@ -75,226 +39,23 @@ func main() { } } -// GenerateSchema generates the complete JSON schema for the Crush configuration -func (g *SchemaGenerator) GenerateSchema() *JSONSchema { - // Generate schema for the main Config struct - configType := reflect.TypeOf(config.Config{}) - configSchema := g.generateTypeSchema(configType) - - // Create the root schema - schema := &JSONSchema{ - Schema: "http://json-schema.org/draft-07/schema#", - Title: "Crush Configuration", - Description: "Configuration schema for the Crush application", - Type: configSchema.Type, - Properties: configSchema.Properties, - Required: configSchema.Required, - Definitions: g.definitions, - } - - // Add custom enhancements - g.enhanceSchema(schema) - - return schema -} - -// generateTypeSchema generates a JSON schema for a given Go type -func (g *SchemaGenerator) generateTypeSchema(t reflect.Type) *JSONSchema { - // Handle pointers - if t.Kind() == reflect.Ptr { - return g.generateTypeSchema(t.Elem()) - } - - // Check if we've already processed this type - if g.visited[t] { - // Return a reference to avoid infinite recursion - return &JSONSchema{ - Ref: fmt.Sprintf("#/definitions/%s", t.Name()), - } - } - - switch t.Kind() { - case reflect.String: - return &JSONSchema{Type: "string"} - case reflect.Bool: - return &JSONSchema{Type: "boolean"} - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return &JSONSchema{Type: "integer"} - case reflect.Float32, reflect.Float64: - return &JSONSchema{Type: "number"} - case reflect.Slice, reflect.Array: - itemSchema := g.generateTypeSchema(t.Elem()) - return &JSONSchema{ - Type: "array", - Items: itemSchema, - } - case reflect.Map: - valueSchema := g.generateTypeSchema(t.Elem()) - return &JSONSchema{ - Type: "object", - AdditionalProperties: valueSchema, - } - case reflect.Struct: - return g.generateStructSchema(t) - case reflect.Interface: - // For interface{} types, allow any value - return &JSONSchema{} - default: - // Fallback for unknown types - return &JSONSchema{} - } -} - -// generateStructSchema generates a JSON schema for a struct type -func (g *SchemaGenerator) generateStructSchema(t reflect.Type) *JSONSchema { - // Mark as visited to prevent infinite recursion - g.visited[t] = true - - schema := &JSONSchema{ - Type: "object", - Properties: make(map[string]*JSONSchema), - } - - var required []string - - for i := range t.NumField() { - field := t.Field(i) - - // Skip unexported fields - if !field.IsExported() { - continue - } - - // Get JSON tag - jsonTag := field.Tag.Get("json") - if jsonTag == "-" { - continue - } - - // Parse JSON tag - jsonName, options := parseJSONTag(jsonTag) - if jsonName == "" { - jsonName = strings.ToLower(field.Name) - } - - // Generate field schema - fieldSchema := g.generateTypeSchema(field.Type) - - // Add description from field name if not present - if fieldSchema.Description == "" { - fieldSchema.Description = generateFieldDescription(field.Name, field.Type) - } - - // Check if field is required (not omitempty and not a pointer) - if !slices.Contains(options, "omitempty") && field.Type.Kind() != reflect.Ptr { - required = append(required, jsonName) - } - - schema.Properties[jsonName] = fieldSchema - } - - if len(required) > 0 { - schema.Required = required - } - - // Store in definitions if it's a named type - if t.Name() != "" { - g.definitions[t.Name()] = schema - } - - return schema -} - -// parseJSONTag parses a JSON struct tag -func parseJSONTag(tag string) (name string, options []string) { - if tag == "" { - return "", nil - } - - parts := strings.Split(tag, ",") - name = parts[0] - if len(parts) > 1 { - options = parts[1:] - } - return name, options -} - -// generateFieldDescription generates a description for a field based on its name and type -func generateFieldDescription(fieldName string, fieldType reflect.Type) string { - // Convert camelCase to words - words := camelCaseToWords(fieldName) - description := strings.Join(words, " ") - - // Add type-specific information - switch fieldType.Kind() { - case reflect.Bool: - if !strings.Contains(strings.ToLower(description), "enable") && - !strings.Contains(strings.ToLower(description), "disable") { - description = "Enable " + strings.ToLower(description) - } - case reflect.Slice: - if !strings.HasSuffix(description, "s") { - description = description + " list" - } - case reflect.Map: - description = description + " configuration" - } - - return description -} - -// camelCaseToWords converts camelCase to separate words -func camelCaseToWords(s string) []string { - var words []string - var currentWord strings.Builder - - for i, r := range s { - if i > 0 && r >= 'A' && r <= 'Z' { - if currentWord.Len() > 0 { - words = append(words, currentWord.String()) - currentWord.Reset() - } - } - currentWord.WriteRune(r) - } - - if currentWord.Len() > 0 { - words = append(words, currentWord.String()) - } - - return words -} - -// enhanceSchema adds custom enhancements to the generated schema -func (g *SchemaGenerator) enhanceSchema(schema *JSONSchema) { +// enhanceSchema adds additional enhancements to the generated schema +func enhanceSchema(schema *jsonschema.Schema) { // Add provider enums - g.addProviderEnums(schema) + addProviderEnums(schema) // Add model enums - g.addModelEnums(schema) - - // Add agent enums - g.addAgentEnums(schema) + addModelEnums(schema) // Add tool enums - g.addToolEnums(schema) - - // Add MCP type enums - g.addMCPTypeEnums(schema) - - // Add model type enums - g.addModelTypeEnums(schema) - - // Add default values - g.addDefaultValues(schema) + addToolEnums(schema) - // Add custom descriptions - g.addCustomDescriptions(schema) + // Add default context paths + addDefaultContextPaths(schema) } // addProviderEnums adds provider enums to the schema -func (g *SchemaGenerator) addProviderEnums(schema *JSONSchema) { +func addProviderEnums(schema *jsonschema.Schema) { providers := config.Providers() var providerIDs []any for _, p := range providers { @@ -302,22 +63,24 @@ func (g *SchemaGenerator) addProviderEnums(schema *JSONSchema) { } // Add to PreferredModel provider field - if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { - if providerProp, exists := preferredModelDef.Properties["provider"]; exists { - providerProp.Enum = providerIDs + if schema.Definitions != nil { + if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { + if providerProp, exists := preferredModelDef.Properties.Get("provider"); exists { + providerProp.Enum = providerIDs + } } - } - // Add to ProviderConfig ID field - if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { - if idProp, exists := providerConfigDef.Properties["id"]; exists { - idProp.Enum = providerIDs + // Add to ProviderConfig ID field + if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { + if idProp, exists := providerConfigDef.Properties.Get("id"); exists { + idProp.Enum = providerIDs + } } } } // addModelEnums adds model enums to the schema -func (g *SchemaGenerator) addModelEnums(schema *JSONSchema) { +func addModelEnums(schema *jsonschema.Schema) { providers := config.Providers() var modelIDs []any for _, p := range providers { @@ -327,205 +90,66 @@ func (g *SchemaGenerator) addModelEnums(schema *JSONSchema) { } // Add to PreferredModel model_id field - if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { - if modelIDProp, exists := preferredModelDef.Properties["model_id"]; exists { - modelIDProp.Enum = modelIDs - } - } -} - -// addAgentEnums adds agent ID enums to the schema -func (g *SchemaGenerator) addAgentEnums(schema *JSONSchema) { - agentIDs := []any{ - string(config.AgentCoder), - string(config.AgentTask), - } - - if agentDef, exists := schema.Definitions["Agent"]; exists { - if idProp, exists := agentDef.Properties["id"]; exists { - idProp.Enum = agentIDs + if schema.Definitions != nil { + if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists { + if modelIDProp, exists := preferredModelDef.Properties.Get("model_id"); exists { + modelIDProp.Enum = modelIDs + } } } } // addToolEnums adds tool enums to the schema -func (g *SchemaGenerator) addToolEnums(schema *JSONSchema) { +func addToolEnums(schema *jsonschema.Schema) { tools := []any{ "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", } - if agentDef, exists := schema.Definitions["Agent"]; exists { - if allowedToolsProp, exists := agentDef.Properties["allowed_tools"]; exists { - if allowedToolsProp.Items != nil { - allowedToolsProp.Items.Enum = tools + if schema.Definitions != nil { + if agentDef, exists := schema.Definitions["Agent"]; exists { + if allowedToolsProp, exists := agentDef.Properties.Get("allowed_tools"); exists { + if allowedToolsProp.Items != nil { + allowedToolsProp.Items.Enum = tools + } } } } } -// addMCPTypeEnums adds MCP type enums to the schema -func (g *SchemaGenerator) addMCPTypeEnums(schema *JSONSchema) { - mcpTypes := []any{ - string(config.MCPStdio), - string(config.MCPSse), +// addDefaultContextPaths adds default context paths to the schema +func addDefaultContextPaths(schema *jsonschema.Schema) { + defaultContextPaths := []any{ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md", } - if mcpDef, exists := schema.Definitions["MCP"]; exists { - if typeProp, exists := mcpDef.Properties["type"]; exists { - typeProp.Enum = mcpTypes - } - } -} - -// addModelTypeEnums adds model type enums to the schema -func (g *SchemaGenerator) addModelTypeEnums(schema *JSONSchema) { - modelTypes := []any{ - string(config.LargeModel), - string(config.SmallModel), - } - - if agentDef, exists := schema.Definitions["Agent"]; exists { - if modelProp, exists := agentDef.Properties["model"]; exists { - modelProp.Enum = modelTypes - } - } -} - -// addDefaultValues adds default values to the schema -func (g *SchemaGenerator) addDefaultValues(schema *JSONSchema) { - // Add default context paths - if optionsDef, exists := schema.Definitions["Options"]; exists { - if contextPathsProp, exists := optionsDef.Properties["context_paths"]; exists { - contextPathsProp.Default = []any{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", + if schema.Definitions != nil { + if optionsDef, exists := schema.Definitions["Options"]; exists { + if contextPathsProp, exists := optionsDef.Properties.Get("context_paths"); exists { + contextPathsProp.Default = defaultContextPaths } } - if dataDirProp, exists := optionsDef.Properties["data_directory"]; exists { - dataDirProp.Default = ".crush" - } - if debugProp, exists := optionsDef.Properties["debug"]; exists { - debugProp.Default = false - } - if debugLSPProp, exists := optionsDef.Properties["debug_lsp"]; exists { - debugLSPProp.Default = false - } - if disableAutoSummarizeProp, exists := optionsDef.Properties["disable_auto_summarize"]; exists { - disableAutoSummarizeProp.Default = false - } - } - - // Add default MCP type - if mcpDef, exists := schema.Definitions["MCP"]; exists { - if typeProp, exists := mcpDef.Properties["type"]; exists { - typeProp.Default = string(config.MCPStdio) - } - } - - // Add default TUI options - if tuiOptionsDef, exists := schema.Definitions["TUIOptions"]; exists { - if compactModeProp, exists := tuiOptionsDef.Properties["compact_mode"]; exists { - compactModeProp.Default = false - } - } - - // Add default provider disabled - if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { - if disabledProp, exists := providerConfigDef.Properties["disabled"]; exists { - disabledProp.Default = false - } - } - - // Add default agent disabled - if agentDef, exists := schema.Definitions["Agent"]; exists { - if disabledProp, exists := agentDef.Properties["disabled"]; exists { - disabledProp.Default = false - } - } - - // Add default LSP disabled - if lspConfigDef, exists := schema.Definitions["LSPConfig"]; exists { - if disabledProp, exists := lspConfigDef.Properties["enabled"]; exists { - disabledProp.Default = true - } } -} -// addCustomDescriptions adds custom descriptions to improve the schema -func (g *SchemaGenerator) addCustomDescriptions(schema *JSONSchema) { - // Enhance main config descriptions + // Also add to root properties if they exist if schema.Properties != nil { - if modelsProp, exists := schema.Properties["models"]; exists { - modelsProp.Description = "Preferred model configurations for large and small model types" - } - if providersProp, exists := schema.Properties["providers"]; exists { - providersProp.Description = "LLM provider configurations" - } - if agentsProp, exists := schema.Properties["agents"]; exists { - agentsProp.Description = "Agent configurations for different tasks" - } - if mcpProp, exists := schema.Properties["mcp"]; exists { - mcpProp.Description = "Model Control Protocol server configurations" - } - if lspProp, exists := schema.Properties["lsp"]; exists { - lspProp.Description = "Language Server Protocol configurations" - } - if optionsProp, exists := schema.Properties["options"]; exists { - optionsProp.Description = "General application options and settings" - } - } - - // Enhance specific field descriptions - if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists { - if apiKeyProp, exists := providerConfigDef.Properties["api_key"]; exists { - apiKeyProp.Description = "API key for authenticating with the provider" - } - if baseURLProp, exists := providerConfigDef.Properties["base_url"]; exists { - baseURLProp.Description = "Base URL for the provider API (required for custom providers)" - } - if extraHeadersProp, exists := providerConfigDef.Properties["extra_headers"]; exists { - extraHeadersProp.Description = "Additional HTTP headers to send with requests" - } - if extraParamsProp, exists := providerConfigDef.Properties["extra_params"]; exists { - extraParamsProp.Description = "Additional provider-specific parameters" - } - } - - if agentDef, exists := schema.Definitions["Agent"]; exists { - if allowedToolsProp, exists := agentDef.Properties["allowed_tools"]; exists { - allowedToolsProp.Description = "List of tools this agent is allowed to use (if nil, all tools are allowed)" - } - if allowedMCPProp, exists := agentDef.Properties["allowed_mcp"]; exists { - allowedMCPProp.Description = "Map of MCP servers this agent can use and their allowed tools" - } - if allowedLSPProp, exists := agentDef.Properties["allowed_lsp"]; exists { - allowedLSPProp.Description = "List of LSP servers this agent can use (if nil, all LSPs are allowed)" - } - if contextPathsProp, exists := agentDef.Properties["context_paths"]; exists { - contextPathsProp.Description = "Custom context paths for this agent (additive to global context paths)" - } - } - - if mcpDef, exists := schema.Definitions["MCP"]; exists { - if commandProp, exists := mcpDef.Properties["command"]; exists { - commandProp.Description = "Command to execute for stdio MCP servers" - } - if urlProp, exists := mcpDef.Properties["url"]; exists { - urlProp.Description = "URL for SSE MCP servers" - } - if headersProp, exists := mcpDef.Properties["headers"]; exists { - headersProp.Description = "HTTP headers for SSE MCP servers" + if optionsProp, exists := schema.Properties.Get("options"); exists { + if optionsProp.Properties != nil { + if contextPathsProp, exists := optionsProp.Properties.Get("context_paths"); exists { + contextPathsProp.Default = defaultContextPaths + } + } } } } diff --git a/crush-schema.json b/crush-schema.json index f5fa562c5aff42972eb2308c3374969e5d42cac8..35bd388ceca5b0aca5401d66db963b77f4f584f8 100644 --- a/crush-schema.json +++ b/crush-schema.json @@ -1,492 +1,47 @@ { - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "Crush Configuration", - "description": "Configuration schema for the Crush application", - "type": "object", - "properties": { - "agents": { - "description": "Agent configurations for different tasks", - "type": "object", - "additionalProperties": { - "type": "object", - "properties": { - "allowed_lsp": { - "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", - "type": "array", - "items": { - "type": "string" - } - }, - "allowed_mcp": { - "description": "Map of MCP servers this agent can use and their allowed tools", - "type": "object", - "additionalProperties": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "allowed_tools": { - "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", - "type": "array", - "items": { - "type": "string", - "enum": [ - "bash", - "edit", - "fetch", - "glob", - "grep", - "ls", - "sourcegraph", - "view", - "write", - "agent" - ] - } - }, - "context_paths": { - "description": "Custom context paths for this agent (additive to global context paths)", - "type": "array", - "items": { - "type": "string" - } - }, - "description": { - "description": "Description", - "type": "string" - }, - "disabled": { - "description": "Disabled", - "type": "boolean", - "default": false - }, - "id": { - "description": "I D", - "type": "string", - "enum": [ - "coder", - "task" - ] - }, - "model": { - "description": "Model", - "type": "string", - "enum": [ - "large", - "small" - ] - }, - "name": { - "description": "Name", - "type": "string" - } - }, - "required": [ - "id", - "name", - "disabled", - "model", - "allowed_tools", - "allowed_mcp", - "allowed_lsp", - "context_paths" - ] - } - }, - "lsp": { - "description": "Language Server Protocol configurations", - "type": "object", - "additionalProperties": { - "type": "object", - "properties": { - "args": { - "description": "Args", - "type": "array", - "items": { - "type": "string" - } - }, - "command": { - "description": "Command", - "type": "string" - }, - "enabled": { - "description": "Disabled", - "type": "boolean", - "default": true - }, - "options": { - "description": "Options" - } - }, - "required": [ - "enabled", - "command", - "args", - "options" - ] - } - }, - "mcp": { - "description": "Model Control Protocol server configurations", - "type": "object", - "additionalProperties": { - "type": "object", - "properties": { - "args": { - "description": "Args", - "type": "array", - "items": { - "type": "string" - } - }, - "command": { - "description": "Command to execute for stdio MCP servers", - "type": "string" - }, - "env": { - "description": "Env list", - "type": "array", - "items": { - "type": "string" - } - }, - "headers": { - "description": "HTTP headers for SSE MCP servers", - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "type": { - "description": "Type", - "type": "string", - "enum": [ - "stdio", - "sse" - ], - "default": "stdio" - }, - "url": { - "description": "URL for SSE MCP servers", - "type": "string" - } - }, - "required": [ - "command", - "env", - "args", - "type", - "url", - "headers" - ] - } - }, - "models": { - "description": "Preferred model configurations for large and small model types", - "type": "object", - "properties": { - "large": { - "description": "Large", - "type": "object", - "properties": { - "max_tokens": { - "description": "Max Tokens", - "type": "integer" - }, - "model_id": { - "description": "Model I D", - "type": "string", - "enum": [ - "claude-3-opus", - "claude-3-haiku", - "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", - "gpt-4", - "gpt-3.5-turbo", - "gpt-4-turbo", - "gpt-4o", - "gpt-4o-mini", - "o1-preview", - "o1-mini", - "gemini-2.5-pro", - "gemini-2.5-flash", - "grok-beta", - "anthropic/claude-3.5-sonnet", - "anthropic/claude-3.5-haiku" - ] - }, - "provider": { - "description": "Provider", - "type": "string", - "enum": [ - "anthropic", - "openai", - "gemini", - "xai", - "openrouter" - ] - }, - "reasoning_effort": { - "description": "Reasoning Effort", - "type": "string" - }, - "think": { - "description": "Enable think", - "type": "boolean" - } - }, - "required": [ - "model_id", - "provider" - ] - }, - "small": { - "description": "Small", - "$ref": "#/definitions/PreferredModel" - } - }, - "required": [ - "large", - "small" - ] - }, - "options": { - "description": "General application options and settings", - "type": "object", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "Agent": { "properties": { - "context_paths": { - "description": "Context Paths", - "type": "array", - "items": { - "type": "string" - }, - "default": [ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md" - ] - }, - "data_directory": { - "description": "Data Directory", + "id": { "type": "string", - "default": ".crush" + "enum": [ + "coder", + "task", + "coder", + "task" + ], + "title": "Agent ID", + "description": "Unique identifier for the agent" }, - "debug": { - "description": "Enable debug", - "type": "boolean", - "default": false + "name": { + "type": "string", + "title": "Name", + "description": "Display name of the agent" }, - "debug_lsp": { - "description": "Enable debug l s p", - "type": "boolean", - "default": false + "description": { + "type": "string", + "title": "Description", + "description": "Description of what the agent does" }, - "disable_auto_summarize": { - "description": "Disable Auto Summarize", + "disabled": { "type": "boolean", + "title": "Disabled", + "description": "Whether this agent is disabled", "default": false }, - "tui": { - "description": "T U I", - "type": "object", - "properties": { - "compact_mode": { - "description": "Enable compact mode", - "type": "boolean", - "default": false - } - }, - "required": [ - "compact_mode" - ] - } - }, - "required": [ - "context_paths", - "tui", - "debug", - "debug_lsp", - "disable_auto_summarize", - "data_directory" - ] - }, - "providers": { - "description": "LLM provider configurations", - "type": "object", - "additionalProperties": { - "type": "object", - "properties": { - "api_key": { - "description": "API key for authenticating with the provider", - "type": "string" - }, - "base_url": { - "description": "Base URL for the provider API (required for custom providers)", - "type": "string" - }, - "default_large_model": { - "description": "Default Large Model", - "type": "string" - }, - "default_small_model": { - "description": "Default Small Model", - "type": "string" - }, - "disabled": { - "description": "Disabled", - "type": "boolean", - "default": false - }, - "extra_headers": { - "description": "Additional HTTP headers to send with requests", - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "extra_params": { - "description": "Additional provider-specific parameters", - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "id": { - "description": "I D", - "type": "string", - "enum": [ - "anthropic", - "openai", - "gemini", - "xai", - "openrouter" - ] - }, - "models": { - "description": "Models", - "type": "array", - "items": { - "type": "object", - "properties": { - "can_reason": { - "description": "Enable can reason", - "type": "boolean" - }, - "context_window": { - "description": "Context Window", - "type": "integer" - }, - "cost_per_1m_in": { - "description": "Cost Per1 M In", - "type": "number" - }, - "cost_per_1m_in_cached": { - "description": "Cost Per1 M In Cached", - "type": "number" - }, - "cost_per_1m_out": { - "description": "Cost Per1 M Out", - "type": "number" - }, - "cost_per_1m_out_cached": { - "description": "Cost Per1 M Out Cached", - "type": "number" - }, - "default_max_tokens": { - "description": "Default Max Tokens", - "type": "integer" - }, - "has_reasoning_effort": { - "description": "Enable has reasoning effort", - "type": "boolean" - }, - "id": { - "description": "I D", - "type": "string" - }, - "model": { - "description": "Name", - "type": "string" - }, - "reasoning_effort": { - "description": "Reasoning Effort", - "type": "string" - }, - "supports_attachments": { - "description": "Enable supports images", - "type": "boolean" - } - }, - "required": [ - "id", - "model", - "cost_per_1m_in", - "cost_per_1m_out", - "cost_per_1m_in_cached", - "cost_per_1m_out_cached", - "context_window", - "default_max_tokens", - "can_reason", - "reasoning_effort", - "has_reasoning_effort", - "supports_attachments" - ] - } - }, - "provider_type": { - "description": "Provider Type", - "type": "string" - } - }, - "required": [ - "id", - "provider_type", - "disabled" - ] - } - } - }, - "required": [ - "models", - "options" - ], - "definitions": { - "Agent": { - "type": "object", - "properties": { - "allowed_lsp": { - "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", - "type": "array", - "items": { - "type": "string" - } - }, - "allowed_mcp": { - "description": "Map of MCP servers this agent can use and their allowed tools", - "type": "object", - "additionalProperties": { - "type": "array", - "items": { - "type": "string" - } - } + "model": { + "type": "string", + "enum": [ + "large", + "small", + "large", + "small" + ], + "title": "Model Type", + "description": "Type of model to use (large or small)" }, "allowed_tools": { - "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", - "type": "array", "items": { "type": "string", "enum": [ @@ -501,652 +56,200 @@ "write", "agent" ] - } - }, - "context_paths": { - "description": "Custom context paths for this agent (additive to global context paths)", + }, "type": "array", - "items": { - "type": "string" - } - }, - "description": { - "description": "Description", - "type": "string" - }, - "disabled": { - "description": "Disabled", - "type": "boolean", - "default": false - }, - "id": { - "description": "I D", - "type": "string", - "enum": [ - "coder", - "task" - ] - }, - "model": { - "description": "Model", - "type": "string", - "enum": [ - "large", - "small" - ] + "title": "Allowed Tools", + "description": "List of tools this agent is allowed to use (if nil all tools are allowed)" }, - "name": { - "description": "Name", - "type": "string" - } - }, - "required": [ - "id", - "name", - "disabled", - "model", - "allowed_tools", - "allowed_mcp", - "allowed_lsp", - "context_paths" - ] - }, - "Config": { - "type": "object", - "properties": { - "agents": { - "description": "Agent configurations for different tasks", - "type": "object", + "allowed_mcp": { "additionalProperties": { - "type": "object", - "properties": { - "allowed_lsp": { - "description": "List of LSP servers this agent can use (if nil, all LSPs are allowed)", - "type": "array", - "items": { - "type": "string" - } - }, - "allowed_mcp": { - "description": "Map of MCP servers this agent can use and their allowed tools", - "type": "object", - "additionalProperties": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "allowed_tools": { - "description": "List of tools this agent is allowed to use (if nil, all tools are allowed)", - "type": "array", - "items": { - "type": "string", - "enum": [ - "bash", - "edit", - "fetch", - "glob", - "grep", - "ls", - "sourcegraph", - "view", - "write", - "agent" - ] - } - }, - "context_paths": { - "description": "Custom context paths for this agent (additive to global context paths)", - "type": "array", - "items": { - "type": "string" - } - }, - "description": { - "description": "Description", - "type": "string" - }, - "disabled": { - "description": "Disabled", - "type": "boolean", - "default": false - }, - "id": { - "description": "I D", - "type": "string", - "enum": [ - "coder", - "task" - ] - }, - "model": { - "description": "Model", - "type": "string", - "enum": [ - "large", - "small" - ] - }, - "name": { - "description": "Name", - "type": "string" - } + "items": { + "type": "string" }, - "required": [ - "id", - "name", - "disabled", - "model", - "allowed_tools", - "allowed_mcp", - "allowed_lsp", - "context_paths" - ] - } - }, - "lsp": { - "description": "Language Server Protocol configurations", + "type": "array" + }, "type": "object", - "additionalProperties": { - "type": "object", - "properties": { - "args": { - "description": "Args", - "type": "array", - "items": { - "type": "string" - } - }, - "command": { - "description": "Command", - "type": "string" - }, - "enabled": { - "description": "Disabled", - "type": "boolean", - "default": true - }, - "options": { - "description": "Options" - } - }, - "required": [ - "enabled", - "command", - "args", - "options" - ] - } + "title": "Allowed MCP", + "description": "Map of MCP servers this agent can use and their allowed tools" }, - "mcp": { - "description": "Model Control Protocol server configurations", - "type": "object", - "additionalProperties": { - "type": "object", - "properties": { - "args": { - "description": "Args", - "type": "array", - "items": { - "type": "string" - } - }, - "command": { - "description": "Command to execute for stdio MCP servers", - "type": "string" - }, - "env": { - "description": "Env list", - "type": "array", - "items": { - "type": "string" - } - }, - "headers": { - "description": "HTTP headers for SSE MCP servers", - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "type": { - "description": "Type", - "type": "string", - "enum": [ - "stdio", - "sse" - ], - "default": "stdio" - }, - "url": { - "description": "URL for SSE MCP servers", - "type": "string" - } - }, - "required": [ - "command", - "env", - "args", - "type", - "url", - "headers" - ] - } - }, - "models": { - "description": "Preferred model configurations for large and small model types", - "type": "object", - "properties": { - "large": { - "description": "Large", - "type": "object", - "properties": { - "max_tokens": { - "description": "Max Tokens", - "type": "integer" - }, - "model_id": { - "description": "Model I D", - "type": "string", - "enum": [ - "claude-3-opus", - "claude-3-haiku", - "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", - "gpt-4", - "gpt-3.5-turbo", - "gpt-4-turbo", - "gpt-4o", - "gpt-4o-mini", - "o1-preview", - "o1-mini", - "gemini-2.5-pro", - "gemini-2.5-flash", - "grok-beta", - "anthropic/claude-3.5-sonnet", - "anthropic/claude-3.5-haiku" - ] - }, - "provider": { - "description": "Provider", - "type": "string", - "enum": [ - "anthropic", - "openai", - "gemini", - "xai", - "openrouter" - ] - }, - "reasoning_effort": { - "description": "Reasoning Effort", - "type": "string" - }, - "think": { - "description": "Enable think", - "type": "boolean" - } - }, - "required": [ - "model_id", - "provider" - ] - }, - "small": { - "description": "Small", - "$ref": "#/definitions/PreferredModel" - } + "allowed_lsp": { + "items": { + "type": "string" }, - "required": [ - "large", - "small" - ] + "type": "array", + "title": "Allowed LSP", + "description": "List of LSP servers this agent can use (if nil all LSPs are allowed)" }, - "options": { - "description": "General application options and settings", - "type": "object", - "properties": { - "context_paths": { - "description": "Context Paths", - "type": "array", - "items": { - "type": "string" - }, - "default": [ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md" - ] - }, - "data_directory": { - "description": "Data Directory", - "type": "string", - "default": ".crush" - }, - "debug": { - "description": "Enable debug", - "type": "boolean", - "default": false - }, - "debug_lsp": { - "description": "Enable debug l s p", - "type": "boolean", - "default": false - }, - "disable_auto_summarize": { - "description": "Disable Auto Summarize", - "type": "boolean", - "default": false - }, - "tui": { - "description": "T U I", - "type": "object", - "properties": { - "compact_mode": { - "description": "Enable compact mode", - "type": "boolean", - "default": false - } - }, - "required": [ - "compact_mode" - ] - } + "context_paths": { + "items": { + "type": "string" }, - "required": [ - "context_paths", - "tui", - "debug", - "debug_lsp", - "disable_auto_summarize", - "data_directory" - ] - }, - "providers": { - "description": "LLM provider configurations", - "type": "object", - "additionalProperties": { - "type": "object", - "properties": { - "api_key": { - "description": "API key for authenticating with the provider", - "type": "string" - }, - "base_url": { - "description": "Base URL for the provider API (required for custom providers)", - "type": "string" - }, - "default_large_model": { - "description": "Default Large Model", - "type": "string" - }, - "default_small_model": { - "description": "Default Small Model", - "type": "string" - }, - "disabled": { - "description": "Disabled", - "type": "boolean", - "default": false - }, - "extra_headers": { - "description": "Additional HTTP headers to send with requests", - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "extra_params": { - "description": "Additional provider-specific parameters", - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "id": { - "description": "I D", - "type": "string", - "enum": [ - "anthropic", - "openai", - "gemini", - "xai", - "openrouter" - ] - }, - "models": { - "description": "Models", - "type": "array", - "items": { - "type": "object", - "properties": { - "can_reason": { - "description": "Enable can reason", - "type": "boolean" - }, - "context_window": { - "description": "Context Window", - "type": "integer" - }, - "cost_per_1m_in": { - "description": "Cost Per1 M In", - "type": "number" - }, - "cost_per_1m_in_cached": { - "description": "Cost Per1 M In Cached", - "type": "number" - }, - "cost_per_1m_out": { - "description": "Cost Per1 M Out", - "type": "number" - }, - "cost_per_1m_out_cached": { - "description": "Cost Per1 M Out Cached", - "type": "number" - }, - "default_max_tokens": { - "description": "Default Max Tokens", - "type": "integer" - }, - "has_reasoning_effort": { - "description": "Enable has reasoning effort", - "type": "boolean" - }, - "id": { - "description": "I D", - "type": "string" - }, - "model": { - "description": "Name", - "type": "string" - }, - "reasoning_effort": { - "description": "Reasoning Effort", - "type": "string" - }, - "supports_attachments": { - "description": "Enable supports images", - "type": "boolean" - } - }, - "required": [ - "id", - "model", - "cost_per_1m_in", - "cost_per_1m_out", - "cost_per_1m_in_cached", - "cost_per_1m_out_cached", - "context_window", - "default_max_tokens", - "can_reason", - "reasoning_effort", - "has_reasoning_effort", - "supports_attachments" - ] - } - }, - "provider_type": { - "description": "Provider Type", - "type": "string" - } - }, - "required": [ - "id", - "provider_type", - "disabled" - ] - } + "type": "array", + "title": "Context Paths", + "description": "Custom context paths for this agent (additive to global context paths)" } }, + "type": "object", "required": [ - "models", - "options" + "model" ] }, "LSPConfig": { - "type": "object", "properties": { - "args": { - "description": "Args", - "type": "array", - "items": { - "type": "string" - } - }, - "command": { - "description": "Command", - "type": "string" - }, "enabled": { - "description": "Disabled", "type": "boolean", + "title": "Enabled", + "description": "Whether this LSP server is enabled", "default": true }, + "command": { + "type": "string", + "title": "Command", + "description": "Command to execute for the LSP server" + }, + "args": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Arguments", + "description": "Command line arguments for the LSP server" + }, "options": { - "description": "Options" + "title": "Options", + "description": "LSP server specific options" } }, + "type": "object", "required": [ - "enabled", - "command", - "args", - "options" + "command" ] }, "MCP": { - "type": "object", "properties": { - "args": { - "description": "Args", - "type": "array", - "items": { - "type": "string" - } - }, "command": { - "description": "Command to execute for stdio MCP servers", - "type": "string" + "type": "string", + "title": "Command", + "description": "Command to execute for stdio MCP servers" }, "env": { - "description": "Env list", - "type": "array", "items": { "type": "string" - } + }, + "type": "array", + "title": "Environment", + "description": "Environment variables for the MCP server" }, - "headers": { - "description": "HTTP headers for SSE MCP servers", - "type": "object", - "additionalProperties": { + "args": { + "items": { "type": "string" - } + }, + "type": "array", + "title": "Arguments", + "description": "Command line arguments for the MCP server" }, "type": { - "description": "Type", "type": "string", "enum": [ + "stdio", + "sse", "stdio", "sse" ], + "title": "Type", + "description": "Type of MCP connection", "default": "stdio" }, "url": { - "description": "URL for SSE MCP servers", - "type": "string" + "type": "string", + "title": "URL", + "description": "URL for SSE MCP servers" + }, + "headers": { + "additionalProperties": { + "type": "string" + }, + "type": "object", + "title": "Headers", + "description": "HTTP headers for SSE MCP servers" } }, + "type": "object", "required": [ "command", - "env", - "args", - "type", - "url", - "headers" + "type" ] }, "Model": { - "type": "object", "properties": { - "can_reason": { - "description": "Enable can reason", - "type": "boolean" + "id": { + "type": "string", + "title": "Model ID", + "description": "Unique identifier for the model" }, - "context_window": { - "description": "Context Window", - "type": "integer" + "model": { + "type": "string", + "title": "Model Name", + "description": "Display name of the model" }, "cost_per_1m_in": { - "description": "Cost Per1 M In", - "type": "number" - }, - "cost_per_1m_in_cached": { - "description": "Cost Per1 M In Cached", - "type": "number" + "type": "number", + "minimum": 0, + "title": "Input Cost", + "description": "Cost per 1 million input tokens" }, "cost_per_1m_out": { - "description": "Cost Per1 M Out", - "type": "number" + "type": "number", + "minimum": 0, + "title": "Output Cost", + "description": "Cost per 1 million output tokens" }, - "cost_per_1m_out_cached": { - "description": "Cost Per1 M Out Cached", - "type": "number" + "cost_per_1m_in_cached": { + "type": "number", + "minimum": 0, + "title": "Cached Input Cost", + "description": "Cost per 1 million cached input tokens" }, - "default_max_tokens": { - "description": "Default Max Tokens", - "type": "integer" + "cost_per_1m_out_cached": { + "type": "number", + "minimum": 0, + "title": "Cached Output Cost", + "description": "Cost per 1 million cached output tokens" }, - "has_reasoning_effort": { - "description": "Enable has reasoning effort", - "type": "boolean" + "context_window": { + "type": "integer", + "minimum": 1, + "title": "Context Window", + "description": "Maximum context window size in tokens" }, - "id": { - "description": "I D", - "type": "string" + "default_max_tokens": { + "type": "integer", + "minimum": 1, + "title": "Default Max Tokens", + "description": "Default maximum tokens for responses" }, - "model": { - "description": "Name", - "type": "string" + "can_reason": { + "type": "boolean", + "title": "Can Reason", + "description": "Whether the model supports reasoning capabilities" }, "reasoning_effort": { - "description": "Reasoning Effort", - "type": "string" + "type": "string", + "title": "Reasoning Effort", + "description": "Default reasoning effort level for reasoning models" + }, + "has_reasoning_effort": { + "type": "boolean", + "title": "Has Reasoning Effort", + "description": "Whether the model supports reasoning effort configuration" }, "supports_attachments": { - "description": "Enable supports images", - "type": "boolean" + "type": "boolean", + "title": "Supports Images", + "description": "Whether the model supports image attachments" } }, + "type": "object", "required": [ "id", "model", @@ -1163,15 +266,14 @@ ] }, "Options": { - "description": "General application options and settings", - "type": "object", "properties": { "context_paths": { - "description": "Context Paths", - "type": "array", "items": { "type": "string" }, + "type": "array", + "title": "Context Paths", + "description": "List of paths to search for context files", "default": [ ".github/copilot-instructions.md", ".cursorrules", @@ -1188,318 +290,419 @@ "CRUSH.local.md" ] }, - "data_directory": { - "description": "Data Directory", - "type": "string", - "default": ".crush" + "tui": { + "$ref": "#/$defs/TUIOptions", + "title": "TUI Options", + "description": "Terminal UI configuration options" }, "debug": { - "description": "Enable debug", "type": "boolean", + "title": "Debug", + "description": "Enable debug logging", "default": false }, "debug_lsp": { - "description": "Enable debug l s p", "type": "boolean", + "title": "Debug LSP", + "description": "Enable LSP debug logging", "default": false }, "disable_auto_summarize": { - "description": "Disable Auto Summarize", "type": "boolean", + "title": "Disable Auto Summarize", + "description": "Disable automatic conversation summarization", "default": false }, - "tui": { - "description": "T U I", - "type": "object", - "properties": { - "compact_mode": { - "description": "Enable compact mode", - "type": "boolean", - "default": false - } - }, - "required": [ - "compact_mode" - ] + "data_directory": { + "type": "string", + "title": "Data Directory", + "description": "Directory for storing application data", + "default": ".crush" } }, - "required": [ - "context_paths", - "tui", - "debug", - "debug_lsp", - "disable_auto_summarize", - "data_directory" - ] + "type": "object" }, "PreferredModel": { - "description": "Large", - "type": "object", "properties": { - "max_tokens": { - "description": "Max Tokens", - "type": "integer" - }, "model_id": { - "description": "Model I D", "type": "string", "enum": [ - "claude-3-opus", - "claude-3-haiku", - "claude-3-5-sonnet-20241022", + "claude-opus-4-20250514", + "claude-sonnet-4-20250514", + "claude-3-7-sonnet-20250219", "claude-3-5-haiku-20241022", - "gpt-4", - "gpt-3.5-turbo", - "gpt-4-turbo", + "claude-3-5-sonnet-20240620", + "claude-3-5-sonnet-20241022", + "codex-mini-latest", + "o4-mini", + "o3", + "o3-pro", + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4.1-nano", + "gpt-4.5-preview", + "o3-mini", + "gpt-4o", + "gpt-4o-mini", + "gemini-2.5-pro", + "gemini-2.5-flash", + "codex-mini-latest", + "o4-mini", + "o3", + "o3-pro", + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4.1-nano", + "gpt-4.5-preview", + "o3-mini", "gpt-4o", "gpt-4o-mini", - "o1-preview", - "o1-mini", + "anthropic.claude-opus-4-20250514-v1:0", + "anthropic.claude-sonnet-4-20250514-v1:0", + "anthropic.claude-3-7-sonnet-20250219-v1:0", + "anthropic.claude-3-5-haiku-20241022-v1:0", "gemini-2.5-pro", "gemini-2.5-flash", - "grok-beta", + "grok-3-mini", + "grok-3", + "mistralai/mistral-small-3.2-24b-instruct:free", + "mistralai/mistral-small-3.2-24b-instruct", + "minimax/minimax-m1:extended", + "minimax/minimax-m1", + "google/gemini-2.5-flash-lite-preview-06-17", + "google/gemini-2.5-flash", + "google/gemini-2.5-pro", + "openai/o3-pro", + "x-ai/grok-3-mini", + "x-ai/grok-3", + "mistralai/magistral-small-2506", + "mistralai/magistral-medium-2506", + "mistralai/magistral-medium-2506:thinking", + "google/gemini-2.5-pro-preview", + "deepseek/deepseek-r1-0528", + "anthropic/claude-opus-4", + "anthropic/claude-sonnet-4", + "mistralai/devstral-small:free", + "mistralai/devstral-small", + "google/gemini-2.5-flash-preview-05-20", + "google/gemini-2.5-flash-preview-05-20:thinking", + "openai/codex-mini", + "mistralai/mistral-medium-3", + "google/gemini-2.5-pro-preview-05-06", + "arcee-ai/caller-large", + "arcee-ai/virtuoso-large", + "arcee-ai/virtuoso-medium-v2", + "qwen/qwen3-30b-a3b", + "qwen/qwen3-14b", + "qwen/qwen3-32b", + "qwen/qwen3-235b-a22b", + "google/gemini-2.5-flash-preview", + "google/gemini-2.5-flash-preview:thinking", + "openai/o4-mini-high", + "openai/o3", + "openai/o4-mini", + "openai/gpt-4.1", + "openai/gpt-4.1-mini", + "openai/gpt-4.1-nano", + "x-ai/grok-3-mini-beta", + "x-ai/grok-3-beta", + "meta-llama/llama-4-maverick", + "meta-llama/llama-4-scout", + "all-hands/openhands-lm-32b-v0.1", + "google/gemini-2.5-pro-exp-03-25", + "deepseek/deepseek-chat-v3-0324:free", + "deepseek/deepseek-chat-v3-0324", + "mistralai/mistral-small-3.1-24b-instruct:free", + "mistralai/mistral-small-3.1-24b-instruct", + "ai21/jamba-1.6-large", + "ai21/jamba-1.6-mini", + "openai/gpt-4.5-preview", + "google/gemini-2.0-flash-lite-001", + "anthropic/claude-3.7-sonnet", + "anthropic/claude-3.7-sonnet:beta", + "anthropic/claude-3.7-sonnet:thinking", + "mistralai/mistral-saba", + "openai/o3-mini-high", + "google/gemini-2.0-flash-001", + "qwen/qwen-turbo", + "qwen/qwen-plus", + "qwen/qwen-max", + "openai/o3-mini", + "mistralai/mistral-small-24b-instruct-2501", + "deepseek/deepseek-r1-distill-llama-70b", + "deepseek/deepseek-r1", + "mistralai/codestral-2501", + "deepseek/deepseek-chat", + "openai/o1", + "x-ai/grok-2-1212", + "meta-llama/llama-3.3-70b-instruct", + "amazon/nova-lite-v1", + "amazon/nova-micro-v1", + "amazon/nova-pro-v1", + "openai/gpt-4o-2024-11-20", + "mistralai/mistral-large-2411", + "mistralai/mistral-large-2407", + "mistralai/pixtral-large-2411", + "thedrummer/unslopnemo-12b", + "anthropic/claude-3.5-haiku:beta", + "anthropic/claude-3.5-haiku", + "anthropic/claude-3.5-haiku-20241022:beta", + "anthropic/claude-3.5-haiku-20241022", + "anthropic/claude-3.5-sonnet:beta", "anthropic/claude-3.5-sonnet", - "anthropic/claude-3.5-haiku" - ] + "x-ai/grok-beta", + "mistralai/ministral-8b", + "mistralai/ministral-3b", + "nvidia/llama-3.1-nemotron-70b-instruct", + "google/gemini-flash-1.5-8b", + "meta-llama/llama-3.2-11b-vision-instruct", + "meta-llama/llama-3.2-3b-instruct", + "qwen/qwen-2.5-72b-instruct", + "mistralai/pixtral-12b", + "cohere/command-r-plus-08-2024", + "cohere/command-r-08-2024", + "microsoft/phi-3.5-mini-128k-instruct", + "nousresearch/hermes-3-llama-3.1-70b", + "openai/gpt-4o-2024-08-06", + "meta-llama/llama-3.1-405b-instruct", + "meta-llama/llama-3.1-70b-instruct", + "meta-llama/llama-3.1-8b-instruct", + "mistralai/mistral-nemo", + "openai/gpt-4o-mini", + "openai/gpt-4o-mini-2024-07-18", + "anthropic/claude-3.5-sonnet-20240620:beta", + "anthropic/claude-3.5-sonnet-20240620", + "mistralai/mistral-7b-instruct-v0.3", + "mistralai/mistral-7b-instruct:free", + "mistralai/mistral-7b-instruct", + "microsoft/phi-3-mini-128k-instruct", + "microsoft/phi-3-medium-128k-instruct", + "google/gemini-flash-1.5", + "openai/gpt-4o-2024-05-13", + "openai/gpt-4o", + "openai/gpt-4o:extended", + "meta-llama/llama-3-8b-instruct", + "meta-llama/llama-3-70b-instruct", + "mistralai/mixtral-8x22b-instruct", + "openai/gpt-4-turbo", + "google/gemini-pro-1.5", + "cohere/command-r-plus", + "cohere/command-r-plus-04-2024", + "cohere/command-r", + "anthropic/claude-3-haiku:beta", + "anthropic/claude-3-haiku", + "anthropic/claude-3-opus:beta", + "anthropic/claude-3-opus", + "anthropic/claude-3-sonnet:beta", + "anthropic/claude-3-sonnet", + "cohere/command-r-03-2024", + "mistralai/mistral-large", + "openai/gpt-3.5-turbo-0613", + "openai/gpt-4-turbo-preview", + "mistralai/mistral-small", + "mistralai/mistral-tiny", + "mistralai/mixtral-8x7b-instruct", + "openai/gpt-4-1106-preview", + "mistralai/mistral-7b-instruct-v0.1", + "openai/gpt-3.5-turbo-16k", + "openai/gpt-4", + "openai/gpt-4-0314" + ], + "title": "Model ID", + "description": "ID of the preferred model" }, "provider": { - "description": "Provider", "type": "string", "enum": [ "anthropic", "openai", "gemini", + "azure", + "bedrock", + "vertex", "xai", "openrouter" - ] + ], + "title": "Provider", + "description": "Provider for the preferred model" }, "reasoning_effort": { - "description": "Reasoning Effort", - "type": "string" + "type": "string", + "title": "Reasoning Effort", + "description": "Override reasoning effort for this model" + }, + "max_tokens": { + "type": "integer", + "minimum": 1, + "title": "Max Tokens", + "description": "Override max tokens for this model" }, "think": { - "description": "Enable think", - "type": "boolean" + "type": "boolean", + "title": "Think", + "description": "Enable thinking for reasoning models", + "default": false } }, + "type": "object", "required": [ "model_id", "provider" ] }, "PreferredModels": { - "description": "Preferred model configurations for large and small model types", - "type": "object", "properties": { "large": { - "description": "Large", - "type": "object", - "properties": { - "max_tokens": { - "description": "Max Tokens", - "type": "integer" - }, - "model_id": { - "description": "Model I D", - "type": "string", - "enum": [ - "claude-3-opus", - "claude-3-haiku", - "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", - "gpt-4", - "gpt-3.5-turbo", - "gpt-4-turbo", - "gpt-4o", - "gpt-4o-mini", - "o1-preview", - "o1-mini", - "gemini-2.5-pro", - "gemini-2.5-flash", - "grok-beta", - "anthropic/claude-3.5-sonnet", - "anthropic/claude-3.5-haiku" - ] - }, - "provider": { - "description": "Provider", - "type": "string", - "enum": [ - "anthropic", - "openai", - "gemini", - "xai", - "openrouter" - ] - }, - "reasoning_effort": { - "description": "Reasoning Effort", - "type": "string" - }, - "think": { - "description": "Enable think", - "type": "boolean" - } - }, - "required": [ - "model_id", - "provider" - ] + "$ref": "#/$defs/PreferredModel", + "title": "Large Model", + "description": "Preferred model configuration for large model type" }, "small": { - "description": "Small", - "$ref": "#/definitions/PreferredModel" + "$ref": "#/$defs/PreferredModel", + "title": "Small Model", + "description": "Preferred model configuration for small model type" } }, - "required": [ - "large", - "small" - ] + "type": "object" }, "ProviderConfig": { - "type": "object", "properties": { - "api_key": { - "description": "API key for authenticating with the provider", - "type": "string" + "id": { + "type": "string", + "enum": [ + "anthropic", + "openai", + "gemini", + "azure", + "bedrock", + "vertex", + "xai", + "openrouter" + ], + "title": "Provider ID", + "description": "Unique identifier for the provider" }, "base_url": { - "description": "Base URL for the provider API (required for custom providers)", - "type": "string" + "type": "string", + "title": "Base URL", + "description": "Base URL for the provider API (required for custom providers)" }, - "default_large_model": { - "description": "Default Large Model", - "type": "string" + "provider_type": { + "type": "string", + "title": "Provider Type", + "description": "Type of the provider (openai" }, - "default_small_model": { - "description": "Default Small Model", - "type": "string" + "api_key": { + "type": "string", + "title": "API Key", + "description": "API key for authenticating with the provider" }, "disabled": { - "description": "Disabled", "type": "boolean", + "title": "Disabled", + "description": "Whether this provider is disabled", "default": false }, "extra_headers": { - "description": "Additional HTTP headers to send with requests", - "type": "object", "additionalProperties": { "type": "string" - } + }, + "type": "object", + "title": "Extra Headers", + "description": "Additional HTTP headers to send with requests" }, "extra_params": { - "description": "Additional provider-specific parameters", - "type": "object", "additionalProperties": { "type": "string" - } + }, + "type": "object", + "title": "Extra Parameters", + "description": "Additional provider-specific parameters" }, - "id": { - "description": "I D", + "default_large_model": { "type": "string", - "enum": [ - "anthropic", - "openai", - "gemini", - "xai", - "openrouter" - ] + "title": "Default Large Model", + "description": "Default model ID for large model type" + }, + "default_small_model": { + "type": "string", + "title": "Default Small Model", + "description": "Default model ID for small model type" }, "models": { - "description": "Models", - "type": "array", "items": { - "type": "object", - "properties": { - "can_reason": { - "description": "Enable can reason", - "type": "boolean" - }, - "context_window": { - "description": "Context Window", - "type": "integer" - }, - "cost_per_1m_in": { - "description": "Cost Per1 M In", - "type": "number" - }, - "cost_per_1m_in_cached": { - "description": "Cost Per1 M In Cached", - "type": "number" - }, - "cost_per_1m_out": { - "description": "Cost Per1 M Out", - "type": "number" - }, - "cost_per_1m_out_cached": { - "description": "Cost Per1 M Out Cached", - "type": "number" - }, - "default_max_tokens": { - "description": "Default Max Tokens", - "type": "integer" - }, - "has_reasoning_effort": { - "description": "Enable has reasoning effort", - "type": "boolean" - }, - "id": { - "description": "I D", - "type": "string" - }, - "model": { - "description": "Name", - "type": "string" - }, - "reasoning_effort": { - "description": "Reasoning Effort", - "type": "string" - }, - "supports_attachments": { - "description": "Enable supports images", - "type": "boolean" - } - }, - "required": [ - "id", - "model", - "cost_per_1m_in", - "cost_per_1m_out", - "cost_per_1m_in_cached", - "cost_per_1m_out_cached", - "context_window", - "default_max_tokens", - "can_reason", - "reasoning_effort", - "has_reasoning_effort", - "supports_attachments" - ] - } - }, - "provider_type": { - "description": "Provider Type", - "type": "string" + "$ref": "#/$defs/Model" + }, + "type": "array", + "title": "Models", + "description": "List of available models for this provider" } }, + "type": "object", "required": [ - "id", - "provider_type", - "disabled" + "provider_type" ] }, "TUIOptions": { - "description": "T U I", - "type": "object", "properties": { "compact_mode": { - "description": "Enable compact mode", "type": "boolean", + "title": "Compact Mode", + "description": "Enable compact mode for the TUI", "default": false } }, + "type": "object", "required": [ "compact_mode" ] } - } + }, + "properties": { + "models": { + "$ref": "#/$defs/PreferredModels", + "title": "Models", + "description": "Preferred model configurations for large and small model types" + }, + "providers": { + "additionalProperties": { + "$ref": "#/$defs/ProviderConfig" + }, + "type": "object", + "title": "Providers", + "description": "LLM provider configurations" + }, + "agents": { + "additionalProperties": { + "$ref": "#/$defs/Agent" + }, + "type": "object", + "title": "Agents", + "description": "Agent configurations for different tasks" + }, + "mcp": { + "additionalProperties": { + "$ref": "#/$defs/MCP" + }, + "type": "object", + "title": "MCP", + "description": "Model Control Protocol server configurations" + }, + "lsp": { + "additionalProperties": { + "$ref": "#/$defs/LSPConfig" + }, + "type": "object", + "title": "LSP", + "description": "Language Server Protocol configurations" + }, + "options": { + "$ref": "#/$defs/Options", + "title": "Options", + "description": "General application options and settings" + } + }, + "type": "object", + "title": "Crush Configuration", + "description": "Configuration schema for the Crush application" } diff --git a/crush.json b/crush.json index 1b04ea6c24f8b64a3a12ceb47551f3177fa66302..4937665c513258840f1efb4f88fb2bdd73f6ff68 100644 --- a/crush.json +++ b/crush.json @@ -1,6 +1,7 @@ { + "$schema": "./crush-schema.json", "lsp": { - "Go": { + "go": { "command": "gopls" } } diff --git a/cspell.json b/cspell.json index d62c817e8c8699e6172e576eb0f91602dd8417a3..d98b1326e54c8b62c7ad700fe19b4cbbe3e4f672 100644 --- a/cspell.json +++ b/cspell.json @@ -1 +1 @@ -{"language":"en","flagWords":[],"version":"0.2","words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai"]} \ No newline at end of file +{"flagWords":[],"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema"],"version":"0.2","language":"en"} \ No newline at end of file diff --git a/go.mod b/go.mod index 29bcf5e483cb43dc7038d9220bc8b5259719e96d..0e3f432a3e3262516cc910f1e3c0c309b05d0e8a 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,14 @@ require ( mvdan.cc/sh/v3 v3.11.0 ) +require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect +) + require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.13.0 // indirect diff --git a/go.sum b/go.sum index ddbea31f43ce1c6ec4235f3c7af740cca67795fe..234b9663cab709b9babde6e324374fa7022cb1bc 100644 --- a/go.sum +++ b/go.sum @@ -64,8 +64,12 @@ github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3v github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/bmatcuk/doublestar/v4 v4.8.1 h1:54Bopc5c2cAvhLRAzqOGCYHYyhcDHsFF4wWIR5wKP38= github.com/bmatcuk/doublestar/v4 v4.8.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/charlievieth/fastwalk v1.0.11 h1:5sLT/q9+d9xMdpKExawLppqvXFZCVKf6JHnr2u/ufj8= github.com/charlievieth/fastwalk v1.0.11/go.mod h1:yGy1zbxog41ZVMcKA/i8ojXLFsuayX5VvwhQVoj9PBI= github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e h1:99Ugtt633rqauFsXjZobZmtkNpeaWialfj8dl6COC6A= @@ -147,6 +151,9 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -158,6 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mark3labs/mcp-go v0.17.0 h1:5Ps6T7qXr7De/2QTqs9h6BKeZ/qdeUeGrgM5lPzi930= github.com/mark3labs/mcp-go v0.17.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -243,6 +252,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/internal/config/config.go b/internal/config/config.go index 3caf9f01c4afdba4dd2c29c43fc690dd360173ef..0cccbdfffe5dc1c42e21b03bd9a20d7112005bee 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,6 +14,7 @@ import ( "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/logging" + "github.com/invopop/jsonschema" ) const ( @@ -55,18 +56,18 @@ const ( ) type Model struct { - ID string `json:"id"` - Name string `json:"model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - ReasoningEffort string `json:"reasoning_effort"` - HasReasoningEffort bool `json:"has_reasoning_effort"` - SupportsImages bool `json:"supports_attachments"` + ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model"` + Name string `json:"model" jsonschema:"title=Model Name,description=Display name of the model"` + CostPer1MIn float64 `json:"cost_per_1m_in" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"` + CostPer1MOut float64 `json:"cost_per_1m_out" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"` + ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"` + DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"` + CanReason bool `json:"can_reason" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"` + ReasoningEffort string `json:"reasoning_effort" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"` + HasReasoningEffort bool `json:"has_reasoning_effort" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"` + SupportsImages bool `json:"supports_attachments" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"` } type VertexAIOptions struct { @@ -76,46 +77,46 @@ type VertexAIOptions struct { } type ProviderConfig struct { - ID provider.InferenceProvider `json:"id"` - BaseURL string `json:"base_url,omitempty"` - ProviderType provider.Type `json:"provider_type"` - APIKey string `json:"api_key,omitempty"` - Disabled bool `json:"disabled"` - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + ID provider.InferenceProvider `json:"id,omitempty" jsonschema:"title=Provider ID,description=Unique identifier for the provider"` + BaseURL string `json:"base_url,omitempty" jsonschema:"title=Base URL,description=Base URL for the provider API (required for custom providers)"` + ProviderType provider.Type `json:"provider_type" jsonschema:"title=Provider Type,description=Type of the provider (openai, anthropic, etc.)"` + APIKey string `json:"api_key,omitempty" jsonschema:"title=API Key,description=API key for authenticating with the provider"` + Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this provider is disabled,default=false"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"title=Extra Headers,description=Additional HTTP headers to send with requests"` // used for e.x for vertex to set the project - ExtraParams map[string]string `json:"extra_params,omitempty"` + ExtraParams map[string]string `json:"extra_params,omitempty" jsonschema:"title=Extra Parameters,description=Additional provider-specific parameters"` - DefaultLargeModel string `json:"default_large_model,omitempty"` - DefaultSmallModel string `json:"default_small_model,omitempty"` + DefaultLargeModel string `json:"default_large_model,omitempty" jsonschema:"title=Default Large Model,description=Default model ID for large model type"` + DefaultSmallModel string `json:"default_small_model,omitempty" jsonschema:"title=Default Small Model,description=Default model ID for small model type"` - Models []Model `json:"models,omitempty"` + Models []Model `json:"models,omitempty" jsonschema:"title=Models,description=List of available models for this provider"` } type Agent struct { - ID AgentID `json:"id"` - Name string `json:"name"` - Description string `json:"description,omitempty"` + ID AgentID `json:"id,omitempty" jsonschema:"title=Agent ID,description=Unique identifier for the agent,enum=coder,enum=task"` + Name string `json:"name,omitempty" jsonschema:"title=Name,description=Display name of the agent"` + Description string `json:"description,omitempty" jsonschema:"title=Description,description=Description of what the agent does"` // This is the id of the system prompt used by the agent - Disabled bool `json:"disabled"` + Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this agent is disabled,default=false"` - Model ModelType `json:"model"` + Model ModelType `json:"model" jsonschema:"title=Model Type,description=Type of model to use (large or small),enum=large,enum=small"` // The available tools for the agent // if this is nil, all tools are available - AllowedTools []string `json:"allowed_tools"` + AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"title=Allowed Tools,description=List of tools this agent is allowed to use (if nil all tools are allowed)"` // this tells us which MCPs are available for this agent // if this is empty all mcps are available // the string array is the list of tools from the AllowedMCP the agent has available // if the string array is nil, all tools from the AllowedMCP are available - AllowedMCP map[string][]string `json:"allowed_mcp"` + AllowedMCP map[string][]string `json:"allowed_mcp,omitempty" jsonschema:"title=Allowed MCP,description=Map of MCP servers this agent can use and their allowed tools"` // The list of LSPs that this agent can use // if this is nil, all LSPs are available - AllowedLSP []string `json:"allowed_lsp"` + AllowedLSP []string `json:"allowed_lsp,omitempty" jsonschema:"title=Allowed LSP,description=List of LSP servers this agent can use (if nil all LSPs are allowed)"` // Overrides the context paths for this agent - ContextPaths []string `json:"context_paths"` + ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=Custom context paths for this agent (additive to global context paths)"` } type MCPType string @@ -126,69 +127,70 @@ const ( ) type MCP struct { - Command string `json:"command"` - Env []string `json:"env"` - Args []string `json:"args"` - Type MCPType `json:"type"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` + Command string `json:"command" jsonschema:"title=Command,description=Command to execute for stdio MCP servers"` + Env []string `json:"env,omitempty" jsonschema:"title=Environment,description=Environment variables for the MCP server"` + Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the MCP server"` + Type MCPType `json:"type" jsonschema:"title=Type,description=Type of MCP connection,enum=stdio,enum=sse,default=stdio"` + URL string `json:"url,omitempty" jsonschema:"title=URL,description=URL for SSE MCP servers"` + // TODO: maybe make it possible to get the value from the env + Headers map[string]string `json:"headers,omitempty" jsonschema:"title=Headers,description=HTTP headers for SSE MCP servers"` } type LSPConfig struct { - Disabled bool `json:"enabled"` - Command string `json:"command"` - Args []string `json:"args"` - Options any `json:"options"` + Disabled bool `json:"enabled,omitempty" jsonschema:"title=Enabled,description=Whether this LSP server is enabled,default=true"` + Command string `json:"command" jsonschema:"title=Command,description=Command to execute for the LSP server"` + Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the LSP server"` + Options any `json:"options,omitempty" jsonschema:"title=Options,description=LSP server specific options"` } type TUIOptions struct { - CompactMode bool `json:"compact_mode"` + CompactMode bool `json:"compact_mode" jsonschema:"title=Compact Mode,description=Enable compact mode for the TUI,default=false"` // Here we can add themes later or any TUI related options } type Options struct { - ContextPaths []string `json:"context_paths"` - TUI TUIOptions `json:"tui"` - Debug bool `json:"debug"` - DebugLSP bool `json:"debug_lsp"` - DisableAutoSummarize bool `json:"disable_auto_summarize"` + ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=List of paths to search for context files"` + TUI TUIOptions `json:"tui,omitempty" jsonschema:"title=TUI Options,description=Terminal UI configuration options"` + Debug bool `json:"debug,omitempty" jsonschema:"title=Debug,description=Enable debug logging,default=false"` + DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"title=Debug LSP,description=Enable LSP debug logging,default=false"` + DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"title=Disable Auto Summarize,description=Disable automatic conversation summarization,default=false"` // Relative to the cwd - DataDirectory string `json:"data_directory"` + DataDirectory string `json:"data_directory,omitempty" jsonschema:"title=Data Directory,description=Directory for storing application data,default=.crush"` } type PreferredModel struct { - ModelID string `json:"model_id"` - Provider provider.InferenceProvider `json:"provider"` + ModelID string `json:"model_id" jsonschema:"title=Model ID,description=ID of the preferred model"` + Provider provider.InferenceProvider `json:"provider" jsonschema:"title=Provider,description=Provider for the preferred model"` // ReasoningEffort overrides the default reasoning effort for this model - ReasoningEffort string `json:"reasoning_effort,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Override reasoning effort for this model"` // MaxTokens overrides the default max tokens for this model - MaxTokens int64 `json:"max_tokens,omitempty"` + MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"title=Max Tokens,description=Override max tokens for this model,minimum=1"` // Think indicates if the model should think, only applicable for anthropic reasoning models - Think bool `json:"think,omitempty"` + Think bool `json:"think,omitempty" jsonschema:"title=Think,description=Enable thinking for reasoning models,default=false"` } type PreferredModels struct { - Large PreferredModel `json:"large"` - Small PreferredModel `json:"small"` + Large PreferredModel `json:"large,omitempty" jsonschema:"title=Large Model,description=Preferred model configuration for large model type"` + Small PreferredModel `json:"small,omitempty" jsonschema:"title=Small Model,description=Preferred model configuration for small model type"` } type Config struct { - Models PreferredModels `json:"models"` + Models PreferredModels `json:"models,omitempty" jsonschema:"title=Models,description=Preferred model configurations for large and small model types"` // List of configured providers - Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` + Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty" jsonschema:"title=Providers,description=LLM provider configurations"` // List of configured agents - Agents map[AgentID]Agent `json:"agents,omitempty"` + Agents map[AgentID]Agent `json:"agents,omitempty" jsonschema:"title=Agents,description=Agent configurations for different tasks"` // List of configured MCPs - MCP map[string]MCP `json:"mcp,omitempty"` + MCP map[string]MCP `json:"mcp,omitempty" jsonschema:"title=MCP,description=Model Control Protocol server configurations"` // List of configured LSPs - LSP map[string]LSPConfig `json:"lsp,omitempty"` + LSP map[string]LSPConfig `json:"lsp,omitempty" jsonschema:"title=LSP,description=Language Server Protocol configurations"` // Miscellaneous options - Options Options `json:"options"` + Options Options `json:"options,omitempty" jsonschema:"title=Options,description=General application options and settings"` } var ( @@ -502,27 +504,23 @@ func mergeAgents(base, global, local *Config) { } for agentID, newAgent := range cfg.Agents { if _, ok := base.Agents[agentID]; !ok { - // New agent - apply defaults - newAgent.ID = agentID // Ensure the ID is set correctly + newAgent.ID = agentID if newAgent.Model == "" { - newAgent.Model = LargeModel // Default model type + newAgent.Model = LargeModel } - // Context paths are always additive - start with global, then add custom if len(newAgent.ContextPaths) > 0 { newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...) } else { - newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only + newAgent.ContextPaths = base.Options.ContextPaths } base.Agents[agentID] = newAgent } else { baseAgent := base.Agents[agentID] - // Special handling for known agents - only allow model changes if agentID == AgentCoder || agentID == AgentTask { if newAgent.Model != "" { baseAgent.Model = newAgent.Model } - // For known agents, only allow MCP and LSP configuration if newAgent.AllowedMCP != nil { baseAgent.AllowedMCP = newAgent.AllowedMCP } @@ -534,7 +532,6 @@ func mergeAgents(base, global, local *Config) { baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) } } else { - // Custom agents - allow full merging if newAgent.Name != "" { baseAgent.Name = newAgent.Name } @@ -544,13 +541,11 @@ func mergeAgents(base, global, local *Config) { if newAgent.Model != "" { baseAgent.Model = newAgent.Model } else if baseAgent.Model == "" { - baseAgent.Model = LargeModel // Default fallback + baseAgent.Model = LargeModel } - // Boolean fields - always update (including false values) baseAgent.Disabled = newAgent.Disabled - // Slice/Map fields - update if provided (including empty slices/maps) if newAgent.AllowedTools != nil { baseAgent.AllowedTools = newAgent.AllowedTools } @@ -560,7 +555,6 @@ func mergeAgents(base, global, local *Config) { if newAgent.AllowedLSP != nil { baseAgent.AllowedLSP = newAgent.AllowedLSP } - // Context paths are additive for custom agents too if len(newAgent.ContextPaths) > 0 { baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...) } @@ -596,6 +590,7 @@ func mergeProviderConfigs(base, global, local *Config) { continue } for providerName, p := range cfg.Providers { + p.ID = providerName if _, ok := base.Providers[providerName]; !ok { base.Providers[providerName] = p } else { @@ -616,36 +611,36 @@ func mergeProviderConfigs(base, global, local *Config) { base.Providers = finalProviders } -func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig { - switch providerId { +func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig { + switch providerID { case provider.InferenceProviderAnthropic: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeAnthropic, } case provider.InferenceProviderOpenAI: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeOpenAI, } case provider.InferenceProviderGemini: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeGemini, } case provider.InferenceProviderBedrock: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeBedrock, } case provider.InferenceProviderAzure: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeAzure, } case provider.InferenceProviderOpenRouter: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeOpenAI, BaseURL: "https://openrouter.ai/api/v1", ExtraHeaders: map[string]string{ @@ -655,18 +650,18 @@ func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig } case provider.InferenceProviderXAI: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeXAI, BaseURL: "https://api.x.ai/v1", } case provider.InferenceProviderVertexAI: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeVertexAI, } default: return ProviderConfig{ - ID: providerId, + ID: providerID, ProviderType: provider.TypeOpenAI, } } @@ -1433,3 +1428,27 @@ func (c *Config) validateCompleteness(errors *ValidationErrors) { } } } + +// JSONSchemaExtend adds custom schema properties for AgentID +func (AgentID) JSONSchemaExtend(schema *jsonschema.Schema) { + schema.Enum = []any{ + string(AgentCoder), + string(AgentTask), + } +} + +// JSONSchemaExtend adds custom schema properties for ModelType +func (ModelType) JSONSchemaExtend(schema *jsonschema.Schema) { + schema.Enum = []any{ + string(LargeModel), + string(SmallModel), + } +} + +// JSONSchemaExtend adds custom schema properties for MCPType +func (MCPType) JSONSchemaExtend(schema *jsonschema.Schema) { + schema.Enum = []any{ + string(MCPStdio), + string(MCPSse), + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b48a9eba0a92a9f9239d6f6e3526c24cc8790ac9..2b4764489dd7795df0473c22eb529bbc65fc9a2b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1051,7 +1051,6 @@ func TestProviderValidation_KnownProviderValid(t *testing.T) { DefaultMaxTokens: 2048, }, }, - }, }, } diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go index 73d39d761b15ae682b272f356c62234aaa3ca0ad..801afdd8d6c9891eb47fa53294c047917b031637 100644 --- a/internal/config/provider_mock.go +++ b/internal/config/provider_mock.go @@ -126,28 +126,28 @@ func MockProviders() []provider.Provider { DefaultSmallModelID: "grok-3-mini", Models: []provider.Model{ { - ID: "grok-3", - Name: "Grok 3", - CostPer1MIn: 3.0, - CostPer1MOut: 15.0, - CostPer1MInCached: 0.0, + ID: "grok-3", + Name: "Grok 3", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + CostPer1MInCached: 0.0, CostPer1MOutCached: 0.75, - ContextWindow: 131072, - DefaultMaxTokens: 20000, - CanReason: false, - SupportsImages: false, + ContextWindow: 131072, + DefaultMaxTokens: 20000, + CanReason: false, + SupportsImages: false, }, { - ID: "grok-3-mini", - Name: "Grok 3 Mini", - CostPer1MIn: 0.3, - CostPer1MOut: 0.5, - CostPer1MInCached: 0.0, + ID: "grok-3-mini", + Name: "Grok 3 Mini", + CostPer1MIn: 0.3, + CostPer1MOut: 0.5, + CostPer1MInCached: 0.0, CostPer1MOutCached: 0.075, - ContextWindow: 131072, - DefaultMaxTokens: 20000, - CanReason: true, - SupportsImages: false, + ContextWindow: 131072, + DefaultMaxTokens: 20000, + CanReason: true, + SupportsImages: false, }, }, }, diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 8f2a31f06ab121fa049e7ca8bed159976cb2e92f..53a084d244c2d48538a514e8c72530a3850782d7 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -78,4 +78,4 @@ func TestProviders_ModelCapabilities(t *testing.T) { assert.True(t, foundReasoning) assert.True(t, foundNonReasoning) -} \ No newline at end of file +} From 2a10ac99f755f4d25f8385b58d5d77e47fa747e4 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 14:06:25 +0200 Subject: [PATCH 26/40] chore: implement large/small model switching --- go.mod | 2 +- .../tui/components/dialogs/models/keys.go | 10 +- .../tui/components/dialogs/models/models.go | 116 +++++++++++++----- internal/tui/tui.go | 8 +- 4 files changed, 98 insertions(+), 38 deletions(-) diff --git a/go.mod b/go.mod index 0e3f432a3e3262516cc910f1e3c0c309b05d0e8a..7795febc9cf458a4830ac5ecfe849c3557024786 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/fsnotify/fsnotify v1.8.0 github.com/go-logfmt/logfmt v0.6.0 github.com/google/uuid v1.6.0 + github.com/invopop/jsonschema v0.13.0 github.com/mark3labs/mcp-go v0.17.0 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.25.0 @@ -42,7 +43,6 @@ require ( require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/invopop/jsonschema v0.13.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect ) diff --git a/internal/tui/components/dialogs/models/keys.go b/internal/tui/components/dialogs/models/keys.go index f9c3c549c30c6d95282e88818c247ec80a2f0e4c..dd34f82860b21a98a6ae4a5eea57aa052b383f33 100644 --- a/internal/tui/components/dialogs/models/keys.go +++ b/internal/tui/components/dialogs/models/keys.go @@ -8,13 +8,14 @@ type KeyMap struct { Select, Next, Previous, + Tab, Close key.Binding } func DefaultKeyMap() KeyMap { return KeyMap{ Select: key.NewBinding( - key.WithKeys("enter", "tab", "ctrl+y"), + key.WithKeys("enter", "ctrl+y"), key.WithHelp("enter", "confirm"), ), Next: key.NewBinding( @@ -25,6 +26,10 @@ func DefaultKeyMap() KeyMap { key.WithKeys("up", "ctrl+p"), key.WithHelp("↑", "previous item"), ), + Tab: key.NewBinding( + key.WithKeys("tab"), + key.WithHelp("tab", "toggle model type"), + ), Close: key.NewBinding( key.WithKeys("esc"), key.WithHelp("esc", "cancel"), @@ -38,6 +43,7 @@ func (k KeyMap) KeyBindings() []key.Binding { k.Select, k.Next, k.Previous, + k.Tab, k.Close, } } @@ -57,10 +63,10 @@ func (k KeyMap) FullHelp() [][]key.Binding { func (k KeyMap) ShortHelp() []key.Binding { return []key.Binding{ key.NewBinding( - key.WithKeys("down", "up"), key.WithHelp("↑↓", "choose"), ), + k.Tab, k.Select, k.Close, } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 730cc5d5693dd06821c1ff30cb26398b14325c62..8e81777d7e3e6c45628cb8839b52c484e47f5640 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -22,9 +22,15 @@ const ( defaultWidth = 60 ) +const ( + LargeModelType int = iota + SmallModelType +) + // ModelSelectedMsg is sent when a model is selected type ModelSelectedMsg struct { - Model config.PreferredModel + Model config.PreferredModel + ModelType config.ModelType } // CloseModelDialogMsg is sent when a model is selected @@ -42,12 +48,13 @@ type ModelOption struct { type modelDialogCmp struct { width int - wWidth int // Width of the terminal window - wHeight int // Height of the terminal window + wWidth int + wHeight int modelList list.ListModel keyMap KeyMap help help.Model + modelType int } func NewModelDialogCmp() ModelDialog { @@ -80,34 +87,13 @@ func NewModelDialogCmp() ModelDialog { width: defaultWidth, keyMap: DefaultKeyMap(), help: help, + modelType: LargeModelType, } } func (m *modelDialogCmp) Init() tea.Cmd { - providers := config.Providers() - - modelItems := []util.Model{} - selectIndex := 0 - agentModel := config.GetAgentModel(config.AgentCoder) - agentProvider := config.GetAgentProvider(config.AgentCoder) - for _, provider := range providers { - name := provider.Name - if name == "" { - name = string(provider.ID) - } - modelItems = append(modelItems, commands.NewItemSection(name)) - for _, model := range provider.Models { - modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ - Provider: provider, - Model: model, - })) - if model.ID == agentModel.ID && provider.ID == agentProvider.ID { - selectIndex = len(modelItems) - 1 // Set the selected index to the current model - } - } - } - - return tea.Sequence(m.modelList.Init(), m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex)) + m.SetModelType(m.modelType) + return m.modelList.Init() } func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -115,24 +101,41 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.WindowSizeMsg: m.wWidth = msg.Width m.wHeight = msg.Height + m.SetModelType(m.modelType) return m, m.modelList.SetSize(m.listWidth(), m.listHeight()) case tea.KeyPressMsg: switch { case key.Matches(msg, m.keyMap.Select): selectedItemInx := m.modelList.SelectedIndex() if selectedItemInx == list.NoSelection { - return m, nil // No item selected, do nothing + return m, nil } items := m.modelList.Items() selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption) + var modelType config.ModelType + if m.modelType == LargeModelType { + modelType = config.LargeModel + } else { + modelType = config.SmallModel + } + return m, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), - util.CmdHandler(ModelSelectedMsg{Model: config.PreferredModel{ - ModelID: selectedItem.Model.ID, - Provider: selectedItem.Provider.ID, - }}), + util.CmdHandler(ModelSelectedMsg{ + Model: config.PreferredModel{ + ModelID: selectedItem.Model.ID, + Provider: selectedItem.Provider.ID, + }, + ModelType: modelType, + }), ) + case key.Matches(msg, m.keyMap.Tab): + if m.modelType == LargeModelType { + return m, m.SetModelType(SmallModelType) + } else { + return m, m.SetModelType(LargeModelType) + } case key.Matches(msg, m.keyMap.Close): return m, util.CmdHandler(dialogs.CloseDialogMsg{}) default: @@ -147,9 +150,10 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *modelDialogCmp) View() tea.View { t := styles.CurrentTheme() listView := m.modelList.View() + radio := m.modelTypeRadio() content := lipgloss.JoinVertical( lipgloss.Left, - t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)), + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio), listView.String(), "", t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), @@ -197,3 +201,49 @@ func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor { func (m *modelDialogCmp) ID() dialogs.DialogID { return ModelsDialogID } + +func (m *modelDialogCmp) modelTypeRadio() string { + t := styles.CurrentTheme() + choices := []string{"Large", "Small"} + iconSelected := "◉" + iconUnselected := "○" + if m.modelType == LargeModelType { + return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1]) + } + return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1]) +} + +func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { + m.modelType = modelType + + providers := config.Providers() + modelItems := []util.Model{} + selectIndex := 0 + + cfg := config.Get() + var currentModel config.PreferredModel + if m.modelType == LargeModelType { + currentModel = cfg.Models.Large + } else { + currentModel = cfg.Models.Small + } + + for _, provider := range providers { + name := provider.Name + if name == "" { + name = string(provider.ID) + } + modelItems = append(modelItems, commands.NewItemSection(name)) + for _, model := range provider.Models { + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ + Provider: provider, + Model: model, + })) + if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider { + selectIndex = len(modelItems) - 1 // Set the selected index to the current model + } + } + } + + return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex)) +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index e3c974ca002529ce1ac90f420afcc5eedf2a45fd..46a3f3c2a4ee2a6f4b2441a3d10e98f7e46eca2a 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -172,7 +172,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Model Switch case models.ModelSelectedMsg: - config.UpdatePreferredModel(config.LargeModel, msg.Model) + config.UpdatePreferredModel(msg.ModelType, msg.Model) // Update the agent with the new model/provider configuration if err := a.app.UpdateAgentModel(); err != nil { @@ -180,7 +180,11 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err)) } - return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", msg.Model.ModelID)) + modelTypeName := "large" + if msg.ModelType == config.SmallModel { + modelTypeName = "small" + } + return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.ModelID)) // File Picker case chat.OpenFilePickerMsg: From 9eab1ba11efd7caade2e0c1c66a4c4b67d32026e Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 14:13:08 +0200 Subject: [PATCH 27/40] fix: resolve golangci-lint issues - Fix gofumpt formatting in models.go - Replace deprecated database methods with context-aware versions: - db.Ping() -> db.PingContext() - db.Exec() -> db.ExecContext() - db.Begin() -> db.BeginTx() - Update Taskfile.yaml to match CI lint configuration --- Taskfile.yaml | 4 ++-- internal/tui/components/dialogs/models/models.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Taskfile.yaml b/Taskfile.yaml index 9453eb363513d5b9e0987a637d761eef1bfe63c3..077c994ea7c6ce79fabbc18f76cfb3ef27af8d5e 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -6,12 +6,12 @@ tasks: lint: desc: Run base linters cmds: - - golangci-lint run + - golangci-lint run --path-mode=abs --config=".golangci.yml" --timeout=5m lint-fix: desc: Run base linters and fix issues cmds: - - golangci-lint run --fix + - golangci-lint run --path-mode=abs --config=".golangci.yml" --timeout=5m --fix test: desc: Run tests diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 8e81777d7e3e6c45628cb8839b52c484e47f5640..c2c65e5d90c1d9bf9bfca708692f0f288c27a594 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -219,7 +219,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { providers := config.Providers() modelItems := []util.Model{} selectIndex := 0 - + cfg := config.Get() var currentModel config.PreferredModel if m.modelType == LargeModelType { @@ -227,7 +227,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { } else { currentModel = cfg.Models.Small } - + for _, provider := range providers { name := provider.Name if name == "" { From f0091c38b199a0e49d58c899a15ac6f4b6178283 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 30 Jun 2025 08:44:12 -0400 Subject: [PATCH 28/40] chore: adjust wording of small/large model labels --- internal/tui/components/dialogs/models/models.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index c2c65e5d90c1d9bf9bfca708692f0f288c27a594..b5ca9fe79c51e8534dade154af1f0be553d47d56 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -204,13 +204,13 @@ func (m *modelDialogCmp) ID() dialogs.DialogID { func (m *modelDialogCmp) modelTypeRadio() string { t := styles.CurrentTheme() - choices := []string{"Large", "Small"} + choices := []string{"Large Task", "Small Task"} iconSelected := "◉" iconUnselected := "○" if m.modelType == LargeModelType { - return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1]) + return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1]) } - return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1]) + return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1]) } func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { From f3f4d9d59d80490d34a3c256812e30fe9baf0a1e Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 15:00:05 +0200 Subject: [PATCH 29/40] chore: change how summarize sessions gets triggered --- internal/config/config.go | 38 +++++++++++++++++++++++++++++++++++++ internal/llm/agent/agent.go | 5 +++++ internal/tui/tui.go | 19 +++++++++++++++++-- 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 0cccbdfffe5dc1c42e21b03bd9a20d7112005bee..e0e54ff9139fabf299c57aadab8b384335260f22 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -860,6 +860,44 @@ func GetAgentModel(agentID AgentID) Model { return Model{} } +// GetAgentEffectiveMaxTokens returns the effective max tokens for an agent, +// considering any overrides from the preferred model configuration +func GetAgentEffectiveMaxTokens(agentID AgentID) int64 { + cfg := Get() + agent, ok := cfg.Agents[agentID] + if !ok { + logging.Error("Agent not found", "agent_id", agentID) + return 0 + } + + var preferredModel PreferredModel + switch agent.Model { + case LargeModel: + preferredModel = cfg.Models.Large + case SmallModel: + preferredModel = cfg.Models.Small + default: + logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model) + preferredModel = cfg.Models.Large // Fallback to large model + } + + // Get the base model configuration + baseModel := GetAgentModel(agentID) + if baseModel.ID == "" { + return 0 + } + + // Start with the default max tokens from the base model + maxTokens := baseModel.DefaultMaxTokens + + // Override with preferred model max tokens if set + if preferredModel.MaxTokens > 0 { + maxTokens = preferredModel.MaxTokens + } + + return maxTokens +} + func GetAgentProvider(agentID AgentID) ProviderConfig { cfg := Get() agent, ok := cfg.Agents[agentID] diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index d165921f639ffee7127e4044c42d154f091a0dca..5568c8a619287619900e8d5e5d5d44c2e85de446 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -50,6 +50,7 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] Model() config.Model + EffectiveMaxTokens() int64 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() @@ -229,6 +230,10 @@ func (a *agent) Model() config.Model { return config.GetAgentModel(a.agentCfg.ID) } +func (a *agent) EffectiveMaxTokens() int64 { + return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID) +} + func (a *agent) Cancel(sessionID string) { // Cancel regular requests if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 46a3f3c2a4ee2a6f4b2441a3d10e98f7e46eca2a..a37a55cfbbd05eba217b10691a15d0caed7c183a 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -228,8 +228,23 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if err == nil { model := a.app.CoderAgent.Model() contextWindow := model.ContextWindow - tokens := session.CompletionTokens + session.PromptTokens - if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { + usedTokens := session.CompletionTokens + session.PromptTokens + remainingTokens := contextWindow - usedTokens + + // Get effective max tokens for this agent (considering overrides) + maxTokens := a.app.CoderAgent.EffectiveMaxTokens() + + // Apply 10% margin to max tokens + maxTokensWithMargin := int64(float64(maxTokens) * 1.1) + + // Trigger auto-summarize if remaining tokens < max tokens + 10% margin + // Also ensure we have a reasonable minimum threshold to avoid too-frequent summaries + minThreshold := int64(1000) // Minimum 1000 tokens remaining before triggering + if maxTokensWithMargin < minThreshold { + maxTokensWithMargin = minThreshold + } + + if remainingTokens < maxTokensWithMargin && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{ Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false), From d0fcb3f25dc0ec1e0393fe94b3ea7fee1e62306c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 17:26:09 +0200 Subject: [PATCH 30/40] chore: fix lint --- internal/db/connect.go | 3 ++- internal/tui/tui.go | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/internal/db/connect.go b/internal/db/connect.go index 6452dabdb3a5de6ffb0f618062631dfe4b25102d..95df2e491d13918d8ea66a3f8312438a42e81b4f 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "fmt" "os" @@ -15,7 +16,7 @@ import ( "github.com/pressly/goose/v3" ) -func Connect() (*sql.DB, error) { +func Connect(ctx context.Context) (*sql.DB, error) { dataDir := config.Get().Options.DataDirectory if dataDir == "" { return nil, fmt.Errorf("data.dir is not set") diff --git a/internal/tui/tui.go b/internal/tui/tui.go index a37a55cfbbd05eba217b10691a15d0caed7c183a..fb77e5a8f30d8f4cd290d3a8d4026694c690a109 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -230,20 +230,20 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { contextWindow := model.ContextWindow usedTokens := session.CompletionTokens + session.PromptTokens remainingTokens := contextWindow - usedTokens - + // Get effective max tokens for this agent (considering overrides) maxTokens := a.app.CoderAgent.EffectiveMaxTokens() - + // Apply 10% margin to max tokens maxTokensWithMargin := int64(float64(maxTokens) * 1.1) - + // Trigger auto-summarize if remaining tokens < max tokens + 10% margin // Also ensure we have a reasonable minimum threshold to avoid too-frequent summaries minThreshold := int64(1000) // Minimum 1000 tokens remaining before triggering if maxTokensWithMargin < minThreshold { maxTokensWithMargin = minThreshold } - + if remainingTokens < maxTokensWithMargin && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{ From 5bcb64760b0a212f1c7cbe4e2695fb6a1e490cd6 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 17:36:13 +0200 Subject: [PATCH 31/40] chore: rebase --- go.mod | 14 ++++---------- go.sum | 2 -- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 7795febc9cf458a4830ac5ecfe849c3557024786..a24a21f859d4a0ea3ba8ecf203252da84823ada7 100644 --- a/go.mod +++ b/go.mod @@ -40,13 +40,6 @@ require ( mvdan.cc/sh/v3 v3.11.0 ) -require ( - github.com/bahlo/generic-list-go v0.2.0 // indirect - github.com/buger/jsonparser v1.1.1 // indirect - github.com/mailru/easyjson v0.7.7 // indirect - github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect -) - require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.13.0 // indirect @@ -72,6 +65,8 @@ require ( github.com/aws/smithy-go v1.20.3 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/charmbracelet/colorprofile v0.3.1 // indirect github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20250611152503-f53cdd7e01ef @@ -95,6 +90,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mfridman/interpolate v0.0.2 // indirect @@ -110,15 +106,13 @@ require ( github.com/rivo/uniseg v0.4.7 github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.12.0 - github.com/spf13/cast v1.7.1 // indirect github.com/spf13/pflag v1.0.6 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.7.8 // indirect diff --git a/go.sum b/go.sum index 234b9663cab709b9babde6e324374fa7022cb1bc..799dce9b3d9be5fc202a3bc5300e61bdb430ac2a 100644 --- a/go.sum +++ b/go.sum @@ -88,8 +88,6 @@ github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 h1:L07QkDqR github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa h1:lphz0Z3rsiOtMYiz8axkT24i9yFiueDhJbzyNUADmME= github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa/go.mod h1:xBlh2Yi3DL3zy/2n15kITpg0YZardf/aa/hgUaIM6Rk= -github.com/charmbracelet/x/exp/charmtone v0.0.0-20250603201427-c31516f43444 h1:IJDiTgVE56gkAGfq0lBEloWgkXMk4hl/bmuPoicI4R0= -github.com/charmbracelet/x/exp/charmtone v0.0.0-20250603201427-c31516f43444/go.mod h1:T9jr8CzFpjhFVHjNjKwbAD7KwBNyFnj2pntAO7F2zw0= github.com/charmbracelet/x/exp/charmtone v0.0.0-20250627134340-c144409e381c h1:2GELBLPgfSbHU53bsQhR9XIgNuVZ6w+Rz8RWV5Lq+A4= github.com/charmbracelet/x/exp/charmtone v0.0.0-20250627134340-c144409e381c/go.mod h1:T9jr8CzFpjhFVHjNjKwbAD7KwBNyFnj2pntAO7F2zw0= github.com/charmbracelet/x/exp/golden v0.0.0-20250207160936-21c02780d27a h1:FsHEJ52OC4VuTzU8t+n5frMjLvpYWEznSr/u8tnkCYw= From 2b6766a58dd38dc5337b4b921b736a5f4cb804e3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 18:29:46 +0200 Subject: [PATCH 32/40] chore: fix shell --- internal/config/shell.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/config/shell.go b/internal/config/shell.go index 74931bfefc3a9e16e830fac2c3478a6f0d5396f2..b7c3c8c5a787def8ff28aec677193f5ac58b652a 100644 --- a/internal/config/shell.go +++ b/internal/config/shell.go @@ -7,8 +7,8 @@ import ( "strings" "time" - "github.com/charmbracelet/crush/internal/llm/tools/shell" "github.com/charmbracelet/crush/internal/logging" + "github.com/charmbracelet/crush/internal/shell" ) // ExecuteCommand executes a shell command and returns the output @@ -18,7 +18,7 @@ func ExecuteCommand(ctx context.Context, command string, workingDir string) (str workingDir = WorkingDirectory() } - persistentShell := shell.GetPersistentShell(workingDir) + persistentShell := shell.NewShell(&shell.Options{WorkingDir: workingDir}) stdout, stderr, err := persistentShell.Exec(ctx, command) if err != nil { From 8992d029b2cc6e9df9daf8782a58dbc097b27b93 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 21:40:36 +0200 Subject: [PATCH 33/40] chore: small fixes to the config --- crush-schema.json | 7 +- internal/config/config.go | 20 +++--- internal/llm/provider/openai.go | 2 +- .../tui/components/dialogs/models/models.go | 70 +++++++++++++++++++ 4 files changed, 83 insertions(+), 16 deletions(-) diff --git a/crush-schema.json b/crush-schema.json index 35bd388ceca5b0aca5401d66db963b77f4f584f8..680ba31196e276c290fd5040b36b23c26cb12414 100644 --- a/crush-schema.json +++ b/crush-schema.json @@ -187,7 +187,7 @@ "title": "Model ID", "description": "Unique identifier for the model" }, - "model": { + "name": { "type": "string", "title": "Model Name", "description": "Display name of the model" @@ -252,10 +252,7 @@ "type": "object", "required": [ "id", - "model", - "cost_per_1m_in", - "cost_per_1m_out", - "cost_per_1m_in_cached", + "name", "cost_per_1m_out_cached", "context_window", "default_max_tokens", diff --git a/internal/config/config.go b/internal/config/config.go index e0e54ff9139fabf299c57aadab8b384335260f22..b8a70505da30a1e3f274e95ef89f606d7b6db9d4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -56,18 +56,18 @@ const ( ) type Model struct { - ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model"` - Name string `json:"model" jsonschema:"title=Model Name,description=Display name of the model"` - CostPer1MIn float64 `json:"cost_per_1m_in" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"` - CostPer1MOut float64 `json:"cost_per_1m_out" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"` + ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model, the API model"` + Name string `json:"name" jsonschema:"title=Model Name,description=Display name of the model"` + CostPer1MIn float64 `json:"cost_per_1m_in,omitempty" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"` + CostPer1MOut float64 `json:"cost_per_1m_out,omitempty" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached,omitempty" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached,omitempty" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"` ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"` DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"` - CanReason bool `json:"can_reason" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"` - ReasoningEffort string `json:"reasoning_effort" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"` - HasReasoningEffort bool `json:"has_reasoning_effort" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"` - SupportsImages bool `json:"supports_attachments" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"` + CanReason bool `json:"can_reason,omitempty" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"` + ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"` + HasReasoningEffort bool `json:"has_reasoning_effort,omitempty" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"` + SupportsImages bool `json:"supports_attachments,omitempty" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"` } type VertexAIOptions struct { diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index e045029651f3e9fc158c9f38cf810584e2c06724..bddf820c2d3ccf9bba1a683ed4fe469d05fa31bf 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -183,7 +183,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar case "high": params.ReasoningEffort = shared.ReasoningEffortHigh default: - params.ReasoningEffort = shared.ReasoningEffortMedium + params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort) } } else { params.MaxTokens = openai.Int(maxTokens) diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index b5ca9fe79c51e8534dade154af1f0be553d47d56..aa7a505bd19af72a55e134fc0b077085a761faa6 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -1,6 +1,8 @@ package models import ( + "slices" + "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" @@ -228,7 +230,75 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { currentModel = cfg.Models.Small } + // Create a map to track which providers we've already added + addedProviders := make(map[provider.InferenceProvider]bool) + + // First, add any configured providers that are not in the known providers list + // These should appear at the top of the list + knownProviders := provider.KnownProviders() + for providerID, providerConfig := range cfg.Providers { + if providerConfig.Disabled { + continue + } + + // Check if this provider is not in the known providers list + if !slices.Contains(knownProviders, providerID) { + // Convert config provider to provider.Provider format + configProvider := provider.Provider{ + Name: string(providerID), // Use provider ID as name for unknown providers + ID: providerID, + Models: make([]provider.Model, len(providerConfig.Models)), + } + + // Convert models + for i, model := range providerConfig.Models { + configProvider.Models[i] = provider.Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + HasReasoningEffort: model.HasReasoningEffort, + DefaultReasoningEffort: model.ReasoningEffort, + SupportsImages: model.SupportsImages, + } + } + + // Add this unknown provider to the list + name := configProvider.Name + if name == "" { + name = string(configProvider.ID) + } + modelItems = append(modelItems, commands.NewItemSection(name)) + for _, model := range configProvider.Models { + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ + Provider: configProvider, + Model: model, + })) + if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider { + selectIndex = len(modelItems) - 1 // Set the selected index to the current model + } + } + addedProviders[providerID] = true + } + } + + // Then add the known providers from the predefined list for _, provider := range providers { + // Skip if we already added this provider as an unknown provider + if addedProviders[provider.ID] { + continue + } + + // Check if this provider is configured and not disabled + if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled { + continue + } + name := provider.Name if name == "" { name = string(provider.ID) From 11b65a0849f59807b6e87288bf913a53967a3ff9 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 22:20:37 +0200 Subject: [PATCH 34/40] chode: remove old docs --- CRUSH.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/CRUSH.md b/CRUSH.md index 6ab75f143258e63d4881bf17447304b7d1ca6ba5..c308db631e006dd1c3834b6b470a02f4c41ff53b 100644 --- a/CRUSH.md +++ b/CRUSH.md @@ -47,11 +47,6 @@ func TestYourFunction(t *testing.T) { } ``` -The mock providers include: -- **Anthropic**: claude-3-opus, claude-3-haiku, claude-3-5-sonnet-20241022, claude-3-5-haiku-20241022 -- **OpenAI**: gpt-4, gpt-3.5-turbo, gpt-4-turbo, gpt-4o, gpt-4o-mini -- **Gemini**: gemini-2.5-pro, gemini-2.5-flash - ## Formatting - ALWAYS format any Go code you write. From 9f098d10b229f6370eefa0d064d58cd9cab7c686 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 22:31:11 +0200 Subject: [PATCH 35/40] chore: remove debug log --- internal/tui/components/logs/details.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index 46ec1ce02d1e4573c04f99f45e1af2a2a3a4a731..c73876aff8a6898809a170fec176def30af81314 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -161,7 +161,6 @@ func (i *detailCmp) GetSize() (int, int) { } func (i *detailCmp) SetSize(width int, height int) tea.Cmd { - logging.Info("Setting size for detail component", "width", width, "height", height) i.width = width i.height = height i.viewport.SetWidth(i.width - 4) From 290f1f470634195e329cff031d3933871f719e4e Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 30 Jun 2025 22:43:04 +0200 Subject: [PATCH 36/40] chore: small fixes --- internal/lsp/client.go | 4 ++-- internal/tui/components/dialogs/models/keys.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/lsp/client.go b/internal/lsp/client.go index c04f10a8a924f8725609aace7d5363fe1751a791..24ff0238c355edb5499640b93f9e06f0f07568c9 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -96,10 +96,10 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er go func() { scanner := bufio.NewScanner(stderr) for scanner.Scan() { - fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text()) + logging.Error("LSP Server", "err", scanner.Text()) } if err := scanner.Err(); err != nil { - fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err) + logging.Error("Error reading", "err", err) } }() diff --git a/internal/tui/components/dialogs/models/keys.go b/internal/tui/components/dialogs/models/keys.go index dd34f82860b21a98a6ae4a5eea57aa052b383f33..bb70785172bac66d9fda905172572c881b2ecd35 100644 --- a/internal/tui/components/dialogs/models/keys.go +++ b/internal/tui/components/dialogs/models/keys.go @@ -28,7 +28,7 @@ func DefaultKeyMap() KeyMap { ), Tab: key.NewBinding( key.WithKeys("tab"), - key.WithHelp("tab", "toggle model type"), + key.WithHelp("tab", "toggle type"), ), Close: key.NewBinding( key.WithKeys("esc"), From 1f6eef460ec921c435cba5bd58228dfa8adf6ef3 Mon Sep 17 00:00:00 2001 From: Gedy Palomino <36518098+gedzeppelin@users.noreply.github.com> Date: Tue, 1 Jul 2025 04:50:43 -0500 Subject: [PATCH 37/40] fix(mcp): ensure required field if nil (#278) --- internal/llm/agent/mcp-tools.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 2375606416e144db5ada7b0ab4309c7987aa8080..59a15bdd72ff620c92eb4ee3144a0a4b13276f67 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -33,11 +33,15 @@ type MCPClient interface { } func (b *mcpTool) Info() tools.ToolInfo { + required := b.tool.InputSchema.Required + if required == nil { + required = make([]string, 0) + } return tools.ToolInfo{ Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name), Description: b.tool.Description, Parameters: b.tool.InputSchema.Properties, - Required: b.tool.InputSchema.Required, + Required: required, } } From f0571f5f5adef12eba9ddf6d07223a043d63dca8 Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Tue, 1 Jul 2025 04:52:19 -0500 Subject: [PATCH 38/40] fix(tool/grep): always show file names with rg (#271) --- internal/llm/tools/grep.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index f20d61ef1ed44f50235f4ba19b8ea44ba7043eb6..1d2d008cf3d641f03cebcd97b73956d11e84a2a1 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -211,7 +211,7 @@ func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) { } // Use -n to show line numbers and include the matched line - args := []string{"-n", pattern} + args := []string{"-H", "-n", pattern} if include != "" { args = append(args, "--glob", include) } From 979384a3a23e9476a8e886cc65fc4ed152c986bb Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 1 Jul 2025 12:15:22 +0200 Subject: [PATCH 39/40] chore: fix lint --- internal/logging/logger.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/logging/logger.go b/internal/logging/logger.go index 98e7b23ae3b4025acbcd2585a7cfd3f3c9230623..ef33ba299855f24f0b48ae5af9ed80a069ba7a31 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" "os" + // "path/filepath" "encoding/json" "runtime" @@ -22,6 +23,7 @@ func getCaller() string { } return caller } + func Info(msg string, args ...any) { source := getCaller() slog.Info(msg, append([]any{"source", source}, args...)...) @@ -122,7 +124,7 @@ func AppendToSessionLogFile(sessionId string, filename string, content string) s filePath := fmt.Sprintf("%s/%s", sessionPath, filename) - f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { Error("Failed to open session log file", "filepath", filePath, "error", err) return "" From 142ff9e9735befd24922d893bff64951d4114aca Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 1 Jul 2025 12:18:40 +0200 Subject: [PATCH 40/40] chore: remove unnecessary commnet --- internal/llm/agent/agent.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ec4f160e650d4eafa497b92b77b8dc26c2aca40b..cd2e2fdaccc9108af3bab8a0072baad062585846 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -503,12 +503,6 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg tool = availableTool break } - // Monkey patch for Copilot Sonnet-4 tool repetition obfuscation - // if strings.HasPrefix(toolCall.Name, availableTool.Info().Name) && - // strings.HasPrefix(toolCall.Name, availableTool.Info().Name+availableTool.Info().Name) { - // tool = availableTool - // break - // } } // Tool not found