From 565ab85eb91102ff7a9b03f950ac13ae4391f6ab Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 10:55:44 +0200 Subject: [PATCH] chore: move to the new config --- cmd/root.go | 3 +- cmd/schema/main.go | 21 +- internal/app/app.go | 21 +- internal/config/config.go | 1295 +++++++---------- internal/{config_v2 => config}/config_test.go | 4 +- internal/{config_v2 => config}/fs.go | 2 +- internal/config/init.go | 19 +- internal/{config_v2 => config}/provider.go | 2 +- internal/config_v2/config.go | 660 --------- internal/db/connect.go | 5 +- internal/db/messages.sql.go | 14 +- ...0250627000000_add_provider_to_messages.sql | 11 + internal/db/models.go | 1 + internal/db/sql/messages.sql | 3 +- internal/fur/client/client.go | 2 +- internal/fur/provider/provider.go | 15 +- internal/llm/agent/agent-tool.go | 27 +- internal/llm/agent/agent.go | 301 ++-- internal/llm/agent/mcp-tools.go | 8 +- internal/llm/agent/tools.go | 50 - internal/llm/models/anthropic.go | 111 -- internal/llm/models/azure.go | 168 --- internal/llm/models/gemini.go | 67 - internal/llm/models/groq.go | 87 -- internal/llm/models/local.go | 206 --- internal/llm/models/models.go | 74 - internal/llm/models/openai.go | 181 --- internal/llm/models/openrouter.go | 276 ---- internal/llm/models/vertexai.go | 38 - internal/llm/models/xai.go | 61 - internal/llm/prompt/coder.go | 18 +- internal/llm/prompt/prompt.go | 54 +- internal/llm/prompt/prompt_test.go | 15 +- internal/llm/prompt/summarizer.go | 6 +- internal/llm/prompt/task.go | 4 +- internal/llm/prompt/title.go | 6 +- internal/llm/provider/anthropic.go | 10 +- internal/llm/provider/bedrock.go | 15 +- internal/llm/provider/gemini.go | 8 +- internal/llm/provider/openai.go | 10 +- internal/llm/provider/provider.go | 134 +- internal/llm/provider/vertexai.go | 7 +- internal/lsp/client.go | 28 +- internal/lsp/handlers.go | 4 +- internal/lsp/transport.go | 38 +- internal/lsp/watcher/watcher.go | 50 +- internal/message/content.go | 9 +- internal/message/message.go | 12 +- internal/tui/components/chat/header/header.go | 7 +- .../tui/components/chat/messages/messages.go | 8 +- .../tui/components/chat/sidebar/sidebar.go | 8 +- .../tui/components/dialogs/commands/loader.go | 2 +- internal/tui/components/dialogs/init/init.go | 6 +- .../tui/components/dialogs/models/models.go | 124 +- internal/tui/page/chat/chat.go | 10 +- internal/tui/tui.go | 7 +- 56 files changed, 1096 insertions(+), 3237 deletions(-) rename internal/{config_v2 => config}/config_test.go (94%) rename internal/{config_v2 => config}/fs.go (99%) rename internal/{config_v2 => config}/provider.go (98%) delete mode 100644 internal/config_v2/config.go create mode 100644 internal/db/migrations/20250627000000_add_provider_to_messages.sql delete mode 100644 internal/llm/agent/tools.go delete mode 100644 internal/llm/models/anthropic.go delete mode 100644 internal/llm/models/azure.go delete mode 100644 internal/llm/models/gemini.go delete mode 100644 internal/llm/models/groq.go delete mode 100644 internal/llm/models/local.go delete mode 100644 internal/llm/models/models.go delete mode 100644 internal/llm/models/openai.go delete mode 100644 internal/llm/models/openrouter.go delete mode 100644 internal/llm/models/vertexai.go delete mode 100644 internal/llm/models/xai.go diff --git a/cmd/root.go b/cmd/root.go index 2b5f79cf0337c386196d783ad9d18e2e1380aa5b..d741b859178e6c524b4b4e3a61863f144840812c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -72,7 +72,8 @@ to assist developers in writing, debugging, and understanding code directly from } cwd = c } - _, err := config.Load(cwd, debug) + + _, err := config.Init(cwd, debug) if err != nil { return err } diff --git a/cmd/schema/main.go b/cmd/schema/main.go index da5353c0ec7353bfa3ec9b35760b735ecc2c9ccd..9eb88769fd84772628df5332d3dcc1b1b234ac90 100644 --- a/cmd/schema/main.go +++ b/cmd/schema/main.go @@ -1,3 +1,4 @@ +// TODO: FIX THIS package main import ( @@ -6,7 +7,6 @@ import ( "os" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" ) // JSONSchemaType represents a JSON Schema type @@ -192,22 +192,10 @@ func generateSchema() map[string]any { }, } - // Add known providers - knownProviders := []string{ - string(models.ProviderAnthropic), - string(models.ProviderOpenAI), - string(models.ProviderGemini), - string(models.ProviderGROQ), - string(models.ProviderOpenRouter), - string(models.ProviderBedrock), - string(models.ProviderAzure), - string(models.ProviderVertexAI), - } - providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{ "type": "string", "description": "Provider type", - "enum": knownProviders, + "enum": []string{}, } schema["properties"].(map[string]any)["providers"] = providerSchema @@ -241,9 +229,7 @@ func generateSchema() map[string]any { // Add model enum modelEnum := []string{} - for modelID := range models.SupportedModels { - modelEnum = append(modelEnum, string(modelID)) - } + agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum // Add specific agent properties @@ -251,7 +237,6 @@ func generateSchema() map[string]any { knownAgents := []string{ string(config.AgentCoder), string(config.AgentTask), - string(config.AgentTitle), } for _, agentName := range knownAgents { diff --git a/internal/app/app.go b/internal/app/app.go index e7472059a9f3fad360172c353f5d9a188529d177..75042e89648779cf50a4376aa01aa3b6ac8e72a0 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/charmbracelet/crush/internal/config" + configv2 "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" @@ -55,18 +55,21 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { // Initialize LSP clients in the background go app.initLSPClients(ctx) + cfg := configv2.Get() + + coderAgentCfg := cfg.Agents[configv2.AgentCoder] + if coderAgentCfg.ID == "" { + return nil, fmt.Errorf("coder agent configuration is missing") + } + var err error app.CoderAgent, err = agent.NewAgent( - config.AgentCoder, + coderAgentCfg, + app.Permissions, app.Sessions, app.Messages, - agent.CoderAgentTools( - app.Permissions, - app.Sessions, - app.Messages, - app.History, - app.LSPClients, - ), + app.History, + app.LSPClients, ) if err != nil { logging.Error("Failed to create coder agent", err) diff --git a/internal/config/config.go b/internal/config/config.go index 3944cb1374582f9af0eeb7bfadd05ef5f9a8c198..13444a5ccc8e99bdaa57a6156151b45a40176c09 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,67 +1,132 @@ -// Package config manages application configuration from various sources. package config import ( "encoding/json" + "errors" "fmt" "log/slog" + "maps" "os" "path/filepath" + "slices" "strings" + "sync" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/logging" - "github.com/spf13/afero" - "github.com/spf13/viper" ) -// MCPType defines the type of MCP (Model Control Protocol) server. -type MCPType string - -// Supported MCP types const ( - MCPStdio MCPType = "stdio" - MCPSse MCPType = "sse" + defaultDataDirectory = ".crush" + defaultLogLevel = "info" + appName = "crush" + + MaxTokensFallbackDefault = 4096 ) -// MCPServer defines the configuration for a Model Control Protocol server. -type MCPServer struct { - Command string `json:"command"` - Env []string `json:"env"` - Args []string `json:"args"` - Type MCPType `json:"type"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` +var defaultContextPaths = []string{ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md", } -type AgentName string +type AgentID string const ( - AgentCoder AgentName = "coder" - AgentSummarizer AgentName = "summarizer" - AgentTask AgentName = "task" - AgentTitle AgentName = "title" + AgentCoder AgentID = "coder" + AgentTask AgentID = "task" ) -// Agent defines configuration for different LLM models and their token limits. -type Agent struct { - Model models.ModelID `json:"model"` - MaxTokens int64 `json:"maxTokens"` - ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh +type Model struct { + ID string `json:"id"` + Name string `json:"model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + ReasoningEffort string `json:"reasoning_effort"` + SupportsImages bool `json:"supports_attachments"` } -// Provider defines configuration for an LLM provider. -type Provider struct { - APIKey string `json:"apiKey"` - Disabled bool `json:"disabled"` +type VertexAIOptions struct { + APIKey string `json:"api_key,omitempty"` + Project string `json:"project,omitempty"` + Location string `json:"location,omitempty"` } -// Data defines storage configuration. -type Data struct { - Directory string `json:"directory,omitempty"` +type ProviderConfig struct { + ID provider.InferenceProvider `json:"id"` + BaseURL string `json:"base_url,omitempty"` + ProviderType provider.Type `json:"provider_type"` + APIKey string `json:"api_key,omitempty"` + Disabled bool `json:"disabled"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // used for e.x for vertex to set the project + ExtraParams map[string]string `json:"extra_params,omitempty"` + + DefaultLargeModel string `json:"default_large_model,omitempty"` + DefaultSmallModel string `json:"default_small_model,omitempty"` + + Models []Model `json:"models,omitempty"` +} + +type Agent struct { + ID AgentID `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + // This is the id of the system prompt used by the agent + Disabled bool `json:"disabled"` + + Provider provider.InferenceProvider `json:"provider"` + Model string `json:"model"` + + // The available tools for the agent + // if this is nil, all tools are available + AllowedTools []string `json:"allowed_tools"` + + // this tells us which MCPs are available for this agent + // if this is empty all mcps are available + // the string array is the list of tools from the AllowedMCP the agent has available + // if the string array is nil, all tools from the AllowedMCP are available + AllowedMCP map[string][]string `json:"allowed_mcp"` + + // The list of LSPs that this agent can use + // if this is nil, all LSPs are available + AllowedLSP []string `json:"allowed_lsp"` + + // Overrides the context paths for this agent + ContextPaths []string `json:"context_paths"` +} + +type MCPType string + +const ( + MCPStdio MCPType = "stdio" + MCPSse MCPType = "sse" +) + +type MCP struct { + Command string `json:"command"` + Env []string `json:"env"` + Args []string `json:"args"` + Type MCPType `json:"type"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` } -// LSPConfig defines configuration for Language Server Protocol integration. type LSPConfig struct { Disabled bool `json:"enabled"` Command string `json:"command"` @@ -69,98 +134,72 @@ type LSPConfig struct { Options any `json:"options"` } -// TUIConfig defines the configuration for the Terminal User Interface. -type TUIConfig struct { - Theme string `json:"theme,omitempty"` +type TUIOptions struct { + CompactMode bool `json:"compact_mode"` + // Here we can add themes later or any TUI related options } -// Config is the main configuration structure for the application. -type Config struct { - Data Data `json:"data"` - WorkingDir string `json:"wd,omitempty"` - MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` - Providers map[models.InferenceProvider]Provider `json:"providers,omitempty"` - LSP map[string]LSPConfig `json:"lsp,omitempty"` - Agents map[AgentName]Agent `json:"agents,omitempty"` - Debug bool `json:"debug,omitempty"` - DebugLSP bool `json:"debugLSP,omitempty"` - ContextPaths []string `json:"contextPaths,omitempty"` - TUI TUIConfig `json:"tui"` - AutoCompact bool `json:"autoCompact,omitempty"` +type Options struct { + ContextPaths []string `json:"context_paths"` + TUI TUIOptions `json:"tui"` + Debug bool `json:"debug"` + DebugLSP bool `json:"debug_lsp"` + DisableAutoSummarize bool `json:"disable_auto_summarize"` + // Relative to the cwd + DataDirectory string `json:"data_directory"` } -// Application constants -const ( - defaultDataDirectory = ".crush" - defaultLogLevel = "info" - appName = "crush" - - MaxTokensFallbackDefault = 4096 -) +type PreferredModel struct { + ModelID string `json:"model_id"` + Provider provider.InferenceProvider `json:"provider"` +} -var defaultContextPaths = []string{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "GEMINI.md", - "gemini.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", +type PreferredModels struct { + Large PreferredModel `json:"large"` + Small PreferredModel `json:"small"` } -// Global configuration instance -var cfg *Config +type Config struct { + Models PreferredModels `json:"models"` + // List of configured providers + Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` -// Load initializes the configuration from environment variables and config files. -// If debug is true, debug mode is enabled and log level is set to debug. -// It returns an error if configuration loading fails. -func Load(workingDir string, debug bool) (*Config, error) { - if cfg != nil { - return cfg, nil - } + // List of configured agents + Agents map[AgentID]Agent `json:"agents,omitempty"` - cfg = &Config{ - WorkingDir: workingDir, - MCPServers: make(map[string]MCPServer), - Providers: make(map[models.InferenceProvider]Provider), - LSP: make(map[string]LSPConfig), - } + // List of configured MCPs + MCP map[string]MCP `json:"mcp,omitempty"` - configureViper() - setDefaults(debug) + // List of configured LSPs + LSP map[string]LSPConfig `json:"lsp,omitempty"` - // Read global config - if err := readConfig(viper.ReadInConfig()); err != nil { - return cfg, err - } + // Miscellaneous options + Options Options `json:"options"` +} - // Load and merge local config - mergeLocalConfig(workingDir) +var ( + instance *Config // The single instance of the Singleton + cwd string + once sync.Once // Ensures the initialization happens only once - setProviderDefaults() +) - // Apply configuration to the struct - if err := viper.Unmarshal(cfg); err != nil { - return cfg, fmt.Errorf("failed to unmarshal config: %w", err) - } +func loadConfig(cwd string, debug bool) (*Config, error) { + // First read the global config file + cfgPath := ConfigPath() - applyDefaultValues() + cfg := defaultConfigBasedOnEnv() + cfg.Options.Debug = debug defaultLevel := slog.LevelInfo - if cfg.Debug { + if cfg.Options.Debug { defaultLevel = slog.LevelDebug } if os.Getenv("CRUSH_DEV_DEBUG") == "true" { - loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log") + loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log") // if file does not exist create it if _, err := os.Stat(loggingFile); os.IsNotExist(err) { - if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil { + if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil { return cfg, fmt.Errorf("failed to create directory: %w", err) } if _, err := os.Create(loggingFile); err != nil { @@ -184,734 +223,530 @@ func Load(workingDir string, debug bool) (*Config, error) { })) slog.SetDefault(logger) } - - // Validate configuration - if err := Validate(); err != nil { - return cfg, fmt.Errorf("config validation failed: %w", err) - } - - if cfg.Agents == nil { - cfg.Agents = make(map[AgentName]Agent) - } - - // Override the max tokens for title agent - cfg.Agents[AgentTitle] = Agent{ - Model: cfg.Agents[AgentTitle].Model, - MaxTokens: 80, - } - return cfg, nil -} - -type configFinder struct { - appName string - dotPrefix bool - paths []string -} - -func (f configFinder) Find(fsys afero.Fs) ([]string, error) { - var configFiles []string - configName := fmt.Sprintf("%s.json", f.appName) - if f.dotPrefix { - configName = fmt.Sprintf(".%s.json", f.appName) - } - paths := []string{} - for _, p := range f.paths { - if p == "" { - continue + var globalCfg *Config + if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // config file exists, read it + file, err := os.ReadFile(cfgPath) + if err != nil { + return nil, err } - paths = append(paths, os.ExpandEnv(p)) - } - - for _, path := range paths { - if path == "" { - continue + globalCfg = &Config{} + if err := json.Unmarshal(file, globalCfg); err != nil { + return nil, err } - - configPath := filepath.Join(path, configName) - if exists, err := afero.Exists(fsys, configPath); err == nil && exists { - configFiles = append(configFiles, configPath) + } else { + // config file does not exist, create a new one + globalCfg = &Config{} + } + + var localConfig *Config + // Global config loaded, now read the local config file + localConfigPath := filepath.Join(cwd, "crush.json") + if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // local config file exists, read it + file, err := os.ReadFile(localConfigPath) + if err != nil { + return nil, err + } + localConfig = &Config{} + if err := json.Unmarshal(file, localConfig); err != nil { + return nil, err } } - return configFiles, nil -} -// configureViper sets up viper's configuration paths and environment variables. -func configureViper() { - viper.SetConfigType("json") - - // Create the three finders - windowsFinder := configFinder{appName: appName, dotPrefix: false, paths: []string{ - "$USERPROFILE", - fmt.Sprintf("$APPDATA/%s", appName), - fmt.Sprintf("$LOCALAPPDATA/%s", appName), - }} - - unixFinder := configFinder{appName: appName, dotPrefix: false, paths: []string{ - "$HOME", - fmt.Sprintf("$XDG_CONFIG_HOME/%s", appName), - fmt.Sprintf("$HOME/.config/%s", appName), - }} - - localFinder := configFinder{appName: appName, dotPrefix: true, paths: []string{ - ".", - }} - - // Use all finders with viper - viper.SetOptions(viper.WithFinder(viper.Finders(windowsFinder, unixFinder, localFinder))) - viper.SetEnvPrefix(strings.ToUpper(appName)) - viper.AutomaticEnv() -} + // merge options + mergeOptions(cfg, globalCfg, localConfig) -// setDefaults configures default values for configuration options. -func setDefaults(debug bool) { - viper.SetDefault("data.directory", defaultDataDirectory) - viper.SetDefault("contextPaths", defaultContextPaths) - viper.SetDefault("tui.theme", "crush") - viper.SetDefault("autoCompact", true) - - if debug { - viper.SetDefault("debug", true) - viper.Set("log.level", "debug") - } else { - viper.SetDefault("debug", false) - viper.SetDefault("log.level", defaultLogLevel) + mergeProviderConfigs(cfg, globalCfg, localConfig) + // no providers found the app is not initialized yet + if len(cfg.Providers) == 0 { + return cfg, nil } -} + preferredProvider := getPreferredProvider(cfg.Providers) + cfg.Models = PreferredModels{ + Large: PreferredModel{ + ModelID: preferredProvider.DefaultLargeModel, + Provider: preferredProvider.ID, + }, + Small: PreferredModel{ + ModelID: preferredProvider.DefaultSmallModel, + Provider: preferredProvider.ID, + }, + } + + mergeModels(cfg, globalCfg, localConfig) + + if preferredProvider == nil { + return nil, errors.New("no valid providers configured") + } + + agents := map[AgentID]Agent{ + AgentCoder: { + ID: AgentCoder, + Name: "Coder", + Description: "An agent that helps with executing coding tasks.", + Provider: cfg.Models.Large.Provider, + Model: cfg.Models.Large.ModelID, + ContextPaths: cfg.Options.ContextPaths, + // All tools allowed + }, + AgentTask: { + ID: AgentTask, + Name: "Task", + Description: "An agent that helps with searching for context and finding implementation details.", + Provider: cfg.Models.Large.Provider, + Model: cfg.Models.Large.ModelID, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{ + "glob", + "grep", + "ls", + "sourcegraph", + "view", + }, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + } + cfg.Agents = agents + mergeAgents(cfg, globalCfg, localConfig) + mergeMCPs(cfg, globalCfg, localConfig) + mergeLSPs(cfg, globalCfg, localConfig) -// setProviderDefaults configures LLM provider defaults based on provider provided by -// environment variables and configuration file. -func setProviderDefaults() { - // Set all API keys we can find in the environment - if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { - viper.SetDefault("providers.anthropic.apiKey", apiKey) - } - if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { - viper.SetDefault("providers.openai.apiKey", apiKey) - } - if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { - viper.SetDefault("providers.gemini.apiKey", apiKey) - } - if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { - viper.SetDefault("providers.groq.apiKey", apiKey) - } - if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" { - viper.SetDefault("providers.openrouter.apiKey", apiKey) - } - if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" { - viper.SetDefault("providers.xai.apiKey", apiKey) - } - if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" { - // api-key may be empty when using Entra ID credentials – that's okay - viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY")) - } - - // Use this order to set the default models - // 1. Anthropic - // 2. OpenAI - // 3. Google Gemini - // 4. Groq - // 5. OpenRouter - // 6. AWS Bedrock - // 7. Azure - // 8. Google Cloud VertexAI - - // Anthropic configuration - if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.Claude4Sonnet) - viper.SetDefault("agents.summarizer.model", models.Claude4Sonnet) - viper.SetDefault("agents.task.model", models.Claude4Sonnet) - viper.SetDefault("agents.title.model", models.Claude4Sonnet) - return - } - - // OpenAI configuration - if key := viper.GetString("providers.openai.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.GPT41) - viper.SetDefault("agents.summarizer.model", models.GPT41) - viper.SetDefault("agents.task.model", models.GPT41Mini) - viper.SetDefault("agents.title.model", models.GPT41Mini) - return - } - - // Google Gemini configuration - if key := viper.GetString("providers.gemini.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.Gemini25) - viper.SetDefault("agents.summarizer.model", models.Gemini25) - viper.SetDefault("agents.task.model", models.Gemini25Flash) - viper.SetDefault("agents.title.model", models.Gemini25Flash) - return - } - - // Groq configuration - if key := viper.GetString("providers.groq.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.QWENQwq) - viper.SetDefault("agents.summarizer.model", models.QWENQwq) - viper.SetDefault("agents.task.model", models.QWENQwq) - viper.SetDefault("agents.title.model", models.QWENQwq) - return - } - - // OpenRouter configuration - if key := viper.GetString("providers.openrouter.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.OpenRouterClaude37Sonnet) - viper.SetDefault("agents.summarizer.model", models.OpenRouterClaude37Sonnet) - viper.SetDefault("agents.task.model", models.OpenRouterClaude37Sonnet) - viper.SetDefault("agents.title.model", models.OpenRouterClaude35Haiku) - return - } - - // XAI configuration - if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.XAIGrok3Beta) - viper.SetDefault("agents.summarizer.model", models.XAIGrok3Beta) - viper.SetDefault("agents.task.model", models.XAIGrok3Beta) - viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta) - return - } - - // AWS Bedrock configuration - if hasAWSCredentials() { - viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.summarizer.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet) - return - } - - // Azure OpenAI configuration - if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" { - viper.SetDefault("agents.coder.model", models.AzureGPT41) - viper.SetDefault("agents.summarizer.model", models.AzureGPT41) - viper.SetDefault("agents.task.model", models.AzureGPT41Mini) - viper.SetDefault("agents.title.model", models.AzureGPT41Mini) - return - } - - // Google Cloud VertexAI configuration - if hasVertexAICredentials() { - viper.SetDefault("agents.coder.model", models.VertexAIGemini25) - viper.SetDefault("agents.summarizer.model", models.VertexAIGemini25) - viper.SetDefault("agents.task.model", models.VertexAIGemini25Flash) - viper.SetDefault("agents.title.model", models.VertexAIGemini25Flash) - return - } + return cfg, nil } -// hasAWSCredentials checks if AWS credentials are available in the environment. -func hasAWSCredentials() bool { - // Check for explicit AWS credentials - if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { - return true - } - - // Check for AWS profile - if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { - return true - } +func Init(workingDir string, debug bool) (*Config, error) { + var err error + once.Do(func() { + cwd = workingDir + instance, err = loadConfig(cwd, debug) + if err != nil { + logging.Error("Failed to load config", "error", err) + } + }) - // Check for AWS region - if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { - return true - } + return instance, err +} - // Check if running on EC2 with instance profile - if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || - os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { - return true +func Get() *Config { + if instance == nil { + // TODO: Handle this better + panic("Config not initialized. Call InitConfig first.") } - - return false + return instance } -// hasVertexAICredentials checks if VertexAI credentials are available in the environment. -func hasVertexAICredentials() bool { - // Check for explicit VertexAI parameters - if os.Getenv("VERTEXAI_PROJECT") != "" && os.Getenv("VERTEXAI_LOCATION") != "" { - return true +func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig { + providers := Providers() + for _, p := range providers { + if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled { + return &providerConfig + } } - // Check for Google Cloud project and location - if os.Getenv("GOOGLE_CLOUD_PROJECT") != "" && (os.Getenv("GOOGLE_CLOUD_REGION") != "" || os.Getenv("GOOGLE_CLOUD_LOCATION") != "") { - return true + // if none found return the first configured provider + for _, providerConfig := range configuredProviders { + if !providerConfig.Disabled { + return &providerConfig + } } - return false + return nil } -// readConfig handles the result of reading a configuration file. -func readConfig(err error) error { - if err == nil { - return nil +func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { + if other.APIKey != "" { + base.APIKey = other.APIKey } - - // It's okay if the config file doesn't exist - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - return nil + // Only change these options if the provider is not a known provider + if !slices.Contains(provider.KnownProviders(), p) { + if other.BaseURL != "" { + base.BaseURL = other.BaseURL + } + if other.ProviderType != "" { + base.ProviderType = other.ProviderType + } + if len(base.ExtraHeaders) > 0 { + if base.ExtraHeaders == nil { + base.ExtraHeaders = make(map[string]string) + } + maps.Copy(base.ExtraHeaders, other.ExtraHeaders) + } + if len(other.ExtraParams) > 0 { + if base.ExtraParams == nil { + base.ExtraParams = make(map[string]string) + } + maps.Copy(base.ExtraParams, other.ExtraParams) + } } - return fmt.Errorf("failed to read config: %w", err) -} - -// mergeLocalConfig loads and merges configuration from the local directory. -func mergeLocalConfig(workingDir string) { - local := viper.New() - local.SetConfigName(fmt.Sprintf(".%s", appName)) - local.SetConfigType("json") - local.AddConfigPath(workingDir) - - // Merge local config if it exists - if err := local.ReadInConfig(); err == nil { - viper.MergeConfigMap(local.AllSettings()) + if other.Disabled { + base.Disabled = other.Disabled } -} -// applyDefaultValues sets default values for configuration fields that need processing. -func applyDefaultValues() { - // Set default MCP type if not specified - for k, v := range cfg.MCPServers { - if v.Type == "" { - v.Type = MCPStdio - cfg.MCPServers[k] = v - } + if other.DefaultLargeModel != "" { + base.DefaultLargeModel = other.DefaultLargeModel } -} - -// It validates model IDs and providers, ensuring they are supported. -func validateAgent(cfg *Config, name AgentName, agent Agent) error { - // Check if model exists - model, modelExists := models.SupportedModels[agent.Model] - if !modelExists { - logging.Warn("unsupported model configured, reverting to default", - "agent", name, - "configured_model", agent.Model) - - // Set default model based on available providers - if setDefaultModelForAgent(name) { - logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) - } else { - return fmt.Errorf("no valid provider available for agent %s", name) - } - return nil - } - - // Check if provider for the model is configured - provider := model.Provider - providerCfg, providerExists := cfg.Providers[provider] - - if !providerExists { - // Provider not configured, check if we have environment variables - apiKey := getProviderAPIKey(provider) - if apiKey == "" { - logging.Warn("provider not configured for model, reverting to default", - "agent", name, - "model", agent.Model, - "provider", provider) - - // Set default model based on available providers - if setDefaultModelForAgent(name) { - logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) - } else { - return fmt.Errorf("no valid provider available for agent %s", name) + // Add new models if they don't exist + if other.Models != nil { + for _, model := range other.Models { + // check if the model already exists + exists := false + for _, existingModel := range base.Models { + if existingModel.ID == model.ID { + exists = true + break + } } - } else { - // Add provider with API key from environment - cfg.Providers[provider] = Provider{ - APIKey: apiKey, - } - logging.Info("added provider from environment", "provider", provider) - } - } else if providerCfg.Disabled || providerCfg.APIKey == "" { - // Provider is disabled or has no API key - logging.Warn("provider is disabled or has no API key, reverting to default", - "agent", name, - "model", agent.Model, - "provider", provider) - - // Set default model based on available providers - if setDefaultModelForAgent(name) { - logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) - } else { - return fmt.Errorf("no valid provider available for agent %s", name) - } - } - - // Validate max tokens - if agent.MaxTokens <= 0 { - logging.Warn("invalid max tokens, setting to default", - "agent", name, - "model", agent.Model, - "max_tokens", agent.MaxTokens) - - // Update the agent with default max tokens - updatedAgent := cfg.Agents[name] - if model.DefaultMaxTokens > 0 { - updatedAgent.MaxTokens = model.DefaultMaxTokens - } else { - updatedAgent.MaxTokens = MaxTokensFallbackDefault - } - cfg.Agents[name] = updatedAgent - } else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 { - // Ensure max tokens doesn't exceed half the context window (reasonable limit) - logging.Warn("max tokens exceeds half the context window, adjusting", - "agent", name, - "model", agent.Model, - "max_tokens", agent.MaxTokens, - "context_window", model.ContextWindow) - - // Update the agent with adjusted max tokens - updatedAgent := cfg.Agents[name] - updatedAgent.MaxTokens = model.ContextWindow / 2 - cfg.Agents[name] = updatedAgent - } - - // Validate reasoning effort for models that support reasoning - if model.CanReason && provider == models.ProviderOpenAI || provider == models.ProviderLocal { - if agent.ReasoningEffort == "" { - // Set default reasoning effort for models that support it - logging.Info("setting default reasoning effort for model that supports reasoning", - "agent", name, - "model", agent.Model) - - // Update the agent with default reasoning effort - updatedAgent := cfg.Agents[name] - updatedAgent.ReasoningEffort = "medium" - cfg.Agents[name] = updatedAgent - } else { - // Check if reasoning effort is valid (low, medium, high) - effort := strings.ToLower(agent.ReasoningEffort) - if effort != "low" && effort != "medium" && effort != "high" { - logging.Warn("invalid reasoning effort, setting to medium", - "agent", name, - "model", agent.Model, - "reasoning_effort", agent.ReasoningEffort) - - // Update the agent with valid reasoning effort - updatedAgent := cfg.Agents[name] - updatedAgent.ReasoningEffort = "medium" - cfg.Agents[name] = updatedAgent + if !exists { + base.Models = append(base.Models, model) } } - } else if !model.CanReason && agent.ReasoningEffort != "" { - // Model doesn't support reasoning but reasoning effort is set - logging.Warn("model doesn't support reasoning but reasoning effort is set, ignoring", - "agent", name, - "model", agent.Model, - "reasoning_effort", agent.ReasoningEffort) - - // Update the agent to remove reasoning effort - updatedAgent := cfg.Agents[name] - updatedAgent.ReasoningEffort = "" - cfg.Agents[name] = updatedAgent } - return nil + return base } -// Validate checks if the configuration is valid and applies defaults where needed. -func Validate() error { - if cfg == nil { - return fmt.Errorf("config not loaded") - } - - // Validate agent models - for name, agent := range cfg.Agents { - if err := validateAgent(cfg, name, agent); err != nil { - return err +func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error { + if !slices.Contains(provider.KnownProviders(), p) { + if providerConfig.ProviderType != provider.TypeOpenAI { + return errors.New("invalid provider type: " + string(providerConfig.ProviderType)) } - } - - // Validate providers - for provider, providerCfg := range cfg.Providers { - if providerCfg.APIKey == "" && !providerCfg.Disabled { - logging.Warn("provider has no API key, marking as disabled", "provider", provider) - providerCfg.Disabled = true - cfg.Providers[provider] = providerCfg + if providerConfig.BaseURL == "" { + return errors.New("base URL must be set for custom providers") } - } - - // Validate LSP configurations - for language, lspConfig := range cfg.LSP { - if lspConfig.Command == "" && !lspConfig.Disabled { - logging.Warn("LSP configuration has no command, marking as disabled", "language", language) - lspConfig.Disabled = true - cfg.LSP[language] = lspConfig + if providerConfig.APIKey == "" { + return errors.New("API key must be set for custom providers") } } - return nil } -// getProviderAPIKey gets the API key for a provider from environment variables -func getProviderAPIKey(provider models.InferenceProvider) string { - switch provider { - case models.ProviderAnthropic: - return os.Getenv("ANTHROPIC_API_KEY") - case models.ProviderOpenAI: - return os.Getenv("OPENAI_API_KEY") - case models.ProviderGemini: - return os.Getenv("GEMINI_API_KEY") - case models.ProviderGROQ: - return os.Getenv("GROQ_API_KEY") - case models.ProviderAzure: - return os.Getenv("AZURE_OPENAI_API_KEY") - case models.ProviderOpenRouter: - return os.Getenv("OPENROUTER_API_KEY") - case models.ProviderBedrock: - if hasAWSCredentials() { - return "aws-credentials-available" - } - case models.ProviderVertexAI: - if hasVertexAICredentials() { - return "vertex-ai-credentials-available" - } - } - return "" -} - -// setDefaultModelForAgent sets a default model for an agent based on available providers -func setDefaultModelForAgent(agent AgentName) bool { - // Check providers in order of preference - if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { - maxTokens := int64(5000) - if agent == AgentTitle { - maxTokens = 80 +func mergeModels(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } - cfg.Agents[agent] = Agent{ - Model: models.Claude37Sonnet, - MaxTokens: maxTokens, + if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" { + base.Models.Large = cfg.Models.Large } - return true - } - if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { - var model models.ModelID - maxTokens := int64(5000) - reasoningEffort := "" - - switch agent { - case AgentTitle: - model = models.GPT41Mini - maxTokens = 80 - case AgentTask: - model = models.GPT41Mini - default: - model = models.GPT41 + if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" { + base.Models.Small = cfg.Models.Small } + } +} - // Check if model supports reasoning - if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason { - reasoningEffort = "medium" +func mergeOptions(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + baseOptions := base.Options + other := cfg.Options + if len(other.ContextPaths) > 0 { + baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...) } - cfg.Agents[agent] = Agent{ - Model: model, - MaxTokens: maxTokens, - ReasoningEffort: reasoningEffort, + if other.TUI.CompactMode { + baseOptions.TUI.CompactMode = other.TUI.CompactMode } - return true - } - if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" { - var model models.ModelID - maxTokens := int64(5000) - reasoningEffort := "" + if other.Debug { + baseOptions.Debug = other.Debug + } - switch agent { - case AgentTitle: - model = models.OpenRouterClaude35Haiku - maxTokens = 80 - case AgentTask: - model = models.OpenRouterClaude37Sonnet - default: - model = models.OpenRouterClaude37Sonnet + if other.DebugLSP { + baseOptions.DebugLSP = other.DebugLSP } - // Check if model supports reasoning - if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason { - reasoningEffort = "medium" + if other.DisableAutoSummarize { + baseOptions.DisableAutoSummarize = other.DisableAutoSummarize } - cfg.Agents[agent] = Agent{ - Model: model, - MaxTokens: maxTokens, - ReasoningEffort: reasoningEffort, + if other.DataDirectory != "" { + baseOptions.DataDirectory = other.DataDirectory } - return true + base.Options = baseOptions } +} - if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { - var model models.ModelID - maxTokens := int64(5000) - - if agent == AgentTitle { - model = models.Gemini25Flash - maxTokens = 80 - } else { - model = models.Gemini25 +func mergeAgents(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } - - cfg.Agents[agent] = Agent{ - Model: model, - MaxTokens: maxTokens, + for agentID, newAgent := range cfg.Agents { + if _, ok := base.Agents[agentID]; !ok { + newAgent.ID = agentID // Ensure the ID is set correctly + base.Agents[agentID] = newAgent + } else { + switch agentID { + case AgentCoder: + baseAgent := base.Agents[agentID] + if newAgent.Model != "" && newAgent.Provider != "" { + baseAgent.Model = newAgent.Model + baseAgent.Provider = newAgent.Provider + } + baseAgent.AllowedMCP = newAgent.AllowedMCP + baseAgent.AllowedLSP = newAgent.AllowedLSP + base.Agents[agentID] = baseAgent + default: + baseAgent := base.Agents[agentID] + baseAgent.Name = newAgent.Name + baseAgent.Description = newAgent.Description + baseAgent.Disabled = newAgent.Disabled + if newAgent.Model == "" || newAgent.Provider == "" { + baseAgent.Provider = base.Models.Large.Provider + baseAgent.Model = base.Models.Large.ModelID + } + baseAgent.AllowedTools = newAgent.AllowedTools + baseAgent.AllowedMCP = newAgent.AllowedMCP + baseAgent.AllowedLSP = newAgent.AllowedLSP + base.Agents[agentID] = baseAgent + } + } } - return true } +} - if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { - maxTokens := int64(5000) - if agent == AgentTitle { - maxTokens = 80 +func mergeMCPs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } + maps.Copy(base.MCP, cfg.MCP) + } +} - cfg.Agents[agent] = Agent{ - Model: models.QWENQwq, - MaxTokens: maxTokens, +func mergeLSPs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } - return true + maps.Copy(base.LSP, cfg.LSP) } +} - if hasAWSCredentials() { - maxTokens := int64(5000) - if agent == AgentTitle { - maxTokens = 80 +func mergeProviderConfigs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue } - - cfg.Agents[agent] = Agent{ - Model: models.BedrockClaude37Sonnet, - MaxTokens: maxTokens, - ReasoningEffort: "medium", // Claude models support reasoning + for providerName, globalProvider := range cfg.Providers { + if _, ok := base.Providers[providerName]; !ok { + base.Providers[providerName] = globalProvider + } else { + base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider) + } } - return true } - if hasVertexAICredentials() { - var model models.ModelID - maxTokens := int64(5000) - - if agent == AgentTitle { - model = models.VertexAIGemini25Flash - maxTokens = 80 - } else { - model = models.VertexAIGemini25 + finalProviders := make(map[provider.InferenceProvider]ProviderConfig) + for providerName, providerConfig := range base.Providers { + err := validateProvider(providerName, providerConfig) + if err != nil { + logging.Warn("Skipping provider", "name", providerName, "error", err) } + finalProviders[providerName] = providerConfig + } + base.Providers = finalProviders +} - cfg.Agents[agent] = Agent{ - Model: model, - MaxTokens: maxTokens, +func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig { + switch providerId { + case provider.InferenceProviderAnthropic: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeAnthropic, + } + case provider.InferenceProviderOpenAI: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeOpenAI, + } + case provider.InferenceProviderGemini: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeGemini, + } + case provider.InferenceProviderBedrock: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeBedrock, + } + case provider.InferenceProviderAzure: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeAzure, + } + case provider.InferenceProviderOpenRouter: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeOpenAI, + BaseURL: "https://openrouter.ai/api/v1", + ExtraHeaders: map[string]string{ + "HTTP-Referer": "crush.charm.land", + "X-Title": "Crush", + }, + } + case provider.InferenceProviderXAI: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeXAI, + BaseURL: "https://api.x.ai/v1", + } + case provider.InferenceProviderVertexAI: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeVertexAI, + } + default: + return ProviderConfig{ + ID: providerId, + ProviderType: provider.TypeOpenAI, } - return true } - - return false } -func updateCfgFile(updateCfg func(config *Config)) error { - if cfg == nil { - return fmt.Errorf("config not loaded") +func defaultConfigBasedOnEnv() *Config { + cfg := &Config{ + Options: Options{ + DataDirectory: defaultDataDirectory, + ContextPaths: defaultContextPaths, + }, + Providers: make(map[provider.InferenceProvider]ProviderConfig), + } + + providers := Providers() + + for _, p := range providers { + if strings.HasPrefix(p.APIKey, "$") { + envVar := strings.TrimPrefix(p.APIKey, "$") + if apiKey := os.Getenv(envVar); apiKey != "" { + providerConfig := providerDefaultConfig(p.ID) + providerConfig.APIKey = apiKey + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + baseURL := p.APIEndpoint + if strings.HasPrefix(baseURL, "$") { + envVar := strings.TrimPrefix(baseURL, "$") + if url := os.Getenv(envVar); url != "" { + baseURL = url + } + } + providerConfig.BaseURL = baseURL + for _, model := range p.Models { + providerConfig.Models = append(providerConfig.Models, Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + }) + } + cfg.Providers[p.ID] = providerConfig + } + } } + // TODO: support local models - // Get the config file path - configFile := viper.ConfigFileUsed() - var configData []byte - if configFile == "" { - homeDir, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to get home directory: %w", err) + if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" { + providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI) + providerConfig.ExtraParams = map[string]string{ + "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), + "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), } - configFile = filepath.Join(homeDir, fmt.Sprintf(".%s.json", appName)) - logging.Info("config file not found, creating new one", "path", configFile) - configData = []byte(`{}`) - } else { - // Read the existing config file - data, err := os.ReadFile(configFile) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - configData = data + cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig } - // Parse the JSON - var userCfg *Config - if err := json.Unmarshal(configData, &userCfg); err != nil { - return fmt.Errorf("failed to parse config file: %w", err) + if hasAWSCredentials() { + providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock) + providerConfig.ExtraParams = map[string]string{ + "region": os.Getenv("AWS_DEFAULT_REGION"), + } + if providerConfig.ExtraParams["region"] == "" { + providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION") + } + cfg.Providers[provider.InferenceProviderBedrock] = providerConfig } + return cfg +} - updateCfg(userCfg) +func hasAWSCredentials() bool { + if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { + return true + } - // Write the updated config back to file - updatedData, err := json.MarshalIndent(userCfg, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal config: %w", err) + if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { + return true } - if err := os.WriteFile(configFile, updatedData, 0o644); err != nil { - return fmt.Errorf("failed to write config file: %w", err) + if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { + return true } - return nil -} + if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { + return true + } -// Get returns the current configuration. -// It's safe to call this function multiple times. -func Get() *Config { - return cfg + return false } -// WorkingDirectory returns the current working directory from the configuration. func WorkingDirectory() string { - if cfg == nil { - panic("config not loaded") - } - return cfg.WorkingDir + return cwd } -func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error { - if cfg == nil { - panic("config not loaded") +func GetAgentModel(agentID AgentID) Model { + cfg := Get() + agent, ok := cfg.Agents[agentID] + if !ok { + logging.Error("Agent not found", "agent_id", agentID) + return Model{} } - existingAgentCfg := cfg.Agents[agentName] - - model, ok := models.SupportedModels[modelID] + providerConfig, ok := cfg.Providers[agent.Provider] if !ok { - return fmt.Errorf("model %s not supported", modelID) + logging.Error("Provider not found for agent", "agent_id", agentID, "provider", agent.Provider) + return Model{} } - maxTokens := existingAgentCfg.MaxTokens - if model.DefaultMaxTokens > 0 { - maxTokens = model.DefaultMaxTokens + for _, model := range providerConfig.Models { + if model.ID == agent.Model { + return model + } } - newAgentCfg := Agent{ - Model: modelID, - MaxTokens: maxTokens, - ReasoningEffort: existingAgentCfg.ReasoningEffort, - } - cfg.Agents[agentName] = newAgentCfg + logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model) + return Model{} +} - if err := validateAgent(cfg, agentName, newAgentCfg); err != nil { - // revert config update on failure - cfg.Agents[agentName] = existingAgentCfg - return fmt.Errorf("failed to update agent model: %w", err) +func GetProviderModel(provider provider.InferenceProvider, modelID string) Model { + cfg := Get() + providerConfig, ok := cfg.Providers[provider] + if !ok { + logging.Error("Provider not found", "provider", provider) + return Model{} } - return updateCfgFile(func(config *Config) { - if config.Agents == nil { - config.Agents = make(map[AgentName]Agent) + for _, model := range providerConfig.Models { + if model.ID == modelID { + return model } - config.Agents[agentName] = newAgentCfg - }) -} - -// UpdateTheme updates the theme in the configuration and writes it to the config file. -func UpdateTheme(themeName string) error { - if cfg == nil { - return fmt.Errorf("config not loaded") } - // Update the in-memory config - cfg.TUI.Theme = themeName - - // Update the file config - return updateCfgFile(func(config *Config) { - config.TUI.Theme = themeName - }) + logging.Error("Model not found for provider", "provider", provider, "model_id", modelID) + return Model{} } diff --git a/internal/config_v2/config_test.go b/internal/config/config_test.go similarity index 94% rename from internal/config_v2/config_test.go rename to internal/config/config_test.go index 9bcfcdc78375e1a3a35726b513f04e3cb1e2c3b3..2942c206aa4bb8b81ff3f3fca9a444411359e515 100644 --- a/internal/config_v2/config_test.go +++ b/internal/config/config_test.go @@ -1,4 +1,4 @@ -package configv2 +package config import ( "encoding/json" @@ -28,7 +28,7 @@ func TestConfigWithEnv(t *testing.T) { os.Setenv("GEMINI_API_KEY", "test-gemini-key") os.Setenv("XAI_API_KEY", "test-xai-key") os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") - cfg := InitConfig(cwdDir) + cfg, _ := Init(cwdDir, false) data, _ := json.MarshalIndent(cfg, "", " ") fmt.Println(string(data)) assert.Len(t, cfg.Providers, 5) diff --git a/internal/config_v2/fs.go b/internal/config/fs.go similarity index 99% rename from internal/config_v2/fs.go rename to internal/config/fs.go index 976267a2a68efb718449f59b3720d0d186720cdf..efa622cf937846370616042de4fe2bcd6f33b7a1 100644 --- a/internal/config_v2/fs.go +++ b/internal/config/fs.go @@ -1,4 +1,4 @@ -package configv2 +package config import ( "fmt" diff --git a/internal/config/init.go b/internal/config/init.go index 1b603fbb846aba45230cd0f4683cb465e14db69a..f17e1db28e41cc44e168765e55e88311423e1102 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -17,23 +17,20 @@ type ProjectInitFlag struct { Initialized bool `json:"initialized"` } -// ShouldShowInitDialog checks if the initialization dialog should be shown for the current directory -func ShouldShowInitDialog() (bool, error) { - if cfg == nil { +// ProjectNeedsInitialization checks if the current project needs initialization +func ProjectNeedsInitialization() (bool, error) { + if instance == nil { return false, fmt.Errorf("config not loaded") } - // Create the flag file path - flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename) + flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename) // Check if the flag file exists _, err := os.Stat(flagFilePath) if err == nil { - // File exists, don't show the dialog return false, nil } - // If the error is not "file not found", return the error if !os.IsNotExist(err) { return false, fmt.Errorf("failed to check init flag file: %w", err) } @@ -44,11 +41,9 @@ func ShouldShowInitDialog() (bool, error) { return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err) } if crushExists { - // CRUSH.md already exists, don't show the dialog return false, nil } - // File doesn't exist, show the dialog return true, nil } @@ -75,13 +70,11 @@ func crushMdExists(dir string) (bool, error) { // MarkProjectInitialized marks the current project as initialized func MarkProjectInitialized() error { - if cfg == nil { + if instance == nil { return fmt.Errorf("config not loaded") } - // Create the flag file path - flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename) + flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename) - // Create an empty file to mark the project as initialized file, err := os.Create(flagFilePath) if err != nil { return fmt.Errorf("failed to create init flag file: %w", err) diff --git a/internal/config_v2/provider.go b/internal/config/provider.go similarity index 98% rename from internal/config_v2/provider.go rename to internal/config/provider.go index ec6b5bdb701876af4705c9e78fcc55a87646edd2..4c2b61ff6d5d86f62a8a1833a6ea91b500bbc7b0 100644 --- a/internal/config_v2/provider.go +++ b/internal/config/provider.go @@ -1,4 +1,4 @@ -package configv2 +package config import ( "encoding/json" diff --git a/internal/config_v2/config.go b/internal/config_v2/config.go deleted file mode 100644 index 9f7f2ad14356531150cca4f05952fb390c716c68..0000000000000000000000000000000000000000 --- a/internal/config_v2/config.go +++ /dev/null @@ -1,660 +0,0 @@ -package configv2 - -import ( - "encoding/json" - "errors" - "maps" - "os" - "path/filepath" - "slices" - "strings" - "sync" - - "github.com/charmbracelet/crush/internal/fur/provider" - "github.com/charmbracelet/crush/internal/logging" -) - -const ( - defaultDataDirectory = ".crush" - defaultLogLevel = "info" - appName = "crush" - - MaxTokensFallbackDefault = 4096 -) - -var defaultContextPaths = []string{ - ".github/copilot-instructions.md", - ".cursorrules", - ".cursor/rules/", - "CLAUDE.md", - "CLAUDE.local.md", - "crush.md", - "crush.local.md", - "Crush.md", - "Crush.local.md", - "CRUSH.md", - "CRUSH.local.md", -} - -type AgentID string - -const ( - AgentCoder AgentID = "coder" - AgentTask AgentID = "task" - AgentTitle AgentID = "title" - AgentSummarize AgentID = "summarize" -) - -type Model struct { - ID string `json:"id"` - Name string `json:"model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - ReasoningEffort string `json:"reasoning_effort"` - SupportsImages bool `json:"supports_attachments"` -} - -type VertexAIOptions struct { - APIKey string `json:"api_key,omitempty"` - Project string `json:"project,omitempty"` - Location string `json:"location,omitempty"` -} - -type ProviderConfig struct { - ID provider.InferenceProvider `json:"id"` - BaseURL string `json:"base_url,omitempty"` - ProviderType provider.Type `json:"provider_type"` - APIKey string `json:"api_key,omitempty"` - Disabled bool `json:"disabled"` - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` - // used for e.x for vertex to set the project - ExtraParams map[string]string `json:"extra_params,omitempty"` - - DefaultLargeModel string `json:"default_large_model,omitempty"` - DefaultSmallModel string `json:"default_small_model,omitempty"` - - Models []Model `json:"models,omitempty"` -} - -type Agent struct { - ID AgentID `json:"id"` - Name string `json:"name"` - Description string `json:"description,omitempty"` - // This is the id of the system prompt used by the agent - Disabled bool `json:"disabled"` - - Provider provider.InferenceProvider `json:"provider"` - Model string `json:"model"` - - // The available tools for the agent - // if this is nil, all tools are available - AllowedTools []string `json:"allowed_tools"` - - // this tells us which MCPs are available for this agent - // if this is empty all mcps are available - // the string array is the list of tools from the AllowedMCP the agent has available - // if the string array is nil, all tools from the AllowedMCP are available - AllowedMCP map[string][]string `json:"allowed_mcp"` - - // The list of LSPs that this agent can use - // if this is nil, all LSPs are available - AllowedLSP []string `json:"allowed_lsp"` - - // Overrides the context paths for this agent - ContextPaths []string `json:"context_paths"` -} - -type MCPType string - -const ( - MCPStdio MCPType = "stdio" - MCPSse MCPType = "sse" -) - -type MCP struct { - Command string `json:"command"` - Env []string `json:"env"` - Args []string `json:"args"` - Type MCPType `json:"type"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` -} - -type LSPConfig struct { - Disabled bool `json:"enabled"` - Command string `json:"command"` - Args []string `json:"args"` - Options any `json:"options"` -} - -type TUIOptions struct { - CompactMode bool `json:"compact_mode"` - // Here we can add themes later or any TUI related options -} - -type Options struct { - ContextPaths []string `json:"context_paths"` - TUI TUIOptions `json:"tui"` - Debug bool `json:"debug"` - DebugLSP bool `json:"debug_lsp"` - DisableAutoSummarize bool `json:"disable_auto_summarize"` - // Relative to the cwd - DataDirectory string `json:"data_directory"` -} - -type Config struct { - // List of configured providers - Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` - - // List of configured agents - Agents map[AgentID]Agent `json:"agents,omitempty"` - - // List of configured MCPs - MCP map[string]MCP `json:"mcp,omitempty"` - - // List of configured LSPs - LSP map[string]LSPConfig `json:"lsp,omitempty"` - - // Miscellaneous options - Options Options `json:"options"` -} - -var ( - instance *Config // The single instance of the Singleton - cwd string - once sync.Once // Ensures the initialization happens only once - -) - -func loadConfig(cwd string) (*Config, error) { - // First read the global config file - cfgPath := ConfigPath() - - cfg := defaultConfigBasedOnEnv() - - var globalCfg *Config - if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) { - // some other error occurred while checking the file - return nil, err - } else if err == nil { - // config file exists, read it - file, err := os.ReadFile(cfgPath) - if err != nil { - return nil, err - } - globalCfg = &Config{} - if err := json.Unmarshal(file, globalCfg); err != nil { - return nil, err - } - } else { - // config file does not exist, create a new one - globalCfg = &Config{} - } - - var localConfig *Config - // Global config loaded, now read the local config file - localConfigPath := filepath.Join(cwd, "crush.json") - if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) { - // some other error occurred while checking the file - return nil, err - } else if err == nil { - // local config file exists, read it - file, err := os.ReadFile(localConfigPath) - if err != nil { - return nil, err - } - localConfig = &Config{} - if err := json.Unmarshal(file, localConfig); err != nil { - return nil, err - } - } - - // merge options - mergeOptions(cfg, globalCfg, localConfig) - - mergeProviderConfigs(cfg, globalCfg, localConfig) - // no providers found the app is not initialized yet - if len(cfg.Providers) == 0 { - return cfg, nil - } - preferredProvider := getPreferredProvider(cfg.Providers) - - if preferredProvider == nil { - return nil, errors.New("no valid providers configured") - } - - agents := map[AgentID]Agent{ - AgentCoder: { - ID: AgentCoder, - Name: "Coder", - Description: "An agent that helps with executing coding tasks.", - Provider: preferredProvider.ID, - Model: preferredProvider.DefaultLargeModel, - ContextPaths: cfg.Options.ContextPaths, - // All tools allowed - }, - AgentTask: { - ID: AgentTask, - Name: "Task", - Description: "An agent that helps with searching for context and finding implementation details.", - Provider: preferredProvider.ID, - Model: preferredProvider.DefaultLargeModel, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{ - "glob", - "grep", - "ls", - "sourcegraph", - "view", - }, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - AgentTitle: { - ID: AgentTitle, - Name: "Title", - Description: "An agent that helps with generating titles for sessions.", - Provider: preferredProvider.ID, - Model: preferredProvider.DefaultSmallModel, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{}, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - AgentSummarize: { - ID: AgentSummarize, - Name: "Summarize", - Description: "An agent that helps with summarizing sessions.", - Provider: preferredProvider.ID, - Model: preferredProvider.DefaultSmallModel, - ContextPaths: cfg.Options.ContextPaths, - AllowedTools: []string{}, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - } - cfg.Agents = agents - mergeAgents(cfg, globalCfg, localConfig) - mergeMCPs(cfg, globalCfg, localConfig) - mergeLSPs(cfg, globalCfg, localConfig) - - return cfg, nil -} - -func InitConfig(workingDir string) *Config { - once.Do(func() { - cwd = workingDir - cfg, err := loadConfig(cwd) - if err != nil { - // TODO: Handle this better - panic("Failed to load config: " + err.Error()) - } - instance = cfg - }) - - return instance -} - -func GetConfig() *Config { - if instance == nil { - // TODO: Handle this better - panic("Config not initialized. Call InitConfig first.") - } - return instance -} - -func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig { - providers := Providers() - for _, p := range providers { - if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled { - return &providerConfig - } - } - // if none found return the first configured provider - for _, providerConfig := range configuredProviders { - if !providerConfig.Disabled { - return &providerConfig - } - } - return nil -} - -func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { - if other.APIKey != "" { - base.APIKey = other.APIKey - } - // Only change these options if the provider is not a known provider - if !slices.Contains(provider.KnownProviders(), p) { - if other.BaseURL != "" { - base.BaseURL = other.BaseURL - } - if other.ProviderType != "" { - base.ProviderType = other.ProviderType - } - if len(base.ExtraHeaders) > 0 { - if base.ExtraHeaders == nil { - base.ExtraHeaders = make(map[string]string) - } - maps.Copy(base.ExtraHeaders, other.ExtraHeaders) - } - if len(other.ExtraParams) > 0 { - if base.ExtraParams == nil { - base.ExtraParams = make(map[string]string) - } - maps.Copy(base.ExtraParams, other.ExtraParams) - } - } - - if other.Disabled { - base.Disabled = other.Disabled - } - - if other.DefaultLargeModel != "" { - base.DefaultLargeModel = other.DefaultLargeModel - } - // Add new models if they don't exist - if other.Models != nil { - for _, model := range other.Models { - // check if the model already exists - exists := false - for _, existingModel := range base.Models { - if existingModel.ID == model.ID { - exists = true - break - } - } - if !exists { - base.Models = append(base.Models, model) - } - } - } - - return base -} - -func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error { - if !slices.Contains(provider.KnownProviders(), p) { - if providerConfig.ProviderType != provider.TypeOpenAI { - return errors.New("invalid provider type: " + string(providerConfig.ProviderType)) - } - if providerConfig.BaseURL == "" { - return errors.New("base URL must be set for custom providers") - } - if providerConfig.APIKey == "" { - return errors.New("API key must be set for custom providers") - } - } - return nil -} - -func mergeOptions(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - baseOptions := base.Options - other := cfg.Options - if len(other.ContextPaths) > 0 { - baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...) - } - - if other.TUI.CompactMode { - baseOptions.TUI.CompactMode = other.TUI.CompactMode - } - - if other.Debug { - baseOptions.Debug = other.Debug - } - - if other.DebugLSP { - baseOptions.DebugLSP = other.DebugLSP - } - - if other.DisableAutoSummarize { - baseOptions.DisableAutoSummarize = other.DisableAutoSummarize - } - - if other.DataDirectory != "" { - baseOptions.DataDirectory = other.DataDirectory - } - base.Options = baseOptions - } -} - -func mergeAgents(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - for agentID, newAgent := range cfg.Agents { - if _, ok := base.Agents[agentID]; !ok { - newAgent.ID = agentID // Ensure the ID is set correctly - base.Agents[agentID] = newAgent - } else { - switch agentID { - case AgentCoder: - baseAgent := base.Agents[agentID] - baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider - baseAgent.AllowedMCP = newAgent.AllowedMCP - baseAgent.AllowedLSP = newAgent.AllowedLSP - base.Agents[agentID] = baseAgent - case AgentTask: - baseAgent := base.Agents[agentID] - baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider - base.Agents[agentID] = baseAgent - case AgentTitle: - baseAgent := base.Agents[agentID] - baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider - base.Agents[agentID] = baseAgent - case AgentSummarize: - baseAgent := base.Agents[agentID] - baseAgent.Model = newAgent.Model - baseAgent.Provider = newAgent.Provider - base.Agents[agentID] = baseAgent - default: - baseAgent := base.Agents[agentID] - baseAgent.Name = newAgent.Name - baseAgent.Description = newAgent.Description - baseAgent.Disabled = newAgent.Disabled - baseAgent.Provider = newAgent.Provider - baseAgent.Model = newAgent.Model - baseAgent.AllowedTools = newAgent.AllowedTools - baseAgent.AllowedMCP = newAgent.AllowedMCP - baseAgent.AllowedLSP = newAgent.AllowedLSP - base.Agents[agentID] = baseAgent - - } - } - } - } -} - -func mergeMCPs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - maps.Copy(base.MCP, cfg.MCP) - } -} - -func mergeLSPs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - maps.Copy(base.LSP, cfg.LSP) - } -} - -func mergeProviderConfigs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { - if cfg == nil { - continue - } - for providerName, globalProvider := range cfg.Providers { - if _, ok := base.Providers[providerName]; !ok { - base.Providers[providerName] = globalProvider - } else { - base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider) - } - } - } - - finalProviders := make(map[provider.InferenceProvider]ProviderConfig) - for providerName, providerConfig := range base.Providers { - err := validateProvider(providerName, providerConfig) - if err != nil { - logging.Warn("Skipping provider", "name", providerName, "error", err) - } - finalProviders[providerName] = providerConfig - } - base.Providers = finalProviders -} - -func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig { - switch providerId { - case provider.InferenceProviderAnthropic: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeAnthropic, - } - case provider.InferenceProviderOpenAI: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeOpenAI, - } - case provider.InferenceProviderGemini: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeGemini, - } - case provider.InferenceProviderBedrock: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeBedrock, - } - case provider.InferenceProviderAzure: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeAzure, - } - case provider.InferenceProviderOpenRouter: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeOpenAI, - BaseURL: "https://openrouter.ai/api/v1", - ExtraHeaders: map[string]string{ - "HTTP-Referer": "crush.charm.land", - "X-Title": "Crush", - }, - } - case provider.InferenceProviderXAI: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeXAI, - BaseURL: "https://api.x.ai/v1", - } - case provider.InferenceProviderVertexAI: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeVertexAI, - } - default: - return ProviderConfig{ - ID: providerId, - ProviderType: provider.TypeOpenAI, - } - } -} - -func defaultConfigBasedOnEnv() *Config { - cfg := &Config{ - Options: Options{ - DataDirectory: defaultDataDirectory, - ContextPaths: defaultContextPaths, - }, - Providers: make(map[provider.InferenceProvider]ProviderConfig), - } - - providers := Providers() - - for _, p := range providers { - if strings.HasPrefix(p.APIKey, "$") { - envVar := strings.TrimPrefix(p.APIKey, "$") - if apiKey := os.Getenv(envVar); apiKey != "" { - providerConfig := providerDefaultConfig(p.ID) - providerConfig.APIKey = apiKey - providerConfig.DefaultLargeModel = p.DefaultLargeModelID - providerConfig.DefaultSmallModel = p.DefaultSmallModelID - for _, model := range p.Models { - providerConfig.Models = append(providerConfig.Models, Model{ - ID: model.ID, - Name: model.Name, - CostPer1MIn: model.CostPer1MIn, - CostPer1MOut: model.CostPer1MOut, - CostPer1MInCached: model.CostPer1MInCached, - CostPer1MOutCached: model.CostPer1MOutCached, - ContextWindow: model.ContextWindow, - DefaultMaxTokens: model.DefaultMaxTokens, - CanReason: model.CanReason, - SupportsImages: model.SupportsImages, - }) - } - cfg.Providers[p.ID] = providerConfig - } - } - } - // TODO: support local models - - if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" { - providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI) - providerConfig.ExtraParams = map[string]string{ - "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), - "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), - } - cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig - } - - if hasAWSCredentials() { - providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock) - cfg.Providers[provider.InferenceProviderBedrock] = providerConfig - } - return cfg -} - -func hasAWSCredentials() bool { - if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { - return true - } - - if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { - return true - } - - if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { - return true - } - - if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || - os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { - return true - } - - return false -} - -func WorkingDirectory() string { - return cwd -} diff --git a/internal/db/connect.go b/internal/db/connect.go index 9212ce1f097e6877a9ce9b368e77d76e739b673f..6452dabdb3a5de6ffb0f618062631dfe4b25102d 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -1,7 +1,6 @@ package db import ( - "context" "database/sql" "fmt" "os" @@ -16,8 +15,8 @@ import ( "github.com/pressly/goose/v3" ) -func Connect(ctx context.Context) (*sql.DB, error) { - dataDir := config.Get().Data.Directory +func Connect() (*sql.DB, error) { + dataDir := config.Get().Options.DataDirectory if dataDir == "" { return nil, fmt.Errorf("data.dir is not set") } diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go index 2acfe18fdbc63312c49d65e9e3acb1bd24cf4d7e..81f322921db87dde7ade48ce64322aa01004d255 100644 --- a/internal/db/messages.sql.go +++ b/internal/db/messages.sql.go @@ -17,12 +17,13 @@ INSERT INTO messages ( role, parts, model, + provider, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) -RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at +RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider ` type CreateMessageParams struct { @@ -31,6 +32,7 @@ type CreateMessageParams struct { Role string `json:"role"` Parts string `json:"parts"` Model sql.NullString `json:"model"` + Provider sql.NullString `json:"provider"` } func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) { @@ -40,6 +42,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M arg.Role, arg.Parts, arg.Model, + arg.Provider, ) var i Message err := row.Scan( @@ -51,6 +54,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M &i.CreatedAt, &i.UpdatedAt, &i.FinishedAt, + &i.Provider, ) return i, err } @@ -76,7 +80,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e } const getMessage = `-- name: GetMessage :one -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider FROM messages WHERE id = ? LIMIT 1 ` @@ -93,12 +97,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) { &i.CreatedAt, &i.UpdatedAt, &i.FinishedAt, + &i.Provider, ) return i, err } const listMessagesBySession = `-- name: ListMessagesBySession :many -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider FROM messages WHERE session_id = ? ORDER BY created_at ASC @@ -122,6 +127,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) ( &i.CreatedAt, &i.UpdatedAt, &i.FinishedAt, + &i.Provider, ); err != nil { return nil, err } diff --git a/internal/db/migrations/20250627000000_add_provider_to_messages.sql b/internal/db/migrations/20250627000000_add_provider_to_messages.sql new file mode 100644 index 0000000000000000000000000000000000000000..9bf0ed9749c49640f10407c97deb032f60baaac2 --- /dev/null +++ b/internal/db/migrations/20250627000000_add_provider_to_messages.sql @@ -0,0 +1,11 @@ +-- +goose Up +-- +goose StatementBegin +-- Add provider column to messages table +ALTER TABLE messages ADD COLUMN provider TEXT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +-- Remove provider column from messages table +ALTER TABLE messages DROP COLUMN provider; +-- +goose StatementEnd \ No newline at end of file diff --git a/internal/db/models.go b/internal/db/models.go index ec19f99b213e041331b5d6a14dee3648bc14c1de..ec3e6e10ad990d0f1a3d03a7533c8b1aed184447 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -27,6 +27,7 @@ type Message struct { CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` FinishedAt sql.NullInt64 `json:"finished_at"` + Provider sql.NullString `json:"provider"` } type Session struct { diff --git a/internal/db/sql/messages.sql b/internal/db/sql/messages.sql index a59cebe7d00fe5fd7cbd449df681df45e832979a..ea946177591d1e145a59475a1ca9272f3191d4d6 100644 --- a/internal/db/sql/messages.sql +++ b/internal/db/sql/messages.sql @@ -16,10 +16,11 @@ INSERT INTO messages ( role, parts, model, + provider, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) RETURNING *; diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go index 263e8317ce8ac92d8820ba5288f2e40d2616e0e1..5f0ddeaeee708d4b5475403ce1874591f7e9bb2c 100644 --- a/internal/fur/client/client.go +++ b/internal/fur/client/client.go @@ -10,7 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/fur/provider" ) -const defaultURL = "http://localhost:8080" +const defaultURL = "https://fur.charmcli.dev" // Client represents a client for the fur service. type Client struct { diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go index 85275f1155eff219c87d85fce3cdcc436f4a4e47..8545694dea70b410a3a1912b82313bde2852d942 100644 --- a/internal/fur/provider/provider.go +++ b/internal/fur/provider/provider.go @@ -6,14 +6,13 @@ type Type string // All the supported AI provider types. const ( - TypeOpenAI Type = "openai" - TypeAnthropic Type = "anthropic" - TypeGemini Type = "gemini" - TypeAzure Type = "azure" - TypeBedrock Type = "bedrock" - TypeVertexAI Type = "vertexai" - TypeXAI Type = "xai" - TypeOpenRouter Type = "openrouter" + TypeOpenAI Type = "openai" + TypeAnthropic Type = "anthropic" + TypeGemini Type = "gemini" + TypeAzure Type = "azure" + TypeBedrock Type = "bedrock" + TypeVertexAI Type = "vertexai" + TypeXAI Type = "xai" ) // InferenceProvider represents the inference provider identifier. diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 9e5e9bc7844b055c52464032dfc4d75495f9e426..ae15c5867e7321f2ac29e1809f5eb7effb830fdc 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,17 +5,15 @@ import ( "encoding/json" "fmt" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/session" ) type agentTool struct { - sessions session.Service - messages message.Service - lspClients map[string]*lsp.Client + agent Service + sessions session.Service + messages message.Service } const ( @@ -58,17 +56,12 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients)) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) - } - session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") if err != nil { return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) } - done, err := agent.Run(ctx, session.ID, params.Prompt) + done, err := b.agent.Run(ctx, session.ID, params.Prompt) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) } @@ -101,13 +94,13 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes } func NewAgentTool( - Sessions session.Service, - Messages message.Service, - LspClients map[string]*lsp.Client, + agent Service, + sessions session.Service, + messages message.Service, ) tools.BaseTool { return &agentTool{ - sessions: Sessions, - messages: Messages, - lspClients: LspClients, + sessions: sessions, + messages: messages, + agent: agent, } } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ea2a3bd2b11735c1f0422e859adcfa65a82fdb98..f9e97b164aa98fe1ae76490fdfcf336efb43098f 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -4,16 +4,18 @@ import ( "context" "errors" "fmt" + "slices" "strings" "sync" "time" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/logging" + "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" @@ -47,71 +49,198 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] - Model() models.Model + Model() configv2.Model Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() IsSessionBusy(sessionID string) bool IsBusy() bool - Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) + Update(model configv2.PreferredModel) (configv2.Model, error) Summarize(ctx context.Context, sessionID string) error } type agent struct { *pubsub.Broker[AgentEvent] + agentCfg configv2.Agent sessions session.Service messages message.Service - tools []tools.BaseTool - provider provider.Provider + tools []tools.BaseTool + provider provider.Provider + providerID string - titleProvider provider.Provider - summarizeProvider provider.Provider + titleProvider provider.Provider + summarizeProvider provider.Provider + summarizeProviderID string activeRequests sync.Map } +var agentPromptMap = map[configv2.AgentID]prompt.PromptID{ + configv2.AgentCoder: prompt.PromptCoder, + configv2.AgentTask: prompt.PromptTask, +} + func NewAgent( - agentName config.AgentName, + agentCfg configv2.Agent, + // These services are needed in the tools + permissions permission.Service, sessions session.Service, messages message.Service, - agentTools []tools.BaseTool, + history history.Service, + lspClients map[string]*lsp.Client, ) (Service, error) { - agentProvider, err := createAgentProvider(agentName) + ctx := context.Background() + cfg := configv2.Get() + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + + allTools := []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions, history), + } + + if agentCfg.ID == configv2.AgentCoder { + taskAgentCfg := configv2.Get().Agents[configv2.AgentTask] + if taskAgentCfg.ID == "" { + return nil, fmt.Errorf("task agent not found in config") + } + taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) + if err != nil { + return nil, fmt.Errorf("failed to create task agent: %w", err) + } + + allTools = append( + allTools, + NewAgentTool( + taskAgent, + sessions, + messages, + ), + ) + } + + allTools = append(allTools, otherTools...) + var providerCfg configv2.ProviderConfig + for _, p := range cfg.Providers { + if p.ID == agentCfg.Provider { + providerCfg = p + break + } + } + if providerCfg.ID == "" { + return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider) + } + + var model configv2.Model + for _, m := range providerCfg.Models { + if m.ID == agentCfg.Model { + model = m + break + } + } + if model.ID == "" { + return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider) + } + + promptID := agentPromptMap[agentCfg.ID] + if promptID == "" { + promptID = prompt.PromptDefault + } + opts := []provider.ProviderClientOption{ + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), + provider.WithMaxTokens(model.DefaultMaxTokens), + } + agentProvider, err := provider.NewProviderV2(providerCfg, opts...) if err != nil { return nil, err } - var titleProvider provider.Provider - // Only generate titles for the coder agent - if agentName == config.AgentCoder { - titleProvider, err = createAgentProvider(config.AgentTitle) - if err != nil { - return nil, err + + smallModelCfg := cfg.Models.Small + var smallModel configv2.Model + + var smallModelProviderCfg configv2.ProviderConfig + if smallModelCfg.Provider == providerCfg.ID { + smallModelProviderCfg = providerCfg + } else { + for _, p := range cfg.Providers { + if p.ID == smallModelCfg.Provider { + smallModelProviderCfg = p + break + } + } + if smallModelProviderCfg.ID == "" { + return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) } } - var summarizeProvider provider.Provider - if agentName == config.AgentCoder { - summarizeProvider, err = createAgentProvider(config.AgentSummarizer) - if err != nil { - return nil, err + for _, m := range smallModelProviderCfg.Models { + if m.ID == smallModelCfg.ModelID { + smallModel = m + break + } + } + if smallModel.ID == "" { + return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID) + } + + titleOpts := []provider.ProviderClientOption{ + provider.WithModel(smallModel), + provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), + provider.WithMaxTokens(40), + } + titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) + if err != nil { + return nil, err + } + summarizeOpts := []provider.ProviderClientOption{ + provider.WithModel(smallModel), + provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), + provider.WithMaxTokens(smallModel.DefaultMaxTokens), + } + summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) + if err != nil { + return nil, err + } + + agentTools := []tools.BaseTool{} + if agentCfg.AllowedTools == nil { + agentTools = allTools + } else { + for _, tool := range allTools { + if slices.Contains(agentCfg.AllowedTools, tool.Name()) { + agentTools = append(agentTools, tool) + } } } agent := &agent{ - Broker: pubsub.NewBroker[AgentEvent](), - provider: agentProvider, - messages: messages, - sessions: sessions, - tools: agentTools, - titleProvider: titleProvider, - summarizeProvider: summarizeProvider, - activeRequests: sync.Map{}, + Broker: pubsub.NewBroker[AgentEvent](), + agentCfg: agentCfg, + provider: agentProvider, + providerID: string(providerCfg.ID), + messages: messages, + sessions: sessions, + tools: agentTools, + titleProvider: titleProvider, + summarizeProvider: summarizeProvider, + summarizeProviderID: string(smallModelProviderCfg.ID), + activeRequests: sync.Map{}, } return agent, nil } -func (a *agent) Model() models.Model { +func (a *agent) Model() configv2.Model { return a.provider.Model() } @@ -207,7 +336,7 @@ func (a *agent) err(err error) AgentEvent { } func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) { - if !a.provider.Model().SupportsAttachments && attachments != nil { + if !a.provider.Model().SupportsImages && attachments != nil { attachments = nil } events := make(chan AgentEvent) @@ -327,9 +456,10 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.provider.Model().ID, + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.provider.Model().ID, + Provider: a.providerID, }) if err != nil { return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) @@ -424,8 +554,9 @@ out: parts = append(parts, tr) } msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, + Role: message.Tool, + Parts: parts, + Provider: a.providerID, }) if err != nil { return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) @@ -484,7 +615,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg return nil } -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error { sess, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -506,21 +637,48 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.M return nil } -func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) { +func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) { if a.IsBusy() { - return models.Model{}, fmt.Errorf("cannot change model while processing requests") + return configv2.Model{}, fmt.Errorf("cannot change model while processing requests") } - if err := config.UpdateAgentModel(agentName, modelID); err != nil { - return models.Model{}, fmt.Errorf("failed to update config: %w", err) + cfg := configv2.Get() + var providerCfg configv2.ProviderConfig + for _, p := range cfg.Providers { + if p.ID == modelCfg.Provider { + providerCfg = p + break + } + } + if providerCfg.ID == "" { + return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider) } - provider, err := createAgentProvider(agentName) - if err != nil { - return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err) + var model configv2.Model + for _, m := range providerCfg.Models { + if m.ID == modelCfg.ModelID { + model = m + break + } + } + if model.ID == "" { + return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider) } - a.provider = provider + promptID := agentPromptMap[a.agentCfg.ID] + if promptID == "" { + promptID = prompt.PromptDefault + } + opts := []provider.ProviderClientOption{ + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), + provider.WithMaxTokens(model.DefaultMaxTokens), + } + agentProvider, err := provider.NewProviderV2(providerCfg, opts...) + if err != nil { + return configv2.Model{}, err + } + a.provider = agentProvider return a.provider.Model(), nil } @@ -654,7 +812,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { Time: time.Now().Unix(), }, }, - Model: a.summarizeProvider.Model().ID, + Model: a.summarizeProvider.Model().ID, + Provider: a.summarizeProviderID, }) if err != nil { event = AgentEvent{ @@ -705,51 +864,3 @@ func (a *agent) CancelAll() { return true }) } - -func createAgentProvider(agentName config.AgentName) (provider.Provider, error) { - cfg := config.Get() - agentConfig, ok := cfg.Agents[agentName] - if !ok { - return nil, fmt.Errorf("agent %s not found", agentName) - } - model, ok := models.SupportedModels[agentConfig.Model] - if !ok { - return nil, fmt.Errorf("model %s not supported", agentConfig.Model) - } - - providerCfg, ok := cfg.Providers[model.Provider] - if !ok { - return nil, fmt.Errorf("provider %s not supported", model.Provider) - } - if providerCfg.Disabled { - return nil, fmt.Errorf("provider %s is not enabled", model.Provider) - } - maxTokens := model.DefaultMaxTokens - if agentConfig.MaxTokens > 0 { - maxTokens = agentConfig.MaxTokens - } - opts := []provider.ProviderClientOption{ - provider.WithAPIKey(providerCfg.APIKey), - provider.WithModel(model), - provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), - provider.WithMaxTokens(maxTokens), - } - // TODO: reimplement - // if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason { - // opts = append( - // opts, - // provider.WithOpenAIOptions( - // provider.WithReasoningEffort(agentConfig.ReasoningEffort), - // ), - // ) - // } - agentProvider, err := provider.NewProvider( - model.Provider, - opts..., - ) - if err != nil { - return nil, fmt.Errorf("could not create provider: %v", err) - } - - return agentProvider, nil -} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 8fde02755eb320b8925891a3eca938c3cd7911f9..1950324fa3ed4dbd9de358d18023247b0bb429e7 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -18,7 +18,7 @@ import ( type mcpTool struct { mcpName string tool mcp.Tool - mcpConfig config.MCPServer + mcpConfig config.MCP permissions permission.Service } @@ -128,7 +128,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("invalid mcp type"), nil } -func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool { +func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool { return &mcpTool{ mcpName: name, tool: tool, @@ -139,7 +139,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC var mcpTools []tools.BaseTool -func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool { +func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool { var stdioTools []tools.BaseTool initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION @@ -170,7 +170,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba if len(mcpTools) > 0 { return mcpTools } - for name, m := range config.Get().MCPServers { + for name, m := range config.Get().MCP { switch m.Type { case config.MCPStdio: c, err := client.NewStdioMCPClient( diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go deleted file mode 100644 index 0fe2c530ca6dc30916fd2dfa094ad6303bf39443..0000000000000000000000000000000000000000 --- a/internal/llm/agent/tools.go +++ /dev/null @@ -1,50 +0,0 @@ -package agent - -import ( - "context" - - "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/crush/internal/message" - "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/session" -) - -func CoderAgentTools( - permissions permission.Service, - sessions session.Service, - messages message.Service, - history history.Service, - lspClients map[string]*lsp.Client, -) []tools.BaseTool { - ctx := context.Background() - otherTools := GetMcpTools(ctx, permissions) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - return append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions, history), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewWriteTool(lspClients, permissions, history), - NewAgentTool(sessions, messages, lspClients), - }, otherTools..., - ) -} - -func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { - return []tools.BaseTool{ - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - } -} diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go deleted file mode 100644 index 85c47def3d94034297265c506c5870f2b449d286..0000000000000000000000000000000000000000 --- a/internal/llm/models/anthropic.go +++ /dev/null @@ -1,111 +0,0 @@ -package models - -const ( - ProviderAnthropic InferenceProvider = "anthropic" - - // Models - Claude35Sonnet ModelID = "claude-3.5-sonnet" - Claude3Haiku ModelID = "claude-3-haiku" - Claude37Sonnet ModelID = "claude-3.7-sonnet" - Claude35Haiku ModelID = "claude-3.5-haiku" - Claude3Opus ModelID = "claude-3-opus" - Claude4Opus ModelID = "claude-4-opus" - Claude4Sonnet ModelID = "claude-4-sonnet" -) - -// https://docs.anthropic.com/en/docs/about-claude/models/all-models -var AnthropicModels = map[ModelID]Model{ - Claude35Sonnet: { - ID: Claude35Sonnet, - Name: "Claude 3.5 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-5-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - ContextWindow: 200000, - DefaultMaxTokens: 5000, - SupportsAttachments: true, - }, - Claude3Haiku: { - ID: Claude3Haiku, - Name: "Claude 3 Haiku", - Provider: ProviderAnthropic, - APIModel: "claude-3-haiku-20240307", // doesn't support "-latest" - CostPer1MIn: 0.25, - CostPer1MInCached: 0.30, - CostPer1MOutCached: 0.03, - CostPer1MOut: 1.25, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, - Claude37Sonnet: { - ID: Claude37Sonnet, - Name: "Claude 3.7 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-7-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - Claude35Haiku: { - ID: Claude35Haiku, - Name: "Claude 3.5 Haiku", - Provider: ProviderAnthropic, - APIModel: "claude-3-5-haiku-latest", - CostPer1MIn: 0.80, - CostPer1MInCached: 1.0, - CostPer1MOutCached: 0.08, - CostPer1MOut: 4.0, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, - Claude3Opus: { - ID: Claude3Opus, - Name: "Claude 3 Opus", - Provider: ProviderAnthropic, - APIModel: "claude-3-opus-latest", - CostPer1MIn: 15.0, - CostPer1MInCached: 18.75, - CostPer1MOutCached: 1.50, - CostPer1MOut: 75.0, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, - Claude4Sonnet: { - ID: Claude4Sonnet, - Name: "Claude 4 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-sonnet-4-20250514", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - ContextWindow: 200000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - Claude4Opus: { - ID: Claude4Opus, - Name: "Claude 4 Opus", - Provider: ProviderAnthropic, - APIModel: "claude-opus-4-20250514", - CostPer1MIn: 15.0, - CostPer1MInCached: 18.75, - CostPer1MOutCached: 1.50, - CostPer1MOut: 75.0, - ContextWindow: 200000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/azure.go b/internal/llm/models/azure.go deleted file mode 100644 index eb7ae293ee053d953f5bcbb20120089ca6bae95b..0000000000000000000000000000000000000000 --- a/internal/llm/models/azure.go +++ /dev/null @@ -1,168 +0,0 @@ -package models - -const ProviderAzure InferenceProvider = "azure" - -const ( - AzureGPT41 ModelID = "azure.gpt-4.1" - AzureGPT41Mini ModelID = "azure.gpt-4.1-mini" - AzureGPT41Nano ModelID = "azure.gpt-4.1-nano" - AzureGPT45Preview ModelID = "azure.gpt-4.5-preview" - AzureGPT4o ModelID = "azure.gpt-4o" - AzureGPT4oMini ModelID = "azure.gpt-4o-mini" - AzureO1 ModelID = "azure.o1" - AzureO1Mini ModelID = "azure.o1-mini" - AzureO3 ModelID = "azure.o3" - AzureO3Mini ModelID = "azure.o3-mini" - AzureO4Mini ModelID = "azure.o4-mini" -) - -var AzureModels = map[ModelID]Model{ - AzureGPT41: { - ID: AzureGPT41, - Name: "Azure OpenAI – GPT 4.1", - Provider: ProviderAzure, - APIModel: "gpt-4.1", - CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT41Mini: { - ID: AzureGPT41Mini, - Name: "Azure OpenAI – GPT 4.1 mini", - Provider: ProviderAzure, - APIModel: "gpt-4.1-mini", - CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT41Nano: { - ID: AzureGPT41Nano, - Name: "Azure OpenAI – GPT 4.1 nano", - Provider: ProviderAzure, - APIModel: "gpt-4.1-nano", - CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41Nano].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT45Preview: { - ID: AzureGPT45Preview, - Name: "Azure OpenAI – GPT 4.5 preview", - Provider: ProviderAzure, - APIModel: "gpt-4.5-preview", - CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT45Preview].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT4o: { - ID: AzureGPT4o, - Name: "Azure OpenAI – GPT-4o", - Provider: ProviderAzure, - APIModel: "gpt-4o", - CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT4o].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureGPT4oMini: { - ID: AzureGPT4oMini, - Name: "Azure OpenAI – GPT-4o mini", - Provider: ProviderAzure, - APIModel: "gpt-4o-mini", - CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT4oMini].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens, - SupportsAttachments: true, - }, - AzureO1: { - ID: AzureO1, - Name: "Azure OpenAI – O1", - Provider: ProviderAzure, - APIModel: "o1", - CostPer1MIn: OpenAIModels[O1].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens, - CanReason: OpenAIModels[O1].CanReason, - SupportsAttachments: true, - }, - AzureO1Mini: { - ID: AzureO1Mini, - Name: "Azure OpenAI – O1 mini", - Provider: ProviderAzure, - APIModel: "o1-mini", - CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O1Mini].CanReason, - SupportsAttachments: true, - }, - AzureO3: { - ID: AzureO3, - Name: "Azure OpenAI – O3", - Provider: ProviderAzure, - APIModel: "o3", - CostPer1MIn: OpenAIModels[O3].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O3].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached, - ContextWindow: OpenAIModels[O3].ContextWindow, - DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens, - CanReason: OpenAIModels[O3].CanReason, - SupportsAttachments: true, - }, - AzureO3Mini: { - ID: AzureO3Mini, - Name: "Azure OpenAI – O3 mini", - Provider: ProviderAzure, - APIModel: "o3-mini", - CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O3Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O3Mini].CanReason, - SupportsAttachments: false, - }, - AzureO4Mini: { - ID: AzureO4Mini, - Name: "Azure OpenAI – O4 mini", - Provider: ProviderAzure, - APIModel: "o4-mini", - CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O4Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O4Mini].CanReason, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/gemini.go b/internal/llm/models/gemini.go deleted file mode 100644 index 9749c6d3409acf7b05cd67690504e2cb3ac4fd39..0000000000000000000000000000000000000000 --- a/internal/llm/models/gemini.go +++ /dev/null @@ -1,67 +0,0 @@ -package models - -const ( - ProviderGemini InferenceProvider = "gemini" - - // Models - Gemini25Flash ModelID = "gemini-2.5-flash" - Gemini25 ModelID = "gemini-2.5" - Gemini20Flash ModelID = "gemini-2.0-flash" - Gemini20FlashLite ModelID = "gemini-2.0-flash-lite" -) - -var GeminiModels = map[ModelID]Model{ - Gemini25Flash: { - ID: Gemini25Flash, - Name: "Gemini 2.5 Flash", - Provider: ProviderGemini, - APIModel: "gemini-2.5-flash-preview-04-17", - CostPer1MIn: 0.15, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.60, - ContextWindow: 1000000, - DefaultMaxTokens: 50000, - SupportsAttachments: true, - }, - Gemini25: { - ID: Gemini25, - Name: "Gemini 2.5 Pro", - Provider: ProviderGemini, - APIModel: "gemini-2.5-pro-preview-05-06", - CostPer1MIn: 1.25, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 10, - ContextWindow: 1000000, - DefaultMaxTokens: 50000, - SupportsAttachments: true, - }, - - Gemini20Flash: { - ID: Gemini20Flash, - Name: "Gemini 2.0 Flash", - Provider: ProviderGemini, - APIModel: "gemini-2.0-flash", - CostPer1MIn: 0.10, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.40, - ContextWindow: 1000000, - DefaultMaxTokens: 6000, - SupportsAttachments: true, - }, - Gemini20FlashLite: { - ID: Gemini20FlashLite, - Name: "Gemini 2.0 Flash Lite", - Provider: ProviderGemini, - APIModel: "gemini-2.0-flash-lite", - CostPer1MIn: 0.05, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.30, - ContextWindow: 1000000, - DefaultMaxTokens: 6000, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/groq.go b/internal/llm/models/groq.go deleted file mode 100644 index 39288962c8e42a1acec8a01b3157b10d9b00b5dc..0000000000000000000000000000000000000000 --- a/internal/llm/models/groq.go +++ /dev/null @@ -1,87 +0,0 @@ -package models - -const ( - ProviderGROQ InferenceProvider = "groq" - - // GROQ - QWENQwq ModelID = "qwen-qwq" - - // GROQ preview models - Llama4Scout ModelID = "meta-llama/llama-4-scout-17b-16e-instruct" - Llama4Maverick ModelID = "meta-llama/llama-4-maverick-17b-128e-instruct" - Llama3_3_70BVersatile ModelID = "llama-3.3-70b-versatile" - DeepseekR1DistillLlama70b ModelID = "deepseek-r1-distill-llama-70b" -) - -var GroqModels = map[ModelID]Model{ - // - // GROQ - QWENQwq: { - ID: QWENQwq, - Name: "Qwen Qwq", - Provider: ProviderGROQ, - APIModel: "qwen-qwq-32b", - CostPer1MIn: 0.29, - CostPer1MInCached: 0.275, - CostPer1MOutCached: 0.0, - CostPer1MOut: 0.39, - ContextWindow: 128_000, - DefaultMaxTokens: 50000, - // for some reason, the groq api doesn't like the reasoningEffort parameter - CanReason: false, - SupportsAttachments: false, - }, - - Llama4Scout: { - ID: Llama4Scout, - Name: "Llama4Scout", - Provider: ProviderGROQ, - APIModel: "meta-llama/llama-4-scout-17b-16e-instruct", - CostPer1MIn: 0.11, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.34, - ContextWindow: 128_000, // 10M when? - SupportsAttachments: true, - }, - - Llama4Maverick: { - ID: Llama4Maverick, - Name: "Llama4Maverick", - Provider: ProviderGROQ, - APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct", - CostPer1MIn: 0.20, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.20, - ContextWindow: 128_000, - SupportsAttachments: true, - }, - - Llama3_3_70BVersatile: { - ID: Llama3_3_70BVersatile, - Name: "Llama3_3_70BVersatile", - Provider: ProviderGROQ, - APIModel: "llama-3.3-70b-versatile", - CostPer1MIn: 0.59, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.79, - ContextWindow: 128_000, - SupportsAttachments: false, - }, - - DeepseekR1DistillLlama70b: { - ID: DeepseekR1DistillLlama70b, - Name: "DeepseekR1DistillLlama70b", - Provider: ProviderGROQ, - APIModel: "deepseek-r1-distill-llama-70b", - CostPer1MIn: 0.75, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0.99, - ContextWindow: 128_000, - CanReason: true, - SupportsAttachments: false, - }, -} diff --git a/internal/llm/models/local.go b/internal/llm/models/local.go deleted file mode 100644 index c469e99fd65d5befbfffe5126a31c88eae68e150..0000000000000000000000000000000000000000 --- a/internal/llm/models/local.go +++ /dev/null @@ -1,206 +0,0 @@ -package models - -import ( - "cmp" - "context" - "encoding/json" - "net/http" - "net/url" - "os" - "regexp" - "strings" - "unicode" - - "github.com/charmbracelet/crush/internal/logging" - "github.com/spf13/viper" -) - -const ( - ProviderLocal InferenceProvider = "local" - - localModelsPath = "v1/models" - lmStudioBetaModelsPath = "api/v0/models" -) - -func init() { - if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" { - localEndpoint, err := url.Parse(endpoint) - if err != nil { - logging.Debug("Failed to parse local endpoint", - "error", err, - "endpoint", endpoint, - ) - return - } - - load := func(url *url.URL, path string) []localModel { - url.Path = path - return listLocalModels(url.String()) - } - - models := load(localEndpoint, lmStudioBetaModelsPath) - - if len(models) == 0 { - models = load(localEndpoint, localModelsPath) - } - - if len(models) == 0 { - logging.Debug("No local models found", - "endpoint", endpoint, - ) - return - } - - loadLocalModels(models) - - viper.SetDefault("providers.local.apiKey", "dummy") - } -} - -type localModelList struct { - Data []localModel `json:"data"` -} - -type localModel struct { - ID string `json:"id"` - Object string `json:"object"` - Type string `json:"type"` - Publisher string `json:"publisher"` - Arch string `json:"arch"` - CompatibilityType string `json:"compatibility_type"` - Quantization string `json:"quantization"` - State string `json:"state"` - MaxContextLength int64 `json:"max_context_length"` - LoadedContextLength int64 `json:"loaded_context_length"` -} - -func listLocalModels(modelsEndpoint string) []localModel { - res, err := http.NewRequestWithContext(context.Background(), http.MethodGet, modelsEndpoint, nil) - if err != nil { - logging.Debug("Failed to list local models", - "error", err, - "endpoint", modelsEndpoint, - ) - } - defer res.Body.Close() - - if res.Response.StatusCode != http.StatusOK { - logging.Debug("Failed to list local models", - "status", res.Response.Status, - "endpoint", modelsEndpoint, - ) - } - - var modelList localModelList - if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil { - logging.Debug("Failed to list local models", - "error", err, - "endpoint", modelsEndpoint, - ) - } - - var supportedModels []localModel - for _, model := range modelList.Data { - if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) { - if model.Object != "model" || model.Type != "llm" { - logging.Debug("Skipping unsupported LMStudio model", - "endpoint", modelsEndpoint, - "id", model.ID, - "object", model.Object, - "type", model.Type, - ) - - continue - } - } - - supportedModels = append(supportedModels, model) - } - - return supportedModels -} - -func loadLocalModels(models []localModel) { - for i, m := range models { - model := convertLocalModel(m) - SupportedModels[model.ID] = model - - if i == 0 || m.State == "loaded" { - viper.SetDefault("agents.coder.model", model.ID) - viper.SetDefault("agents.summarizer.model", model.ID) - viper.SetDefault("agents.task.model", model.ID) - viper.SetDefault("agents.title.model", model.ID) - } - } -} - -func convertLocalModel(model localModel) Model { - return Model{ - ID: ModelID("local." + model.ID), - Name: friendlyModelName(model.ID), - Provider: ProviderLocal, - APIModel: model.ID, - ContextWindow: cmp.Or(model.LoadedContextLength, 4096), - DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096), - CanReason: true, - SupportsAttachments: true, - } -} - -var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`) - -func friendlyModelName(modelID string) string { - mainID := modelID - tag := "" - - if slash := strings.LastIndex(mainID, "/"); slash != -1 { - mainID = mainID[slash+1:] - } - - if at := strings.Index(modelID, "@"); at != -1 { - mainID = modelID[:at] - tag = modelID[at+1:] - } - - match := modelInfoRegex.FindStringSubmatch(mainID) - if match == nil { - return modelID - } - - capitalize := func(s string) string { - if s == "" { - return "" - } - runes := []rune(s) - runes[0] = unicode.ToUpper(runes[0]) - return string(runes) - } - - family := capitalize(match[1]) - version := "" - label := "" - - if len(match) > 2 && match[2] != "" { - version = strings.ToUpper(match[2]) - } - - if len(match) > 3 && match[3] != "" { - label = capitalize(match[3]) - } - - var parts []string - if family != "" { - parts = append(parts, family) - } - if version != "" { - parts = append(parts, version) - } - if label != "" { - parts = append(parts, label) - } - if tag != "" { - parts = append(parts, tag) - } - - return strings.Join(parts, " ") -} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go deleted file mode 100644 index 0aefc170d32d1023f0d246a2cc7522e895453a88..0000000000000000000000000000000000000000 --- a/internal/llm/models/models.go +++ /dev/null @@ -1,74 +0,0 @@ -package models - -import "maps" - -type ( - ModelID string - InferenceProvider string -) - -type Model struct { - ID ModelID `json:"id"` - Name string `json:"name"` - Provider InferenceProvider `json:"provider"` - APIModel string `json:"api_model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - SupportsAttachments bool `json:"supports_attachments"` -} - -// Model IDs -const ( // GEMINI - // Bedrock - BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet" -) - -const ( - ProviderBedrock InferenceProvider = "bedrock" - // ForTests - ProviderMock InferenceProvider = "__mock" -) - -var SupportedModels = map[ModelID]Model{ - // Bedrock - BedrockClaude37Sonnet: { - ID: BedrockClaude37Sonnet, - Name: "Bedrock: Claude 3.7 Sonnet", - Provider: ProviderBedrock, - APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, -} - -var KnownProviders = []InferenceProvider{ - ProviderAnthropic, - ProviderOpenAI, - ProviderGemini, - ProviderAzure, - ProviderGROQ, - ProviderLocal, - ProviderOpenRouter, - ProviderVertexAI, - ProviderBedrock, - ProviderXAI, - ProviderMock, -} - -func init() { - maps.Copy(SupportedModels, AnthropicModels) - maps.Copy(SupportedModels, OpenAIModels) - maps.Copy(SupportedModels, GeminiModels) - maps.Copy(SupportedModels, GroqModels) - maps.Copy(SupportedModels, AzureModels) - maps.Copy(SupportedModels, OpenRouterModels) - maps.Copy(SupportedModels, XAIModels) - maps.Copy(SupportedModels, VertexAIGeminiModels) -} diff --git a/internal/llm/models/openai.go b/internal/llm/models/openai.go deleted file mode 100644 index e4173277cbdfe1e579068d2981df1e70b9943cb1..0000000000000000000000000000000000000000 --- a/internal/llm/models/openai.go +++ /dev/null @@ -1,181 +0,0 @@ -package models - -const ( - ProviderOpenAI InferenceProvider = "openai" - - GPT41 ModelID = "gpt-4.1" - GPT41Mini ModelID = "gpt-4.1-mini" - GPT41Nano ModelID = "gpt-4.1-nano" - GPT45Preview ModelID = "gpt-4.5-preview" - GPT4o ModelID = "gpt-4o" - GPT4oMini ModelID = "gpt-4o-mini" - O1 ModelID = "o1" - O1Pro ModelID = "o1-pro" - O1Mini ModelID = "o1-mini" - O3 ModelID = "o3" - O3Mini ModelID = "o3-mini" - O4Mini ModelID = "o4-mini" -) - -var OpenAIModels = map[ModelID]Model{ - GPT41: { - ID: GPT41, - Name: "GPT 4.1", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1", - CostPer1MIn: 2.00, - CostPer1MInCached: 0.50, - CostPer1MOutCached: 0.0, - CostPer1MOut: 8.00, - ContextWindow: 1_047_576, - DefaultMaxTokens: 20000, - SupportsAttachments: true, - }, - GPT41Mini: { - ID: GPT41Mini, - Name: "GPT 4.1 mini", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1", - CostPer1MIn: 0.40, - CostPer1MInCached: 0.10, - CostPer1MOutCached: 0.0, - CostPer1MOut: 1.60, - ContextWindow: 200_000, - DefaultMaxTokens: 20000, - SupportsAttachments: true, - }, - GPT41Nano: { - ID: GPT41Nano, - Name: "GPT 4.1 nano", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1-nano", - CostPer1MIn: 0.10, - CostPer1MInCached: 0.025, - CostPer1MOutCached: 0.0, - CostPer1MOut: 0.40, - ContextWindow: 1_047_576, - DefaultMaxTokens: 20000, - SupportsAttachments: true, - }, - GPT45Preview: { - ID: GPT45Preview, - Name: "GPT 4.5 preview", - Provider: ProviderOpenAI, - APIModel: "gpt-4.5-preview", - CostPer1MIn: 75.00, - CostPer1MInCached: 37.50, - CostPer1MOutCached: 0.0, - CostPer1MOut: 150.00, - ContextWindow: 128_000, - DefaultMaxTokens: 15000, - SupportsAttachments: true, - }, - GPT4o: { - ID: GPT4o, - Name: "GPT 4o", - Provider: ProviderOpenAI, - APIModel: "gpt-4o", - CostPer1MIn: 2.50, - CostPer1MInCached: 1.25, - CostPer1MOutCached: 0.0, - CostPer1MOut: 10.00, - ContextWindow: 128_000, - DefaultMaxTokens: 4096, - SupportsAttachments: true, - }, - GPT4oMini: { - ID: GPT4oMini, - Name: "GPT 4o mini", - Provider: ProviderOpenAI, - APIModel: "gpt-4o-mini", - CostPer1MIn: 0.15, - CostPer1MInCached: 0.075, - CostPer1MOutCached: 0.0, - CostPer1MOut: 0.60, - ContextWindow: 128_000, - SupportsAttachments: true, - }, - O1: { - ID: O1, - Name: "O1", - Provider: ProviderOpenAI, - APIModel: "o1", - CostPer1MIn: 15.00, - CostPer1MInCached: 7.50, - CostPer1MOutCached: 0.0, - CostPer1MOut: 60.00, - ContextWindow: 200_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - O1Pro: { - ID: O1Pro, - Name: "o1 pro", - Provider: ProviderOpenAI, - APIModel: "o1-pro", - CostPer1MIn: 150.00, - CostPer1MInCached: 0.0, - CostPer1MOutCached: 0.0, - CostPer1MOut: 600.00, - ContextWindow: 200_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - O1Mini: { - ID: O1Mini, - Name: "o1 mini", - Provider: ProviderOpenAI, - APIModel: "o1-mini", - CostPer1MIn: 1.10, - CostPer1MInCached: 0.55, - CostPer1MOutCached: 0.0, - CostPer1MOut: 4.40, - ContextWindow: 128_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, - O3: { - ID: O3, - Name: "o3", - Provider: ProviderOpenAI, - APIModel: "o3", - CostPer1MIn: 10.00, - CostPer1MInCached: 2.50, - CostPer1MOutCached: 0.0, - CostPer1MOut: 40.00, - ContextWindow: 200_000, - CanReason: true, - SupportsAttachments: true, - }, - O3Mini: { - ID: O3Mini, - Name: "o3 mini", - Provider: ProviderOpenAI, - APIModel: "o3-mini", - CostPer1MIn: 1.10, - CostPer1MInCached: 0.55, - CostPer1MOutCached: 0.0, - CostPer1MOut: 4.40, - ContextWindow: 200_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: false, - }, - O4Mini: { - ID: O4Mini, - Name: "o4 mini", - Provider: ProviderOpenAI, - APIModel: "o4-mini", - CostPer1MIn: 1.10, - CostPer1MInCached: 0.275, - CostPer1MOutCached: 0.0, - CostPer1MOut: 4.40, - ContextWindow: 128_000, - DefaultMaxTokens: 50000, - CanReason: true, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/openrouter.go b/internal/llm/models/openrouter.go deleted file mode 100644 index 8884e03442d30787fd505ca6a6c518d299748752..0000000000000000000000000000000000000000 --- a/internal/llm/models/openrouter.go +++ /dev/null @@ -1,276 +0,0 @@ -package models - -const ( - ProviderOpenRouter InferenceProvider = "openrouter" - - OpenRouterGPT41 ModelID = "openrouter.gpt-4.1" - OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini" - OpenRouterGPT41Nano ModelID = "openrouter.gpt-4.1-nano" - OpenRouterGPT45Preview ModelID = "openrouter.gpt-4.5-preview" - OpenRouterGPT4o ModelID = "openrouter.gpt-4o" - OpenRouterGPT4oMini ModelID = "openrouter.gpt-4o-mini" - OpenRouterO1 ModelID = "openrouter.o1" - OpenRouterO1Pro ModelID = "openrouter.o1-pro" - OpenRouterO1Mini ModelID = "openrouter.o1-mini" - OpenRouterO3 ModelID = "openrouter.o3" - OpenRouterO3Mini ModelID = "openrouter.o3-mini" - OpenRouterO4Mini ModelID = "openrouter.o4-mini" - OpenRouterGemini25Flash ModelID = "openrouter.gemini-2.5-flash" - OpenRouterGemini25 ModelID = "openrouter.gemini-2.5" - OpenRouterClaude35Sonnet ModelID = "openrouter.claude-3.5-sonnet" - OpenRouterClaude3Haiku ModelID = "openrouter.claude-3-haiku" - OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet" - OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku" - OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus" - OpenRouterDeepSeekR1Free ModelID = "openrouter.deepseek-r1-free" -) - -var OpenRouterModels = map[ModelID]Model{ - OpenRouterGPT41: { - ID: OpenRouterGPT41, - Name: "OpenRouter – GPT 4.1", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4.1", - CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens, - }, - OpenRouterGPT41Mini: { - ID: OpenRouterGPT41Mini, - Name: "OpenRouter – GPT 4.1 mini", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4.1-mini", - CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens, - }, - OpenRouterGPT41Nano: { - ID: OpenRouterGPT41Nano, - Name: "OpenRouter – GPT 4.1 nano", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4.1-nano", - CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT41Nano].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens, - }, - OpenRouterGPT45Preview: { - ID: OpenRouterGPT45Preview, - Name: "OpenRouter – GPT 4.5 preview", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4.5-preview", - CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT45Preview].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens, - }, - OpenRouterGPT4o: { - ID: OpenRouterGPT4o, - Name: "OpenRouter – GPT 4o", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4o", - CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT4o].ContextWindow, - DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens, - }, - OpenRouterGPT4oMini: { - ID: OpenRouterGPT4oMini, - Name: "OpenRouter – GPT 4o mini", - Provider: ProviderOpenRouter, - APIModel: "openai/gpt-4o-mini", - CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached, - ContextWindow: OpenAIModels[GPT4oMini].ContextWindow, - }, - OpenRouterO1: { - ID: OpenRouterO1, - Name: "OpenRouter – O1", - Provider: ProviderOpenRouter, - APIModel: "openai/o1", - CostPer1MIn: OpenAIModels[O1].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens, - CanReason: OpenAIModels[O1].CanReason, - }, - OpenRouterO1Pro: { - ID: OpenRouterO1Pro, - Name: "OpenRouter – o1 pro", - Provider: ProviderOpenRouter, - APIModel: "openai/o1-pro", - CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1Pro].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1Pro].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1Pro].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1Pro].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1Pro].DefaultMaxTokens, - CanReason: OpenAIModels[O1Pro].CanReason, - }, - OpenRouterO1Mini: { - ID: OpenRouterO1Mini, - Name: "OpenRouter – o1 mini", - Provider: ProviderOpenRouter, - APIModel: "openai/o1-mini", - CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O1Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O1Mini].CanReason, - }, - OpenRouterO3: { - ID: OpenRouterO3, - Name: "OpenRouter – o3", - Provider: ProviderOpenRouter, - APIModel: "openai/o3", - CostPer1MIn: OpenAIModels[O3].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O3].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached, - ContextWindow: OpenAIModels[O3].ContextWindow, - DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens, - CanReason: OpenAIModels[O3].CanReason, - }, - OpenRouterO3Mini: { - ID: OpenRouterO3Mini, - Name: "OpenRouter – o3 mini", - Provider: ProviderOpenRouter, - APIModel: "openai/o3-mini-high", - CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O3Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O3Mini].CanReason, - }, - OpenRouterO4Mini: { - ID: OpenRouterO4Mini, - Name: "OpenRouter – o4 mini", - Provider: ProviderOpenRouter, - APIModel: "openai/o4-mini-high", - CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn, - CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached, - CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut, - CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached, - ContextWindow: OpenAIModels[O4Mini].ContextWindow, - DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens, - CanReason: OpenAIModels[O4Mini].CanReason, - }, - OpenRouterGemini25Flash: { - ID: OpenRouterGemini25Flash, - Name: "OpenRouter – Gemini 2.5 Flash", - Provider: ProviderOpenRouter, - APIModel: "google/gemini-2.5-flash-preview:thinking", - CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn, - CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached, - CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut, - CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached, - ContextWindow: GeminiModels[Gemini25Flash].ContextWindow, - DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens, - }, - OpenRouterGemini25: { - ID: OpenRouterGemini25, - Name: "OpenRouter – Gemini 2.5 Pro", - Provider: ProviderOpenRouter, - APIModel: "google/gemini-2.5-pro-preview-03-25", - CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn, - CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached, - CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut, - CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached, - ContextWindow: GeminiModels[Gemini25].ContextWindow, - DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens, - }, - OpenRouterClaude35Sonnet: { - ID: OpenRouterClaude35Sonnet, - Name: "OpenRouter – Claude 3.5 Sonnet", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3.5-sonnet", - CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude35Sonnet].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude35Sonnet].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude35Sonnet].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude35Sonnet].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude35Sonnet].DefaultMaxTokens, - }, - OpenRouterClaude3Haiku: { - ID: OpenRouterClaude3Haiku, - Name: "OpenRouter – Claude 3 Haiku", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3-haiku", - CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude3Haiku].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude3Haiku].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude3Haiku].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude3Haiku].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude3Haiku].DefaultMaxTokens, - }, - OpenRouterClaude37Sonnet: { - ID: OpenRouterClaude37Sonnet, - Name: "OpenRouter – Claude 3.7 Sonnet", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3.7-sonnet", - CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude37Sonnet].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude37Sonnet].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude37Sonnet].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude37Sonnet].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude37Sonnet].DefaultMaxTokens, - CanReason: AnthropicModels[Claude37Sonnet].CanReason, - }, - OpenRouterClaude35Haiku: { - ID: OpenRouterClaude35Haiku, - Name: "OpenRouter – Claude 3.5 Haiku", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3.5-haiku", - CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude35Haiku].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude35Haiku].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude35Haiku].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude35Haiku].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude35Haiku].DefaultMaxTokens, - }, - OpenRouterClaude3Opus: { - ID: OpenRouterClaude3Opus, - Name: "OpenRouter – Claude 3 Opus", - Provider: ProviderOpenRouter, - APIModel: "anthropic/claude-3-opus", - CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn, - CostPer1MInCached: AnthropicModels[Claude3Opus].CostPer1MInCached, - CostPer1MOut: AnthropicModels[Claude3Opus].CostPer1MOut, - CostPer1MOutCached: AnthropicModels[Claude3Opus].CostPer1MOutCached, - ContextWindow: AnthropicModels[Claude3Opus].ContextWindow, - DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens, - }, - - OpenRouterDeepSeekR1Free: { - ID: OpenRouterDeepSeekR1Free, - Name: "OpenRouter – DeepSeek R1 Free", - Provider: ProviderOpenRouter, - APIModel: "deepseek/deepseek-r1-0528:free", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOut: 0, - CostPer1MOutCached: 0, - ContextWindow: 163_840, - DefaultMaxTokens: 10000, - }, -} diff --git a/internal/llm/models/vertexai.go b/internal/llm/models/vertexai.go deleted file mode 100644 index c9b5744b62c28e2529cac44b1e97234158d2eacf..0000000000000000000000000000000000000000 --- a/internal/llm/models/vertexai.go +++ /dev/null @@ -1,38 +0,0 @@ -package models - -const ( - ProviderVertexAI InferenceProvider = "vertexai" - - // Models - VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash" - VertexAIGemini25 ModelID = "vertexai.gemini-2.5" -) - -var VertexAIGeminiModels = map[ModelID]Model{ - VertexAIGemini25Flash: { - ID: VertexAIGemini25Flash, - Name: "VertexAI: Gemini 2.5 Flash", - Provider: ProviderVertexAI, - APIModel: "gemini-2.5-flash-preview-04-17", - CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn, - CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached, - CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut, - CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached, - ContextWindow: GeminiModels[Gemini25Flash].ContextWindow, - DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens, - SupportsAttachments: true, - }, - VertexAIGemini25: { - ID: VertexAIGemini25, - Name: "VertexAI: Gemini 2.5 Pro", - Provider: ProviderVertexAI, - APIModel: "gemini-2.5-pro-preview-03-25", - CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn, - CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached, - CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut, - CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached, - ContextWindow: GeminiModels[Gemini25].ContextWindow, - DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens, - SupportsAttachments: true, - }, -} diff --git a/internal/llm/models/xai.go b/internal/llm/models/xai.go deleted file mode 100644 index a59eac97ee6fee5db5550663083062099512eddc..0000000000000000000000000000000000000000 --- a/internal/llm/models/xai.go +++ /dev/null @@ -1,61 +0,0 @@ -package models - -const ( - ProviderXAI InferenceProvider = "xai" - - XAIGrok3Beta ModelID = "grok-3-beta" - XAIGrok3MiniBeta ModelID = "grok-3-mini-beta" - XAIGrok3FastBeta ModelID = "grok-3-fast-beta" - XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta" -) - -var XAIModels = map[ModelID]Model{ - XAIGrok3Beta: { - ID: XAIGrok3Beta, - Name: "Grok3 Beta", - Provider: ProviderXAI, - APIModel: "grok-3-beta", - CostPer1MIn: 3.0, - CostPer1MInCached: 0, - CostPer1MOut: 15, - CostPer1MOutCached: 0, - ContextWindow: 131_072, - DefaultMaxTokens: 20_000, - }, - XAIGrok3MiniBeta: { - ID: XAIGrok3MiniBeta, - Name: "Grok3 Mini Beta", - Provider: ProviderXAI, - APIModel: "grok-3-mini-beta", - CostPer1MIn: 0.3, - CostPer1MInCached: 0, - CostPer1MOut: 0.5, - CostPer1MOutCached: 0, - ContextWindow: 131_072, - DefaultMaxTokens: 20_000, - }, - XAIGrok3FastBeta: { - ID: XAIGrok3FastBeta, - Name: "Grok3 Fast Beta", - Provider: ProviderXAI, - APIModel: "grok-3-fast-beta", - CostPer1MIn: 5, - CostPer1MInCached: 0, - CostPer1MOut: 25, - CostPer1MOutCached: 0, - ContextWindow: 131_072, - DefaultMaxTokens: 20_000, - }, - XAiGrok3MiniFastBeta: { - ID: XAiGrok3MiniFastBeta, - Name: "Grok3 Mini Fast Beta", - Provider: ProviderXAI, - APIModel: "grok-3-mini-fast-beta", - CostPer1MIn: 0.6, - CostPer1MInCached: 0, - CostPer1MOut: 4.0, - CostPer1MOutCached: 0, - ContextWindow: 131_072, - DefaultMaxTokens: 20_000, - }, -} diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index b272f4e9f263ff596d06aae787e8b5a1c3ac2aec..9f1e5e7c19e739167bb9ab2bd359218e88fd4367 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -9,19 +9,27 @@ import ( "time" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/logging" ) -func CoderPrompt(provider models.InferenceProvider) string { +func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string { basePrompt := baseAnthropicCoderPrompt - switch provider { - case models.ProviderOpenAI: + switch p { + case provider.InferenceProviderOpenAI: basePrompt = baseOpenAICoderPrompt } envInfo := getEnvironmentInfo() - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) + basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) + + contextContent := getContextFromPaths(contextFiles) + logging.Debug("Context content", "Context", contextContent) + if contextContent != "" { + return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent) + } + return basePrompt } const baseOpenAICoderPrompt = ` diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index ed75d29c500cce16f16d06892ad8fcabc254a08d..36148edd9c71790c3a4cb06d551cdee06272c8b7 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -1,60 +1,44 @@ package prompt import ( - "fmt" "os" "path/filepath" "strings" "sync" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" - "github.com/charmbracelet/crush/internal/logging" + "github.com/charmbracelet/crush/internal/fur/provider" ) -func GetAgentPrompt(agentName config.AgentName, provider models.InferenceProvider) string { +type PromptID string + +const ( + PromptCoder PromptID = "coder" + PromptTitle PromptID = "title" + PromptTask PromptID = "task" + PromptSummarizer PromptID = "summarizer" + PromptDefault PromptID = "default" +) + +func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string { basePrompt := "" - switch agentName { - case config.AgentCoder: + switch promptID { + case PromptCoder: basePrompt = CoderPrompt(provider) - case config.AgentTitle: + case PromptTitle: basePrompt = TitlePrompt(provider) - case config.AgentTask: + case PromptTask: basePrompt = TaskPrompt(provider) - case config.AgentSummarizer: + case PromptSummarizer: basePrompt = SummarizerPrompt(provider) default: basePrompt = "You are a helpful assistant" } - - if agentName == config.AgentCoder || agentName == config.AgentTask { - // Add context from project-specific instruction files if they exist - contextContent := getContextFromPaths() - logging.Debug("Context content", "Context", contextContent) - if contextContent != "" { - return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent) - } - } return basePrompt } -var ( - onceContext sync.Once - contextContent string -) - -func getContextFromPaths() string { - onceContext.Do(func() { - var ( - cfg = config.Get() - workDir = cfg.WorkingDir - contextPaths = cfg.ContextPaths - ) - - contextContent = processContextPaths(workDir, contextPaths) - }) - - return contextContent +func getContextFromPaths(contextPaths []string) string { + return processContextPaths(config.WorkingDirectory(), contextPaths) } func processContextPaths(workDir string, paths []string) string { diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index a350c55a32260173dabd56e22d9e514e97b3e5a3..41e3fe92c7fb5615b6c93e2aa89bad35820567ef 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -15,16 +15,10 @@ func TestGetContextFromPaths(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - _, err := config.Load(tmpDir, false) + _, err := config.Init(tmpDir, false) if err != nil { t.Fatalf("Failed to load config: %v", err) } - cfg := config.Get() - cfg.WorkingDir = tmpDir - cfg.ContextPaths = []string{ - "file.txt", - "directory/", - } testFiles := []string{ "file.txt", "directory/file_a.txt", @@ -34,7 +28,12 @@ func TestGetContextFromPaths(t *testing.T) { createTestFiles(t, tmpDir, testFiles) - context := getContextFromPaths() + context := getContextFromPaths( + []string{ + "file.txt", + "directory/", + }, + ) expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir) assert.Equal(t, expectedContext, context) } diff --git a/internal/llm/prompt/summarizer.go b/internal/llm/prompt/summarizer.go index f5a1de0f8619252d99082c6ca54e152cc25a7bc7..77d98184bcf985ebb2bc569205b6b4cc77b3d601 100644 --- a/internal/llm/prompt/summarizer.go +++ b/internal/llm/prompt/summarizer.go @@ -1,8 +1,10 @@ package prompt -import "github.com/charmbracelet/crush/internal/llm/models" +import ( + "github.com/charmbracelet/crush/internal/fur/provider" +) -func SummarizerPrompt(_ models.InferenceProvider) string { +func SummarizerPrompt(_ provider.InferenceProvider) string { return `You are a helpful AI assistant tasked with summarizing conversations. When asked to summarize, provide a detailed but concise summary of the conversation. diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index 89acf1f02121ea008359eaa5201222061dad0cff..719c0ef45778814e38b391e86174708edcdd7c3e 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -3,10 +3,10 @@ package prompt import ( "fmt" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" ) -func TaskPrompt(_ models.InferenceProvider) string { +func TaskPrompt(_ provider.InferenceProvider) string { agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 0b3177b37857c24d299df0d6e64393cd60ea23eb..11bab4b6835ac0e53adc578cfddd3133f8b654e5 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,8 +1,10 @@ package prompt -import "github.com/charmbracelet/crush/internal/llm/models" +import ( + "github.com/charmbracelet/crush/internal/fur/provider" +) -func TitlePrompt(_ models.InferenceProvider) string { +func TitlePrompt(_ provider.InferenceProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 709a56263e0a8880d444c8ee7e9cab1373e67344..aca4d5b7f0adc4977fb349956be1005186e267e6 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -13,7 +13,7 @@ import ( "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/message" @@ -59,7 +59,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic var contentBlocks []anthropic.ContentBlockParamUnion contentBlocks = append(contentBlocks, content) for _, binaryContent := range msg.BinaryContent() { - base64Image := binaryContent.String(models.ProviderAnthropic) + base64Image := binaryContent.String(provider.InferenceProviderAnthropic) imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image) contentBlocks = append(contentBlocks, imageBlock) } @@ -164,7 +164,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to // } return anthropic.MessageNewParams{ - Model: anthropic.Model(a.providerOptions.model.APIModel), + Model: anthropic.Model(a.providerOptions.model.ID), MaxTokens: a.providerOptions.maxTokens, Temperature: temperature, Messages: messages, @@ -184,7 +184,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(preparedMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } @@ -233,7 +233,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message, func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { // jsonData, _ := json.Marshal(preparedMessages) // logging.Debug("Prepared messages", "messages", string(jsonData)) } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 8db9c1e84a4e8496be77e69e612de4abb9ce0c07..6b31c7d7fd6625ad7c2962f409f6c50f01ff726b 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "os" "strings" "github.com/charmbracelet/crush/internal/llm/tools" @@ -19,14 +18,8 @@ type bedrockClient struct { type BedrockClient ProviderClient func newBedrockClient(opts providerClientOptions) BedrockClient { - // Apply bedrock specific options if they are added in the future - // Get AWS region from environment - region := os.Getenv("AWS_REGION") - if region == "" { - region = os.Getenv("AWS_DEFAULT_REGION") - } - + region := opts.extraParams["region"] if region == "" { region = "us-east-1" // default region } @@ -39,11 +32,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { // Prefix the model name with region regionPrefix := region[:2] - modelName := opts.model.APIModel - opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName) + modelName := opts.model.ID + opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) // Determine which provider to use based on the model - if strings.Contains(string(opts.model.APIModel), "anthropic") { + if strings.Contains(string(opts.model.ID), "anthropic") { // Create Anthropic client with Bedrock configuration anthropicOpts := opts // TODO: later find a way to check if the AWS account has caching enabled diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index dd54dac4491634de06a31ee00f1ffd13ea935076..a91c1eae2427a7629ee1f4de6d6b9abb5944a972 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -157,7 +157,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too geminiMessages := g.convertMessages(messages) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(geminiMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } @@ -173,7 +173,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history) + chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history) attempts := 0 for { @@ -245,7 +245,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t geminiMessages := g.convertMessages(messages) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(geminiMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } @@ -261,7 +261,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history) + chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history) attempts := 0 eventChan := make(chan ProviderEvent) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 334312f9e8c41f5d68251d9e7bbd890074fa3982..448ab3674f25053453f51c0f48475db5699ee913 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -9,7 +9,7 @@ import ( "time" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/message" @@ -68,7 +68,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) for _, binaryContent := range msg.BinaryContent() { - imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)} + imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)} imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) @@ -153,7 +153,7 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(o.providerOptions.model.APIModel), + Model: openai.ChatModel(o.providerOptions.model.ID), Messages: messages, Tools: tools, } @@ -180,7 +180,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(params) logging.Debug("Prepared messages", "messages", string(jsonData)) } @@ -237,7 +237,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t } cfg := config.Get() - if cfg.Debug { + if cfg.Options.Debug { jsonData, _ := json.Marshal(params) logging.Debug("Prepared messages", "messages", string(jsonData)) } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 0d98b74c3292c0aa066dfd0676445e587b800b57..3152cd6a9a7e6fd6a68d0e6b54b6ea6853a38273 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -3,9 +3,9 @@ package provider import ( "context" "fmt" - "os" - "github.com/charmbracelet/crush/internal/llm/models" + configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -55,17 +55,18 @@ type Provider interface { StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() models.Model + Model() configv2.Model } type providerClientOptions struct { baseURL string apiKey string - model models.Model + model configv2.Model disableCache bool maxTokens int64 systemMessage string extraHeaders map[string]string + extraParams map[string]string } type ProviderClientOption func(*providerClientOptions) @@ -80,77 +81,6 @@ type baseProvider[C ProviderClient] struct { client C } -func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) { - clientOptions := providerClientOptions{} - for _, o := range opts { - o(&clientOptions) - } - switch providerName { - case models.ProviderAnthropic: - return &baseProvider[AnthropicClient]{ - options: clientOptions, - client: newAnthropicClient(clientOptions, false), - }, nil - case models.ProviderOpenAI: - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderGemini: - return &baseProvider[GeminiClient]{ - options: clientOptions, - client: newGeminiClient(clientOptions), - }, nil - case models.ProviderBedrock: - return &baseProvider[BedrockClient]{ - options: clientOptions, - client: newBedrockClient(clientOptions), - }, nil - case models.ProviderGROQ: - clientOptions.baseURL = "https://api.groq.com/openai/v1" - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderAzure: - return &baseProvider[AzureClient]{ - options: clientOptions, - client: newAzureClient(clientOptions), - }, nil - case models.ProviderVertexAI: - return &baseProvider[VertexAIClient]{ - options: clientOptions, - client: newVertexAIClient(clientOptions), - }, nil - case models.ProviderOpenRouter: - clientOptions.baseURL = "https://openrouter.ai/api/v1" - clientOptions.extraHeaders = map[string]string{ - "HTTP-Referer": "crush.charm.land", - "X-Title": "Crush", - } - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderXAI: - clientOptions.baseURL = "https://api.x.ai/v1" - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderLocal: - clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT") - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case models.ProviderMock: - // TODO: implement mock client for test - panic("not implemented") - } - return nil, fmt.Errorf("provider not supported: %s", providerName) -} - func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { for _, msg := range messages { // The message has no content @@ -167,7 +97,7 @@ func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.M return p.client.send(ctx, messages, tools) } -func (p *baseProvider[C]) Model() models.Model { +func (p *baseProvider[C]) Model() configv2.Model { return p.options.model } @@ -176,7 +106,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message return p.client.stream(ctx, messages, tools) } -func WithModel(model models.Model) ProviderClientOption { +func WithModel(model configv2.Model) ProviderClientOption { return func(options *providerClientOptions) { options.model = model } @@ -199,3 +129,53 @@ func WithSystemMessage(systemMessage string) ProviderClientOption { options.systemMessage = systemMessage } } + +func NewProviderV2(cfg configv2.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { + clientOptions := providerClientOptions{ + baseURL: cfg.BaseURL, + apiKey: cfg.APIKey, + extraHeaders: cfg.ExtraHeaders, + } + for _, o := range opts { + o(&clientOptions) + } + switch cfg.ProviderType { + case provider.TypeAnthropic: + return &baseProvider[AnthropicClient]{ + options: clientOptions, + client: newAnthropicClient(clientOptions, false), + }, nil + case provider.TypeOpenAI: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + case provider.TypeGemini: + return &baseProvider[GeminiClient]{ + options: clientOptions, + client: newGeminiClient(clientOptions), + }, nil + case provider.TypeBedrock: + return &baseProvider[BedrockClient]{ + options: clientOptions, + client: newBedrockClient(clientOptions), + }, nil + case provider.TypeAzure: + return &baseProvider[AzureClient]{ + options: clientOptions, + client: newAzureClient(clientOptions), + }, nil + case provider.TypeVertexAI: + return &baseProvider[VertexAIClient]{ + options: clientOptions, + client: newVertexAIClient(clientOptions), + }, nil + case provider.TypeXAI: + clientOptions.baseURL = "https://api.x.ai/v1" + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + } + return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType) +} diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index 49374d33fa81ab42e9f0c4d6e7905bfa37a6154e..2d95ad3f60db22e1338db3931b0900e83bccab52 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -2,7 +2,6 @@ package provider import ( "context" - "os" "github.com/charmbracelet/crush/internal/logging" "google.golang.org/genai" @@ -11,9 +10,11 @@ import ( type VertexAIClient ProviderClient func newVertexAIClient(opts providerClientOptions) VertexAIClient { + project := opts.extraHeaders["project"] + location := opts.extraHeaders["location"] client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ - Project: os.Getenv("GOOGLE_CLOUD_PROJECT"), - Location: os.Getenv("GOOGLE_CLOUD_LOCATION"), + Project: project, + Location: location, Backend: genai.BackendVertexAI, }) if err != nil { diff --git a/internal/lsp/client.go b/internal/lsp/client.go index f65b3dee20a3ee0264742257ca78a116661f1165..c04f10a8a924f8725609aace7d5363fe1751a791 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -286,7 +286,7 @@ func (c *Client) SetServerState(state ServerState) { // WaitForServerReady waits for the server to be ready by polling the server // with a simple request until it responds successfully or times out func (c *Client) WaitForServerReady(ctx context.Context) error { - cnf := config.Get() + cfg := config.Get() // Set initial state c.SetServerState(StateStarting) @@ -299,7 +299,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Waiting for LSP server to be ready...") } @@ -308,7 +308,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { // For TypeScript-like servers, we need to open some key files first if serverType == ServerTypeTypeScript { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("TypeScript-like server detected, opening key configuration files") } c.openKeyConfigFiles(ctx) @@ -325,7 +325,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { if err == nil { // Server responded successfully c.SetServerState(StateReady) - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("LSP server is ready") } return nil @@ -333,7 +333,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType) } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType) } } @@ -496,7 +496,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error { // openTypeScriptFiles finds and opens TypeScript files to help initialize the server func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) { - cnf := config.Get() + cfg := config.Get() filesOpened := 0 maxFilesToOpen := 5 // Limit to a reasonable number of files @@ -526,7 +526,7 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) { // Try to open the file if err := c.OpenFile(ctx, path); err == nil { filesOpened++ - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opened TypeScript file for initialization", "file", path) } } @@ -535,11 +535,11 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) { return nil }) - if err != nil && cnf.DebugLSP { + if err != nil && cfg.Options.DebugLSP { logging.Debug("Error walking directory for TypeScript files", "error", err) } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opened TypeScript files for initialization", "count", filesOpened) } } @@ -664,7 +664,7 @@ func (c *Client) NotifyChange(ctx context.Context, filepath string) error { } func (c *Client) CloseFile(ctx context.Context, filepath string) error { - cnf := config.Get() + cfg := config.Get() uri := string(protocol.URIFromPath(filepath)) c.openFilesMu.Lock() @@ -680,7 +680,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error { }, } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Closing file", "file", filepath) } if err := c.Notify(ctx, "textDocument/didClose", params); err != nil { @@ -704,7 +704,7 @@ func (c *Client) IsFileOpen(filepath string) bool { // CloseAllFiles closes all currently open files func (c *Client) CloseAllFiles(ctx context.Context) { - cnf := config.Get() + cfg := config.Get() c.openFilesMu.Lock() filesToClose := make([]string, 0, len(c.openFiles)) @@ -719,12 +719,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) { // Then close them all for _, filePath := range filesToClose { err := c.CloseFile(ctx, filePath) - if err != nil && cnf.DebugLSP { + if err != nil && cfg.Options.DebugLSP { logging.Warn("Error closing file", "file", filePath, "error", err) } } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Closed all files", "files", filesToClose) } } diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 9eb258d761ee36a909cddec16b72b2a3d933a5b4..f2fbfd0a589651590185fe9f73fc222e5bd6b08d 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -82,13 +82,13 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche // Notifications func HandleServerMessage(params json.RawMessage) { - cnf := config.Get() + cfg := config.Get() var msg struct { Type int `json:"type"` Message string `json:"message"` } if err := json.Unmarshal(params, &msg); err == nil { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Server message", "type", msg.Type, "message", msg.Message) } } diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index c3d5d762feeccaaa363a189fd8014b705a583681..5433fb552d6ee3dae390dcf74e3e1d9c8b0d74f9 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -18,9 +18,9 @@ func WriteMessage(w io.Writer, msg *Message) error { if err != nil { return fmt.Errorf("failed to marshal message: %w", err) } - cnf := config.Get() + cfg := config.Get() - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID) } @@ -39,7 +39,7 @@ func WriteMessage(w io.Writer, msg *Message) error { // ReadMessage reads a single LSP message from the given reader func ReadMessage(r *bufio.Reader) (*Message, error) { - cnf := config.Get() + cfg := config.Get() // Read headers var contentLength int for { @@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } line = strings.TrimSpace(line) - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received header", "line", line) } @@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Content-Length", "length", contentLength) } @@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { return nil, fmt.Errorf("failed to read content: %w", err) } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received content", "content", string(content)) } @@ -91,11 +91,11 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { // handleMessages reads and dispatches messages in a loop func (c *Client) handleMessages() { - cnf := config.Get() + cfg := config.Get() for { msg, err := ReadMessage(c.stdout) if err != nil { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Error("Error reading message", "error", err) } return @@ -103,7 +103,7 @@ func (c *Client) handleMessages() { // Handle server->client request (has both Method and ID) if msg.Method != "" && msg.ID != 0 { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID) } @@ -157,11 +157,11 @@ func (c *Client) handleMessages() { c.notificationMu.RUnlock() if ok { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Handling notification", "method", msg.Method) } go handler(msg.Params) - } else if cnf.DebugLSP { + } else if cfg.Options.DebugLSP { logging.Debug("No handler for notification", "method", msg.Method) } continue @@ -174,12 +174,12 @@ func (c *Client) handleMessages() { c.handlersMu.RUnlock() if ok { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received response for request", "id", msg.ID) } ch <- msg close(ch) - } else if cnf.DebugLSP { + } else if cfg.Options.DebugLSP { logging.Debug("No handler for response", "id", msg.ID) } } @@ -188,10 +188,10 @@ func (c *Client) handleMessages() { // Call makes a request and waits for the response func (c *Client) Call(ctx context.Context, method string, params any, result any) error { - cnf := config.Get() + cfg := config.Get() id := c.nextID.Add(1) - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Making call", "method", method, "id", id) } @@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any return fmt.Errorf("failed to send request: %w", err) } - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Request sent", "method", method, "id", id) } // Wait for response resp := <-ch - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Received response", "id", id) } @@ -249,8 +249,8 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any // Notify sends a notification (a request without an ID that doesn't expect a response) func (c *Client) Notify(ctx context.Context, method string, params any) error { - cnf := config.Get() - if cnf.DebugLSP { + cfg := config.Get() + if cfg.Options.DebugLSP { logging.Debug("Sending notification", "method", method) } diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index a69b3c10577d0c89ffb8aa9972a928201e2124f6..3c2dc05909bd8e6d473696efd09d22435f68dc10 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -43,7 +43,7 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher { // AddRegistrations adds file watchers to track func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) { - cnf := config.Get() + cfg := config.Get() logging.Debug("Adding file watcher registrations") w.registrationMu.Lock() @@ -53,7 +53,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc w.registrations = append(w.registrations, watchers...) // Print detailed registration information for debugging - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Adding file watcher registrations", "id", id, "watchers", len(watchers), @@ -122,7 +122,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc highPriorityFilesOpened := w.openHighPriorityFiles(ctx, serverName) filesOpened += highPriorityFilesOpened - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opened high-priority files", "count", highPriorityFilesOpened, "serverName", serverName) @@ -130,7 +130,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // If we've already opened enough high-priority files, we might not need more if filesOpened >= maxFilesToOpen { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Reached file limit with high-priority files", "filesOpened", filesOpened, "maxFiles", maxFilesToOpen) @@ -148,7 +148,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // Skip directories that should be excluded if d.IsDir() { if path != w.workspacePath && shouldExcludeDir(path) { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -176,7 +176,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc }) elapsedTime := time.Since(startTime) - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Limited workspace scan complete", "filesOpened", filesOpened, "maxFiles", maxFilesToOpen, @@ -185,11 +185,11 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc ) } - if err != nil && cnf.DebugLSP { + if err != nil && cfg.Options.DebugLSP { logging.Debug("Error scanning workspace for files to open", "error", err) } }() - } else if cnf.DebugLSP { + } else if cfg.Options.DebugLSP { logging.Debug("Using on-demand file loading for server", "server", serverName) } } @@ -197,7 +197,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // openHighPriorityFiles opens important files for the server type // Returns the number of files opened func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName string) int { - cnf := config.Get() + cfg := config.Get() filesOpened := 0 // Define patterns for high-priority files based on server type @@ -265,7 +265,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName // Use doublestar.Glob to find files matching the pattern (supports ** patterns) matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern) if err != nil { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err) } continue @@ -299,12 +299,12 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName for j := i; j < end; j++ { fullPath := filesToOpen[j] if err := w.client.OpenFile(ctx, fullPath); err != nil { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Error opening high-priority file", "path", fullPath, "error", err) } } else { filesOpened++ - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opened high-priority file", "path", fullPath) } } @@ -321,7 +321,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName // WatchWorkspace sets up file watching for a workspace func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath string) { - cnf := config.Get() + cfg := config.Get() w.workspacePath = workspacePath // Store the watcher in the context for later use @@ -356,7 +356,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str // Skip excluded directories (except workspace root) if d.IsDir() && path != workspacePath { if shouldExcludeDir(path) { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -409,7 +409,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str } // Debug logging - if cnf.DebugLSP { + if cfg.Options.DebugLSP { matched, kind := w.isPathWatched(event.Name) logging.Debug("File event", "path", event.Name, @@ -676,8 +676,8 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan // notifyFileEvent sends a didChangeWatchedFiles notification for a file event func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error { - cnf := config.Get() - if cnf.DebugLSP { + cfg := config.Get() + if cfg.Options.DebugLSP { logging.Debug("Notifying file event", "uri", uri, "changeType", changeType, @@ -826,7 +826,7 @@ func shouldExcludeDir(dirPath string) bool { // shouldExcludeFile returns true if the file should be excluded from opening func shouldExcludeFile(filePath string) bool { fileName := filepath.Base(filePath) - cnf := config.Get() + cfg := config.Get() // Skip dot files if strings.HasPrefix(fileName, ".") { return true @@ -852,12 +852,12 @@ func shouldExcludeFile(filePath string) bool { // Skip large files if info.Size() > maxFileSize { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Skipping large file", "path", filePath, "size", info.Size(), "maxSize", maxFileSize, - "debug", cnf.Debug, + "debug", cfg.Options.Debug, "sizeMB", float64(info.Size())/(1024*1024), "maxSizeMB", float64(maxFileSize)/(1024*1024), ) @@ -870,7 +870,7 @@ func shouldExcludeFile(filePath string) bool { // openMatchingFile opens a file if it matches any of the registered patterns func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { - cnf := config.Get() + cfg := config.Get() // Skip directories info, err := os.Stat(path) if err != nil || info.IsDir() { @@ -890,10 +890,10 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { // Check if the file is a high-priority file that should be opened immediately // This helps with project initialization for certain language servers if isHighPriorityFile(path, serverName) { - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Opening high-priority file", "path", path, "serverName", serverName) } - if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { + if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { logging.Error("Error opening high-priority file", "path", path, "error", err) } return @@ -905,7 +905,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { // Check file size - for preloading we're more conservative if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files - if cnf.DebugLSP { + if cfg.Options.DebugLSP { logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size()) } return @@ -937,7 +937,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { if shouldOpen { // Don't need to check if it's already open - the client.OpenFile handles that - if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { + if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP { logging.Error("Error opening file", "path", path, "error", err) } } diff --git a/internal/message/content.go b/internal/message/content.go index b9e83ba4dd7fcc96216755a3871f0553b58d88d7..3ab53e381aaf7755c141985ebe740dbc44356471 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -5,7 +5,7 @@ import ( "slices" "time" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/fur/provider" ) type MessageRole string @@ -71,9 +71,9 @@ type BinaryContent struct { Data []byte } -func (bc BinaryContent) String(provider models.InferenceProvider) string { +func (bc BinaryContent) String(p provider.InferenceProvider) string { base64Encoded := base64.StdEncoding.EncodeToString(bc.Data) - if provider == models.ProviderOpenAI { + if p == provider.InferenceProviderOpenAI { return "data:" + bc.MIMEType + ";base64," + base64Encoded } return base64Encoded @@ -113,7 +113,8 @@ type Message struct { Role MessageRole SessionID string Parts []ContentPart - Model models.ModelID + Model string + Provider string CreatedAt int64 UpdatedAt int64 } diff --git a/internal/message/message.go b/internal/message/message.go index 9e241a0b011ee6277402709fdd8be3aefb5df6fe..7cd823bc3129df5f807ec478d9d6c02364c6cfec 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -8,15 +8,15 @@ import ( "time" "github.com/charmbracelet/crush/internal/db" - "github.com/charmbracelet/crush/internal/llm/models" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) type CreateMessageParams struct { - Role MessageRole - Parts []ContentPart - Model models.ModelID + Role MessageRole + Parts []ContentPart + Model string + Provider string } type Service interface { @@ -70,6 +70,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes Role: string(params.Role), Parts: string(partsJSON), Model: sql.NullString{String: string(params.Model), Valid: true}, + Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""}, }) if err != nil { return Message{}, err @@ -154,7 +155,8 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { SessionID: item.SessionID, Role: MessageRole(item.Role), Parts: parts, - Model: models.ModelID(item.Model.String), + Model: item.Model.String, + Provider: item.Provider.String, CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, }, nil diff --git a/internal/tui/components/chat/header/header.go b/internal/tui/components/chat/header/header.go index 78620161a75a3ade2e0e2416351c50699ac8bd4d..d924bdc3453dc3fce0351c490cb17b726fcc2549 100644 --- a/internal/tui/components/chat/header/header.go +++ b/internal/tui/components/chat/header/header.go @@ -7,7 +7,6 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/llm/models" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" "github.com/charmbracelet/crush/internal/pubsub" @@ -112,11 +111,7 @@ func (h *header) details() string { parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount))) } - cfg := config.Get() - agentCfg := cfg.Agents[config.AgentCoder] - selectedModelID := agentCfg.Model - model := models.SupportedModels[selectedModelID] - + model := config.GetAgentModel(config.AgentCoder) percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100 formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage))) parts = append(parts, formattedPercentage) diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index d5e95b4e3ebded500f73840fda483d3be53ca71d..98d8b2979a90f46fa5901bc77d1e8b4a5105f04d 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -10,7 +10,8 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" - "github.com/charmbracelet/crush/internal/llm/models" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/tui/components/anim" "github.com/charmbracelet/crush/internal/tui/components/core" @@ -290,8 +291,9 @@ func (m *assistantSectionModel) View() tea.View { duration := finishTime.Sub(m.lastUserMessageTime) infoMsg := t.S().Subtle.Render(duration.String()) icon := t.S().Subtle.Render(styles.ModelIcon) - model := t.S().Muted.Render(models.SupportedModels[m.message.Model].Name) - assistant := fmt.Sprintf("%s %s %s", icon, model, infoMsg) + model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model) + modelFormatted := t.S().Muted.Render(model.Name) + assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg) return tea.NewView( t.S().Base.PaddingLeft(2).Render( core.Section(assistant, m.width-2), diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 405bd1f0f8c7891db1958e70f97e290dd9a8d411..bfcc74c43a2727138d479af647ba461bdcc7520c 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -13,7 +13,6 @@ import ( "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/llm/models" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" @@ -406,7 +405,7 @@ func (m *sidebarCmp) mcpBlock() string { mcpList := []string{section, ""} - mcp := config.Get().MCPServers + mcp := config.Get().MCP if len(mcp) == 0 { return lipgloss.JoinVertical( lipgloss.Left, @@ -475,10 +474,7 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string { } func (s *sidebarCmp) currentModelBlock() string { - cfg := config.Get() - agentCfg := cfg.Agents[config.AgentCoder] - selectedModelID := agentCfg.Model - model := models.SupportedModels[selectedModelID] + model := config.GetAgentModel(config.AgentCoder) t := styles.CurrentTheme() diff --git a/internal/tui/components/dialogs/commands/loader.go b/internal/tui/components/dialogs/commands/loader.go index 9f70afa3cd60342028b6d3fd00e017221c179686..9aee528ee48d0f23e48c417f8bee5bc0e3f381c5 100644 --- a/internal/tui/components/dialogs/commands/loader.go +++ b/internal/tui/components/dialogs/commands/loader.go @@ -63,7 +63,7 @@ func buildCommandSources(cfg *config.Config) []commandSource { // Project directory sources = append(sources, commandSource{ - path: filepath.Join(cfg.Data.Directory, "commands"), + path: filepath.Join(cfg.Options.DataDirectory, "commands"), prefix: ProjectCommandPrefix, }) diff --git a/internal/tui/components/dialogs/init/init.go b/internal/tui/components/dialogs/init/init.go index 74d0dc0b3d9d4630b28c4b240fb17fbe611ba21f..4e331198f5984f81db87332e3c998d9477810806 100644 --- a/internal/tui/components/dialogs/init/init.go +++ b/internal/tui/components/dialogs/init/init.go @@ -5,7 +5,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" - "github.com/charmbracelet/crush/internal/config" + configv2 "github.com/charmbracelet/crush/internal/config" cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" @@ -184,7 +184,7 @@ If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules ( Add the .crush directory to the .gitignore file if it's not already there.` // Mark the project as initialized - if err := config.MarkProjectInitialized(); err != nil { + if err := configv2.MarkProjectInitialized(); err != nil { return util.ReportError(err) } @@ -196,7 +196,7 @@ Add the .crush directory to the .gitignore file if it's not already there.` ) } else { // Mark the project as initialized without running the command - if err := config.MarkProjectInitialized(); err != nil { + if err := configv2.MarkProjectInitialized(); err != nil { return util.ReportError(err) } } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 0197b7141560a67008ceac64c31756bd19fff74a..b5f87b16681ea17e2fb303a4b52a3a83ae30eb85 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -1,13 +1,11 @@ package models import ( - "slices" - "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" + configv2 "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/list" @@ -26,7 +24,7 @@ const ( // ModelSelectedMsg is sent when a model is selected type ModelSelectedMsg struct { - Model models.Model + Model configv2.PreferredModel } // CloseModelDialogMsg is sent when a model is selected @@ -37,6 +35,11 @@ type ModelDialog interface { dialogs.DialogModel } +type ModelOption struct { + Provider provider.Provider + Model provider.Model +} + type modelDialogCmp struct { width int wWidth int // Width of the terminal window @@ -80,47 +83,31 @@ func NewModelDialogCmp() ModelDialog { } } -var ProviderPopularity = map[models.InferenceProvider]int{ - models.ProviderAnthropic: 1, - models.ProviderOpenAI: 2, - models.ProviderGemini: 3, - models.ProviderGROQ: 4, - models.ProviderOpenRouter: 5, - models.ProviderBedrock: 6, - models.ProviderAzure: 7, - models.ProviderVertexAI: 8, - models.ProviderXAI: 9, -} - -var ProviderName = map[models.InferenceProvider]string{ - models.ProviderAnthropic: "Anthropic", - models.ProviderOpenAI: "OpenAI", - models.ProviderGemini: "Gemini", - models.ProviderGROQ: "Groq", - models.ProviderOpenRouter: "OpenRouter", - models.ProviderBedrock: "AWS Bedrock", - models.ProviderAzure: "Azure", - models.ProviderVertexAI: "VertexAI", - models.ProviderXAI: "xAI", -} - func (m *modelDialogCmp) Init() tea.Cmd { - cfg := config.Get() - enabledProviders := getEnabledProviders(cfg) + providers := configv2.Providers() + cfg := configv2.Get() + coderAgent := cfg.Agents[configv2.AgentCoder] modelItems := []util.Model{} - for _, provider := range enabledProviders { - name, ok := ProviderName[provider] - if !ok { - name = string(provider) // Fallback to provider ID if name is not defined + selectIndex := 0 + for _, provider := range providers { + name := provider.Name + if name == "" { + name = string(provider.ID) } modelItems = append(modelItems, commands.NewItemSection(name)) - for _, model := range getModelsForProvider(provider) { - modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model)) + for _, model := range provider.Models { + if model.ID == coderAgent.Model && provider.ID == coderAgent.Provider { + selectIndex = len(modelItems) // Set the selected index to the current model + } + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ + Provider: provider, + Model: model, + })) } } - m.modelList.SetItems(modelItems) - return m.modelList.Init() + + return tea.Sequence(m.modelList.Init(), m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex)) } func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -137,11 +124,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil // No item selected, do nothing } items := m.modelList.Items() - selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model) + selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption) return m, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), - util.CmdHandler(ModelSelectedMsg{Model: selectedItem}), + util.CmdHandler(ModelSelectedMsg{Model: configv2.PreferredModel{ + ModelID: selectedItem.Model.ID, + Provider: selectedItem.Provider.ID, + }}), ) case key.Matches(msg, m.keyMap.Close): return m, util.CmdHandler(dialogs.CloseDialogMsg{}) @@ -189,58 +179,6 @@ func (m *modelDialogCmp) listHeight() int { return min(listHeigh, m.wHeight/2) } -func GetSelectedModel(cfg *config.Config) models.Model { - agentCfg := cfg.Agents[config.AgentCoder] - selectedModelID := agentCfg.Model - return models.SupportedModels[selectedModelID] -} - -func getEnabledProviders(cfg *config.Config) []models.InferenceProvider { - var providers []models.InferenceProvider - for providerID, provider := range cfg.Providers { - if !provider.Disabled { - providers = append(providers, providerID) - } - } - - // Sort by provider popularity - slices.SortFunc(providers, func(a, b models.InferenceProvider) int { - rA := ProviderPopularity[a] - rB := ProviderPopularity[b] - - // models not included in popularity ranking default to last - if rA == 0 { - rA = 999 - } - if rB == 0 { - rB = 999 - } - return rA - rB - }) - return providers -} - -func getModelsForProvider(provider models.InferenceProvider) []models.Model { - var providerModels []models.Model - for _, model := range models.SupportedModels { - if model.Provider == provider { - providerModels = append(providerModels, model) - } - } - - // reverse alphabetical order (if llm naming was consistent latest would appear first) - slices.SortFunc(providerModels, func(a, b models.Model) int { - if a.Name > b.Name { - return -1 - } else if a.Name < b.Name { - return 1 - } - return 0 - }) - - return providerModels -} - func (m *modelDialogCmp) Position() (int, int) { row := m.wHeight/4 - 2 // just a bit above the center col := m.wWidth / 2 diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index ffb6debb0f61cb1fcfa7e180b042b3b8325dd2e5..44d623847765175d3c38eb81122fa3d55abc430d 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -9,7 +9,6 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/models" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/tui/components/chat" @@ -171,14 +170,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { util.CmdHandler(ChatFocusedMsg{Focused: false}), ) case key.Matches(msg, p.keyMap.AddAttachment): - cfg := config.Get() - agentCfg := cfg.Agents[config.AgentCoder] - selectedModelID := agentCfg.Model - model := models.SupportedModels[selectedModelID] - if model.SupportsAttachments { + model := config.GetAgentModel(config.AgentCoder) + if model.SupportsImages { return p, util.CmdHandler(OpenFilePickerMsg{}) } else { - return p, util.ReportWarn("File attachments are not supported by the current model: " + string(selectedModelID)) + return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name) } case key.Matches(msg, p.keyMap.Tab): if p.session.ID == "" { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index c6dee6532993becfbda24d115b8e1e5d05e4fd60..54978b53576940e6fa478b7d05af514f66641acf 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" + configv2 "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/permission" @@ -69,7 +70,7 @@ func (a appModel) Init() tea.Cmd { // Check if we should show the init dialog cmds = append(cmds, func() tea.Msg { - shouldShow, err := config.ShouldShowInitDialog() + shouldShow, err := configv2.ProjectNeedsInitialization() if err != nil { return util.InfoMsg{ Type: util.InfoTypeError, @@ -172,7 +173,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Model Switch case models.ModelSelectedMsg: - model, err := a.app.CoderAgent.Update(config.AgentCoder, msg.Model.ID) + model, err := a.app.CoderAgent.Update(msg.Model) if err != nil { return a, util.ReportError(err) } @@ -222,7 +223,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { model := a.app.CoderAgent.Model() contextWindow := model.ContextWindow tokens := session.CompletionTokens + session.PromptTokens - if (tokens >= int64(float64(contextWindow)*0.95)) && config.Get().AutoCompact { + if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{ Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false),