Detailed changes
@@ -72,7 +72,8 @@ to assist developers in writing, debugging, and understanding code directly from
}
cwd = c
}
- _, err := config.Load(cwd, debug)
+
+ _, err := config.Init(cwd, debug)
if err != nil {
return err
}
@@ -1,3 +1,4 @@
+// TODO: FIX THIS
package main
import (
@@ -6,7 +7,6 @@ import (
"os"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
)
// JSONSchemaType represents a JSON Schema type
@@ -192,22 +192,10 @@ func generateSchema() map[string]any {
},
}
- // Add known providers
- knownProviders := []string{
- string(models.ProviderAnthropic),
- string(models.ProviderOpenAI),
- string(models.ProviderGemini),
- string(models.ProviderGROQ),
- string(models.ProviderOpenRouter),
- string(models.ProviderBedrock),
- string(models.ProviderAzure),
- string(models.ProviderVertexAI),
- }
-
providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{
"type": "string",
"description": "Provider type",
- "enum": knownProviders,
+ "enum": []string{},
}
schema["properties"].(map[string]any)["providers"] = providerSchema
@@ -241,9 +229,7 @@ func generateSchema() map[string]any {
// Add model enum
modelEnum := []string{}
- for modelID := range models.SupportedModels {
- modelEnum = append(modelEnum, string(modelID))
- }
+
agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum
// Add specific agent properties
@@ -251,7 +237,6 @@ func generateSchema() map[string]any {
knownAgents := []string{
string(config.AgentCoder),
string(config.AgentTask),
- string(config.AgentTitle),
}
for _, agentName := range knownAgents {
@@ -9,7 +9,7 @@ import (
"sync"
"time"
- "github.com/charmbracelet/crush/internal/config"
+ configv2 "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/format"
"github.com/charmbracelet/crush/internal/history"
@@ -55,18 +55,21 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
// Initialize LSP clients in the background
go app.initLSPClients(ctx)
+ cfg := configv2.Get()
+
+ coderAgentCfg := cfg.Agents[configv2.AgentCoder]
+ if coderAgentCfg.ID == "" {
+ return nil, fmt.Errorf("coder agent configuration is missing")
+ }
+
var err error
app.CoderAgent, err = agent.NewAgent(
- config.AgentCoder,
+ coderAgentCfg,
+ app.Permissions,
app.Sessions,
app.Messages,
- agent.CoderAgentTools(
- app.Permissions,
- app.Sessions,
- app.Messages,
- app.History,
- app.LSPClients,
- ),
+ app.History,
+ app.LSPClients,
)
if err != nil {
logging.Error("Failed to create coder agent", err)
@@ -1,67 +1,132 @@
-// Package config manages application configuration from various sources.
package config
import (
"encoding/json"
+ "errors"
"fmt"
"log/slog"
+ "maps"
"os"
"path/filepath"
+ "slices"
"strings"
+ "sync"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/logging"
- "github.com/spf13/afero"
- "github.com/spf13/viper"
)
-// MCPType defines the type of MCP (Model Control Protocol) server.
-type MCPType string
-
-// Supported MCP types
const (
- MCPStdio MCPType = "stdio"
- MCPSse MCPType = "sse"
+ defaultDataDirectory = ".crush"
+ defaultLogLevel = "info"
+ appName = "crush"
+
+ MaxTokensFallbackDefault = 4096
)
-// MCPServer defines the configuration for a Model Control Protocol server.
-type MCPServer struct {
- Command string `json:"command"`
- Env []string `json:"env"`
- Args []string `json:"args"`
- Type MCPType `json:"type"`
- URL string `json:"url"`
- Headers map[string]string `json:"headers"`
+var defaultContextPaths = []string{
+ ".github/copilot-instructions.md",
+ ".cursorrules",
+ ".cursor/rules/",
+ "CLAUDE.md",
+ "CLAUDE.local.md",
+ "GEMINI.md",
+ "gemini.md",
+ "crush.md",
+ "crush.local.md",
+ "Crush.md",
+ "Crush.local.md",
+ "CRUSH.md",
+ "CRUSH.local.md",
}
-type AgentName string
+type AgentID string
const (
- AgentCoder AgentName = "coder"
- AgentSummarizer AgentName = "summarizer"
- AgentTask AgentName = "task"
- AgentTitle AgentName = "title"
+ AgentCoder AgentID = "coder"
+ AgentTask AgentID = "task"
)
-// Agent defines configuration for different LLM models and their token limits.
-type Agent struct {
- Model models.ModelID `json:"model"`
- MaxTokens int64 `json:"maxTokens"`
- ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh
+type Model struct {
+ ID string `json:"id"`
+ Name string `json:"model"`
+ CostPer1MIn float64 `json:"cost_per_1m_in"`
+ CostPer1MOut float64 `json:"cost_per_1m_out"`
+ CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
+ CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
+ ContextWindow int64 `json:"context_window"`
+ DefaultMaxTokens int64 `json:"default_max_tokens"`
+ CanReason bool `json:"can_reason"`
+ ReasoningEffort string `json:"reasoning_effort"`
+ SupportsImages bool `json:"supports_attachments"`
}
-// Provider defines configuration for an LLM provider.
-type Provider struct {
- APIKey string `json:"apiKey"`
- Disabled bool `json:"disabled"`
+type VertexAIOptions struct {
+ APIKey string `json:"api_key,omitempty"`
+ Project string `json:"project,omitempty"`
+ Location string `json:"location,omitempty"`
}
-// Data defines storage configuration.
-type Data struct {
- Directory string `json:"directory,omitempty"`
+type ProviderConfig struct {
+ ID provider.InferenceProvider `json:"id"`
+ BaseURL string `json:"base_url,omitempty"`
+ ProviderType provider.Type `json:"provider_type"`
+ APIKey string `json:"api_key,omitempty"`
+ Disabled bool `json:"disabled"`
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // used for e.x for vertex to set the project
+ ExtraParams map[string]string `json:"extra_params,omitempty"`
+
+ DefaultLargeModel string `json:"default_large_model,omitempty"`
+ DefaultSmallModel string `json:"default_small_model,omitempty"`
+
+ Models []Model `json:"models,omitempty"`
+}
+
+type Agent struct {
+ ID AgentID `json:"id"`
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ // This is the id of the system prompt used by the agent
+ Disabled bool `json:"disabled"`
+
+ Provider provider.InferenceProvider `json:"provider"`
+ Model string `json:"model"`
+
+ // The available tools for the agent
+ // if this is nil, all tools are available
+ AllowedTools []string `json:"allowed_tools"`
+
+ // this tells us which MCPs are available for this agent
+ // if this is empty all mcps are available
+ // the string array is the list of tools from the AllowedMCP the agent has available
+ // if the string array is nil, all tools from the AllowedMCP are available
+ AllowedMCP map[string][]string `json:"allowed_mcp"`
+
+ // The list of LSPs that this agent can use
+ // if this is nil, all LSPs are available
+ AllowedLSP []string `json:"allowed_lsp"`
+
+ // Overrides the context paths for this agent
+ ContextPaths []string `json:"context_paths"`
+}
+
+type MCPType string
+
+const (
+ MCPStdio MCPType = "stdio"
+ MCPSse MCPType = "sse"
+)
+
+type MCP struct {
+ Command string `json:"command"`
+ Env []string `json:"env"`
+ Args []string `json:"args"`
+ Type MCPType `json:"type"`
+ URL string `json:"url"`
+ Headers map[string]string `json:"headers"`
}
-// LSPConfig defines configuration for Language Server Protocol integration.
type LSPConfig struct {
Disabled bool `json:"enabled"`
Command string `json:"command"`
@@ -69,98 +134,72 @@ type LSPConfig struct {
Options any `json:"options"`
}
-// TUIConfig defines the configuration for the Terminal User Interface.
-type TUIConfig struct {
- Theme string `json:"theme,omitempty"`
+type TUIOptions struct {
+ CompactMode bool `json:"compact_mode"`
+ // Here we can add themes later or any TUI related options
}
-// Config is the main configuration structure for the application.
-type Config struct {
- Data Data `json:"data"`
- WorkingDir string `json:"wd,omitempty"`
- MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
- Providers map[models.InferenceProvider]Provider `json:"providers,omitempty"`
- LSP map[string]LSPConfig `json:"lsp,omitempty"`
- Agents map[AgentName]Agent `json:"agents,omitempty"`
- Debug bool `json:"debug,omitempty"`
- DebugLSP bool `json:"debugLSP,omitempty"`
- ContextPaths []string `json:"contextPaths,omitempty"`
- TUI TUIConfig `json:"tui"`
- AutoCompact bool `json:"autoCompact,omitempty"`
+type Options struct {
+ ContextPaths []string `json:"context_paths"`
+ TUI TUIOptions `json:"tui"`
+ Debug bool `json:"debug"`
+ DebugLSP bool `json:"debug_lsp"`
+ DisableAutoSummarize bool `json:"disable_auto_summarize"`
+ // Relative to the cwd
+ DataDirectory string `json:"data_directory"`
}
-// Application constants
-const (
- defaultDataDirectory = ".crush"
- defaultLogLevel = "info"
- appName = "crush"
-
- MaxTokensFallbackDefault = 4096
-)
+type PreferredModel struct {
+ ModelID string `json:"model_id"`
+ Provider provider.InferenceProvider `json:"provider"`
+}
-var defaultContextPaths = []string{
- ".github/copilot-instructions.md",
- ".cursorrules",
- ".cursor/rules/",
- "CLAUDE.md",
- "CLAUDE.local.md",
- "GEMINI.md",
- "gemini.md",
- "crush.md",
- "crush.local.md",
- "Crush.md",
- "Crush.local.md",
- "CRUSH.md",
- "CRUSH.local.md",
+type PreferredModels struct {
+ Large PreferredModel `json:"large"`
+ Small PreferredModel `json:"small"`
}
-// Global configuration instance
-var cfg *Config
+type Config struct {
+ Models PreferredModels `json:"models"`
+ // List of configured providers
+ Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
-// Load initializes the configuration from environment variables and config files.
-// If debug is true, debug mode is enabled and log level is set to debug.
-// It returns an error if configuration loading fails.
-func Load(workingDir string, debug bool) (*Config, error) {
- if cfg != nil {
- return cfg, nil
- }
+ // List of configured agents
+ Agents map[AgentID]Agent `json:"agents,omitempty"`
- cfg = &Config{
- WorkingDir: workingDir,
- MCPServers: make(map[string]MCPServer),
- Providers: make(map[models.InferenceProvider]Provider),
- LSP: make(map[string]LSPConfig),
- }
+ // List of configured MCPs
+ MCP map[string]MCP `json:"mcp,omitempty"`
- configureViper()
- setDefaults(debug)
+ // List of configured LSPs
+ LSP map[string]LSPConfig `json:"lsp,omitempty"`
- // Read global config
- if err := readConfig(viper.ReadInConfig()); err != nil {
- return cfg, err
- }
+ // Miscellaneous options
+ Options Options `json:"options"`
+}
- // Load and merge local config
- mergeLocalConfig(workingDir)
+var (
+ instance *Config // The single instance of the Singleton
+ cwd string
+ once sync.Once // Ensures the initialization happens only once
- setProviderDefaults()
+)
- // Apply configuration to the struct
- if err := viper.Unmarshal(cfg); err != nil {
- return cfg, fmt.Errorf("failed to unmarshal config: %w", err)
- }
+func loadConfig(cwd string, debug bool) (*Config, error) {
+ // First read the global config file
+ cfgPath := ConfigPath()
- applyDefaultValues()
+ cfg := defaultConfigBasedOnEnv()
+ cfg.Options.Debug = debug
defaultLevel := slog.LevelInfo
- if cfg.Debug {
+ if cfg.Options.Debug {
defaultLevel = slog.LevelDebug
}
if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
- loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log")
+ loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
// if file does not exist create it
if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
- if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil {
+ if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
return cfg, fmt.Errorf("failed to create directory: %w", err)
}
if _, err := os.Create(loggingFile); err != nil {
@@ -184,734 +223,530 @@ func Load(workingDir string, debug bool) (*Config, error) {
}))
slog.SetDefault(logger)
}
-
- // Validate configuration
- if err := Validate(); err != nil {
- return cfg, fmt.Errorf("config validation failed: %w", err)
- }
-
- if cfg.Agents == nil {
- cfg.Agents = make(map[AgentName]Agent)
- }
-
- // Override the max tokens for title agent
- cfg.Agents[AgentTitle] = Agent{
- Model: cfg.Agents[AgentTitle].Model,
- MaxTokens: 80,
- }
- return cfg, nil
-}
-
-type configFinder struct {
- appName string
- dotPrefix bool
- paths []string
-}
-
-func (f configFinder) Find(fsys afero.Fs) ([]string, error) {
- var configFiles []string
- configName := fmt.Sprintf("%s.json", f.appName)
- if f.dotPrefix {
- configName = fmt.Sprintf(".%s.json", f.appName)
- }
- paths := []string{}
- for _, p := range f.paths {
- if p == "" {
- continue
+ var globalCfg *Config
+ if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
+ // some other error occurred while checking the file
+ return nil, err
+ } else if err == nil {
+ // config file exists, read it
+ file, err := os.ReadFile(cfgPath)
+ if err != nil {
+ return nil, err
}
- paths = append(paths, os.ExpandEnv(p))
- }
-
- for _, path := range paths {
- if path == "" {
- continue
+ globalCfg = &Config{}
+ if err := json.Unmarshal(file, globalCfg); err != nil {
+ return nil, err
}
-
- configPath := filepath.Join(path, configName)
- if exists, err := afero.Exists(fsys, configPath); err == nil && exists {
- configFiles = append(configFiles, configPath)
+ } else {
+ // config file does not exist, create a new one
+ globalCfg = &Config{}
+ }
+
+ var localConfig *Config
+ // Global config loaded, now read the local config file
+ localConfigPath := filepath.Join(cwd, "crush.json")
+ if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
+ // some other error occurred while checking the file
+ return nil, err
+ } else if err == nil {
+ // local config file exists, read it
+ file, err := os.ReadFile(localConfigPath)
+ if err != nil {
+ return nil, err
+ }
+ localConfig = &Config{}
+ if err := json.Unmarshal(file, localConfig); err != nil {
+ return nil, err
}
}
- return configFiles, nil
-}
-// configureViper sets up viper's configuration paths and environment variables.
-func configureViper() {
- viper.SetConfigType("json")
-
- // Create the three finders
- windowsFinder := configFinder{appName: appName, dotPrefix: false, paths: []string{
- "$USERPROFILE",
- fmt.Sprintf("$APPDATA/%s", appName),
- fmt.Sprintf("$LOCALAPPDATA/%s", appName),
- }}
-
- unixFinder := configFinder{appName: appName, dotPrefix: false, paths: []string{
- "$HOME",
- fmt.Sprintf("$XDG_CONFIG_HOME/%s", appName),
- fmt.Sprintf("$HOME/.config/%s", appName),
- }}
-
- localFinder := configFinder{appName: appName, dotPrefix: true, paths: []string{
- ".",
- }}
-
- // Use all finders with viper
- viper.SetOptions(viper.WithFinder(viper.Finders(windowsFinder, unixFinder, localFinder)))
- viper.SetEnvPrefix(strings.ToUpper(appName))
- viper.AutomaticEnv()
-}
+ // merge options
+ mergeOptions(cfg, globalCfg, localConfig)
-// setDefaults configures default values for configuration options.
-func setDefaults(debug bool) {
- viper.SetDefault("data.directory", defaultDataDirectory)
- viper.SetDefault("contextPaths", defaultContextPaths)
- viper.SetDefault("tui.theme", "crush")
- viper.SetDefault("autoCompact", true)
-
- if debug {
- viper.SetDefault("debug", true)
- viper.Set("log.level", "debug")
- } else {
- viper.SetDefault("debug", false)
- viper.SetDefault("log.level", defaultLogLevel)
+ mergeProviderConfigs(cfg, globalCfg, localConfig)
+ // no providers found the app is not initialized yet
+ if len(cfg.Providers) == 0 {
+ return cfg, nil
}
-}
+ preferredProvider := getPreferredProvider(cfg.Providers)
+ cfg.Models = PreferredModels{
+ Large: PreferredModel{
+ ModelID: preferredProvider.DefaultLargeModel,
+ Provider: preferredProvider.ID,
+ },
+ Small: PreferredModel{
+ ModelID: preferredProvider.DefaultSmallModel,
+ Provider: preferredProvider.ID,
+ },
+ }
+
+ mergeModels(cfg, globalCfg, localConfig)
+
+ if preferredProvider == nil {
+ return nil, errors.New("no valid providers configured")
+ }
+
+ agents := map[AgentID]Agent{
+ AgentCoder: {
+ ID: AgentCoder,
+ Name: "Coder",
+ Description: "An agent that helps with executing coding tasks.",
+ Provider: cfg.Models.Large.Provider,
+ Model: cfg.Models.Large.ModelID,
+ ContextPaths: cfg.Options.ContextPaths,
+ // All tools allowed
+ },
+ AgentTask: {
+ ID: AgentTask,
+ Name: "Task",
+ Description: "An agent that helps with searching for context and finding implementation details.",
+ Provider: cfg.Models.Large.Provider,
+ Model: cfg.Models.Large.ModelID,
+ ContextPaths: cfg.Options.ContextPaths,
+ AllowedTools: []string{
+ "glob",
+ "grep",
+ "ls",
+ "sourcegraph",
+ "view",
+ },
+ // NO MCPs or LSPs by default
+ AllowedMCP: map[string][]string{},
+ AllowedLSP: []string{},
+ },
+ }
+ cfg.Agents = agents
+ mergeAgents(cfg, globalCfg, localConfig)
+ mergeMCPs(cfg, globalCfg, localConfig)
+ mergeLSPs(cfg, globalCfg, localConfig)
-// setProviderDefaults configures LLM provider defaults based on provider provided by
-// environment variables and configuration file.
-func setProviderDefaults() {
- // Set all API keys we can find in the environment
- if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.anthropic.apiKey", apiKey)
- }
- if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.openai.apiKey", apiKey)
- }
- if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.gemini.apiKey", apiKey)
- }
- if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.groq.apiKey", apiKey)
- }
- if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.openrouter.apiKey", apiKey)
- }
- if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" {
- viper.SetDefault("providers.xai.apiKey", apiKey)
- }
- if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" {
- // api-key may be empty when using Entra ID credentials – that's okay
- viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
- }
-
- // Use this order to set the default models
- // 1. Anthropic
- // 2. OpenAI
- // 3. Google Gemini
- // 4. Groq
- // 5. OpenRouter
- // 6. AWS Bedrock
- // 7. Azure
- // 8. Google Cloud VertexAI
-
- // Anthropic configuration
- if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.Claude4Sonnet)
- viper.SetDefault("agents.summarizer.model", models.Claude4Sonnet)
- viper.SetDefault("agents.task.model", models.Claude4Sonnet)
- viper.SetDefault("agents.title.model", models.Claude4Sonnet)
- return
- }
-
- // OpenAI configuration
- if key := viper.GetString("providers.openai.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.GPT41)
- viper.SetDefault("agents.summarizer.model", models.GPT41)
- viper.SetDefault("agents.task.model", models.GPT41Mini)
- viper.SetDefault("agents.title.model", models.GPT41Mini)
- return
- }
-
- // Google Gemini configuration
- if key := viper.GetString("providers.gemini.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.Gemini25)
- viper.SetDefault("agents.summarizer.model", models.Gemini25)
- viper.SetDefault("agents.task.model", models.Gemini25Flash)
- viper.SetDefault("agents.title.model", models.Gemini25Flash)
- return
- }
-
- // Groq configuration
- if key := viper.GetString("providers.groq.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.QWENQwq)
- viper.SetDefault("agents.summarizer.model", models.QWENQwq)
- viper.SetDefault("agents.task.model", models.QWENQwq)
- viper.SetDefault("agents.title.model", models.QWENQwq)
- return
- }
-
- // OpenRouter configuration
- if key := viper.GetString("providers.openrouter.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.OpenRouterClaude37Sonnet)
- viper.SetDefault("agents.summarizer.model", models.OpenRouterClaude37Sonnet)
- viper.SetDefault("agents.task.model", models.OpenRouterClaude37Sonnet)
- viper.SetDefault("agents.title.model", models.OpenRouterClaude35Haiku)
- return
- }
-
- // XAI configuration
- if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" {
- viper.SetDefault("agents.coder.model", models.XAIGrok3Beta)
- viper.SetDefault("agents.summarizer.model", models.XAIGrok3Beta)
- viper.SetDefault("agents.task.model", models.XAIGrok3Beta)
- viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta)
- return
- }
-
- // AWS Bedrock configuration
- if hasAWSCredentials() {
- viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet)
- viper.SetDefault("agents.summarizer.model", models.BedrockClaude37Sonnet)
- viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet)
- viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet)
- return
- }
-
- // Azure OpenAI configuration
- if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" {
- viper.SetDefault("agents.coder.model", models.AzureGPT41)
- viper.SetDefault("agents.summarizer.model", models.AzureGPT41)
- viper.SetDefault("agents.task.model", models.AzureGPT41Mini)
- viper.SetDefault("agents.title.model", models.AzureGPT41Mini)
- return
- }
-
- // Google Cloud VertexAI configuration
- if hasVertexAICredentials() {
- viper.SetDefault("agents.coder.model", models.VertexAIGemini25)
- viper.SetDefault("agents.summarizer.model", models.VertexAIGemini25)
- viper.SetDefault("agents.task.model", models.VertexAIGemini25Flash)
- viper.SetDefault("agents.title.model", models.VertexAIGemini25Flash)
- return
- }
+ return cfg, nil
}
-// hasAWSCredentials checks if AWS credentials are available in the environment.
-func hasAWSCredentials() bool {
- // Check for explicit AWS credentials
- if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
- return true
- }
-
- // Check for AWS profile
- if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
- return true
- }
+func Init(workingDir string, debug bool) (*Config, error) {
+ var err error
+ once.Do(func() {
+ cwd = workingDir
+ instance, err = loadConfig(cwd, debug)
+ if err != nil {
+ logging.Error("Failed to load config", "error", err)
+ }
+ })
- // Check for AWS region
- if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
- return true
- }
+ return instance, err
+}
- // Check if running on EC2 with instance profile
- if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
- os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
- return true
+func Get() *Config {
+ if instance == nil {
+ // TODO: Handle this better
+ panic("Config not initialized. Call InitConfig first.")
}
-
- return false
+ return instance
}
-// hasVertexAICredentials checks if VertexAI credentials are available in the environment.
-func hasVertexAICredentials() bool {
- // Check for explicit VertexAI parameters
- if os.Getenv("VERTEXAI_PROJECT") != "" && os.Getenv("VERTEXAI_LOCATION") != "" {
- return true
+func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
+ providers := Providers()
+ for _, p := range providers {
+ if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
+ return &providerConfig
+ }
}
- // Check for Google Cloud project and location
- if os.Getenv("GOOGLE_CLOUD_PROJECT") != "" && (os.Getenv("GOOGLE_CLOUD_REGION") != "" || os.Getenv("GOOGLE_CLOUD_LOCATION") != "") {
- return true
+ // if none found return the first configured provider
+ for _, providerConfig := range configuredProviders {
+ if !providerConfig.Disabled {
+ return &providerConfig
+ }
}
- return false
+ return nil
}
-// readConfig handles the result of reading a configuration file.
-func readConfig(err error) error {
- if err == nil {
- return nil
+func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
+ if other.APIKey != "" {
+ base.APIKey = other.APIKey
}
-
- // It's okay if the config file doesn't exist
- if _, ok := err.(viper.ConfigFileNotFoundError); ok {
- return nil
+ // Only change these options if the provider is not a known provider
+ if !slices.Contains(provider.KnownProviders(), p) {
+ if other.BaseURL != "" {
+ base.BaseURL = other.BaseURL
+ }
+ if other.ProviderType != "" {
+ base.ProviderType = other.ProviderType
+ }
+ if len(base.ExtraHeaders) > 0 {
+ if base.ExtraHeaders == nil {
+ base.ExtraHeaders = make(map[string]string)
+ }
+ maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
+ }
+ if len(other.ExtraParams) > 0 {
+ if base.ExtraParams == nil {
+ base.ExtraParams = make(map[string]string)
+ }
+ maps.Copy(base.ExtraParams, other.ExtraParams)
+ }
}
- return fmt.Errorf("failed to read config: %w", err)
-}
-
-// mergeLocalConfig loads and merges configuration from the local directory.
-func mergeLocalConfig(workingDir string) {
- local := viper.New()
- local.SetConfigName(fmt.Sprintf(".%s", appName))
- local.SetConfigType("json")
- local.AddConfigPath(workingDir)
-
- // Merge local config if it exists
- if err := local.ReadInConfig(); err == nil {
- viper.MergeConfigMap(local.AllSettings())
+ if other.Disabled {
+ base.Disabled = other.Disabled
}
-}
-// applyDefaultValues sets default values for configuration fields that need processing.
-func applyDefaultValues() {
- // Set default MCP type if not specified
- for k, v := range cfg.MCPServers {
- if v.Type == "" {
- v.Type = MCPStdio
- cfg.MCPServers[k] = v
- }
+ if other.DefaultLargeModel != "" {
+ base.DefaultLargeModel = other.DefaultLargeModel
}
-}
-
-// It validates model IDs and providers, ensuring they are supported.
-func validateAgent(cfg *Config, name AgentName, agent Agent) error {
- // Check if model exists
- model, modelExists := models.SupportedModels[agent.Model]
- if !modelExists {
- logging.Warn("unsupported model configured, reverting to default",
- "agent", name,
- "configured_model", agent.Model)
-
- // Set default model based on available providers
- if setDefaultModelForAgent(name) {
- logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
- } else {
- return fmt.Errorf("no valid provider available for agent %s", name)
- }
- return nil
- }
-
- // Check if provider for the model is configured
- provider := model.Provider
- providerCfg, providerExists := cfg.Providers[provider]
-
- if !providerExists {
- // Provider not configured, check if we have environment variables
- apiKey := getProviderAPIKey(provider)
- if apiKey == "" {
- logging.Warn("provider not configured for model, reverting to default",
- "agent", name,
- "model", agent.Model,
- "provider", provider)
-
- // Set default model based on available providers
- if setDefaultModelForAgent(name) {
- logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
- } else {
- return fmt.Errorf("no valid provider available for agent %s", name)
+ // Add new models if they don't exist
+ if other.Models != nil {
+ for _, model := range other.Models {
+ // check if the model already exists
+ exists := false
+ for _, existingModel := range base.Models {
+ if existingModel.ID == model.ID {
+ exists = true
+ break
+ }
}
- } else {
- // Add provider with API key from environment
- cfg.Providers[provider] = Provider{
- APIKey: apiKey,
- }
- logging.Info("added provider from environment", "provider", provider)
- }
- } else if providerCfg.Disabled || providerCfg.APIKey == "" {
- // Provider is disabled or has no API key
- logging.Warn("provider is disabled or has no API key, reverting to default",
- "agent", name,
- "model", agent.Model,
- "provider", provider)
-
- // Set default model based on available providers
- if setDefaultModelForAgent(name) {
- logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
- } else {
- return fmt.Errorf("no valid provider available for agent %s", name)
- }
- }
-
- // Validate max tokens
- if agent.MaxTokens <= 0 {
- logging.Warn("invalid max tokens, setting to default",
- "agent", name,
- "model", agent.Model,
- "max_tokens", agent.MaxTokens)
-
- // Update the agent with default max tokens
- updatedAgent := cfg.Agents[name]
- if model.DefaultMaxTokens > 0 {
- updatedAgent.MaxTokens = model.DefaultMaxTokens
- } else {
- updatedAgent.MaxTokens = MaxTokensFallbackDefault
- }
- cfg.Agents[name] = updatedAgent
- } else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 {
- // Ensure max tokens doesn't exceed half the context window (reasonable limit)
- logging.Warn("max tokens exceeds half the context window, adjusting",
- "agent", name,
- "model", agent.Model,
- "max_tokens", agent.MaxTokens,
- "context_window", model.ContextWindow)
-
- // Update the agent with adjusted max tokens
- updatedAgent := cfg.Agents[name]
- updatedAgent.MaxTokens = model.ContextWindow / 2
- cfg.Agents[name] = updatedAgent
- }
-
- // Validate reasoning effort for models that support reasoning
- if model.CanReason && provider == models.ProviderOpenAI || provider == models.ProviderLocal {
- if agent.ReasoningEffort == "" {
- // Set default reasoning effort for models that support it
- logging.Info("setting default reasoning effort for model that supports reasoning",
- "agent", name,
- "model", agent.Model)
-
- // Update the agent with default reasoning effort
- updatedAgent := cfg.Agents[name]
- updatedAgent.ReasoningEffort = "medium"
- cfg.Agents[name] = updatedAgent
- } else {
- // Check if reasoning effort is valid (low, medium, high)
- effort := strings.ToLower(agent.ReasoningEffort)
- if effort != "low" && effort != "medium" && effort != "high" {
- logging.Warn("invalid reasoning effort, setting to medium",
- "agent", name,
- "model", agent.Model,
- "reasoning_effort", agent.ReasoningEffort)
-
- // Update the agent with valid reasoning effort
- updatedAgent := cfg.Agents[name]
- updatedAgent.ReasoningEffort = "medium"
- cfg.Agents[name] = updatedAgent
+ if !exists {
+ base.Models = append(base.Models, model)
}
}
- } else if !model.CanReason && agent.ReasoningEffort != "" {
- // Model doesn't support reasoning but reasoning effort is set
- logging.Warn("model doesn't support reasoning but reasoning effort is set, ignoring",
- "agent", name,
- "model", agent.Model,
- "reasoning_effort", agent.ReasoningEffort)
-
- // Update the agent to remove reasoning effort
- updatedAgent := cfg.Agents[name]
- updatedAgent.ReasoningEffort = ""
- cfg.Agents[name] = updatedAgent
}
- return nil
+ return base
}
-// Validate checks if the configuration is valid and applies defaults where needed.
-func Validate() error {
- if cfg == nil {
- return fmt.Errorf("config not loaded")
- }
-
- // Validate agent models
- for name, agent := range cfg.Agents {
- if err := validateAgent(cfg, name, agent); err != nil {
- return err
+func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
+ if !slices.Contains(provider.KnownProviders(), p) {
+ if providerConfig.ProviderType != provider.TypeOpenAI {
+ return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
}
- }
-
- // Validate providers
- for provider, providerCfg := range cfg.Providers {
- if providerCfg.APIKey == "" && !providerCfg.Disabled {
- logging.Warn("provider has no API key, marking as disabled", "provider", provider)
- providerCfg.Disabled = true
- cfg.Providers[provider] = providerCfg
+ if providerConfig.BaseURL == "" {
+ return errors.New("base URL must be set for custom providers")
}
- }
-
- // Validate LSP configurations
- for language, lspConfig := range cfg.LSP {
- if lspConfig.Command == "" && !lspConfig.Disabled {
- logging.Warn("LSP configuration has no command, marking as disabled", "language", language)
- lspConfig.Disabled = true
- cfg.LSP[language] = lspConfig
+ if providerConfig.APIKey == "" {
+ return errors.New("API key must be set for custom providers")
}
}
-
return nil
}
-// getProviderAPIKey gets the API key for a provider from environment variables
-func getProviderAPIKey(provider models.InferenceProvider) string {
- switch provider {
- case models.ProviderAnthropic:
- return os.Getenv("ANTHROPIC_API_KEY")
- case models.ProviderOpenAI:
- return os.Getenv("OPENAI_API_KEY")
- case models.ProviderGemini:
- return os.Getenv("GEMINI_API_KEY")
- case models.ProviderGROQ:
- return os.Getenv("GROQ_API_KEY")
- case models.ProviderAzure:
- return os.Getenv("AZURE_OPENAI_API_KEY")
- case models.ProviderOpenRouter:
- return os.Getenv("OPENROUTER_API_KEY")
- case models.ProviderBedrock:
- if hasAWSCredentials() {
- return "aws-credentials-available"
- }
- case models.ProviderVertexAI:
- if hasVertexAICredentials() {
- return "vertex-ai-credentials-available"
- }
- }
- return ""
-}
-
-// setDefaultModelForAgent sets a default model for an agent based on available providers
-func setDefaultModelForAgent(agent AgentName) bool {
- // Check providers in order of preference
- if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
- maxTokens := int64(5000)
- if agent == AgentTitle {
- maxTokens = 80
+func mergeModels(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
}
- cfg.Agents[agent] = Agent{
- Model: models.Claude37Sonnet,
- MaxTokens: maxTokens,
+ if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
+ base.Models.Large = cfg.Models.Large
}
- return true
- }
- if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
- var model models.ModelID
- maxTokens := int64(5000)
- reasoningEffort := ""
-
- switch agent {
- case AgentTitle:
- model = models.GPT41Mini
- maxTokens = 80
- case AgentTask:
- model = models.GPT41Mini
- default:
- model = models.GPT41
+ if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
+ base.Models.Small = cfg.Models.Small
}
+ }
+}
- // Check if model supports reasoning
- if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason {
- reasoningEffort = "medium"
+func mergeOptions(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ baseOptions := base.Options
+ other := cfg.Options
+ if len(other.ContextPaths) > 0 {
+ baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
}
- cfg.Agents[agent] = Agent{
- Model: model,
- MaxTokens: maxTokens,
- ReasoningEffort: reasoningEffort,
+ if other.TUI.CompactMode {
+ baseOptions.TUI.CompactMode = other.TUI.CompactMode
}
- return true
- }
- if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
- var model models.ModelID
- maxTokens := int64(5000)
- reasoningEffort := ""
+ if other.Debug {
+ baseOptions.Debug = other.Debug
+ }
- switch agent {
- case AgentTitle:
- model = models.OpenRouterClaude35Haiku
- maxTokens = 80
- case AgentTask:
- model = models.OpenRouterClaude37Sonnet
- default:
- model = models.OpenRouterClaude37Sonnet
+ if other.DebugLSP {
+ baseOptions.DebugLSP = other.DebugLSP
}
- // Check if model supports reasoning
- if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason {
- reasoningEffort = "medium"
+ if other.DisableAutoSummarize {
+ baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
}
- cfg.Agents[agent] = Agent{
- Model: model,
- MaxTokens: maxTokens,
- ReasoningEffort: reasoningEffort,
+ if other.DataDirectory != "" {
+ baseOptions.DataDirectory = other.DataDirectory
}
- return true
+ base.Options = baseOptions
}
+}
- if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
- var model models.ModelID
- maxTokens := int64(5000)
-
- if agent == AgentTitle {
- model = models.Gemini25Flash
- maxTokens = 80
- } else {
- model = models.Gemini25
+func mergeAgents(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
}
-
- cfg.Agents[agent] = Agent{
- Model: model,
- MaxTokens: maxTokens,
+ for agentID, newAgent := range cfg.Agents {
+ if _, ok := base.Agents[agentID]; !ok {
+ newAgent.ID = agentID // Ensure the ID is set correctly
+ base.Agents[agentID] = newAgent
+ } else {
+ switch agentID {
+ case AgentCoder:
+ baseAgent := base.Agents[agentID]
+ if newAgent.Model != "" && newAgent.Provider != "" {
+ baseAgent.Model = newAgent.Model
+ baseAgent.Provider = newAgent.Provider
+ }
+ baseAgent.AllowedMCP = newAgent.AllowedMCP
+ baseAgent.AllowedLSP = newAgent.AllowedLSP
+ base.Agents[agentID] = baseAgent
+ default:
+ baseAgent := base.Agents[agentID]
+ baseAgent.Name = newAgent.Name
+ baseAgent.Description = newAgent.Description
+ baseAgent.Disabled = newAgent.Disabled
+ if newAgent.Model == "" || newAgent.Provider == "" {
+ baseAgent.Provider = base.Models.Large.Provider
+ baseAgent.Model = base.Models.Large.ModelID
+ }
+ baseAgent.AllowedTools = newAgent.AllowedTools
+ baseAgent.AllowedMCP = newAgent.AllowedMCP
+ baseAgent.AllowedLSP = newAgent.AllowedLSP
+ base.Agents[agentID] = baseAgent
+ }
+ }
}
- return true
}
+}
- if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
- maxTokens := int64(5000)
- if agent == AgentTitle {
- maxTokens = 80
+func mergeMCPs(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
}
+ maps.Copy(base.MCP, cfg.MCP)
+ }
+}
- cfg.Agents[agent] = Agent{
- Model: models.QWENQwq,
- MaxTokens: maxTokens,
+func mergeLSPs(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
}
- return true
+ maps.Copy(base.LSP, cfg.LSP)
}
+}
- if hasAWSCredentials() {
- maxTokens := int64(5000)
- if agent == AgentTitle {
- maxTokens = 80
+func mergeProviderConfigs(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
}
-
- cfg.Agents[agent] = Agent{
- Model: models.BedrockClaude37Sonnet,
- MaxTokens: maxTokens,
- ReasoningEffort: "medium", // Claude models support reasoning
+ for providerName, globalProvider := range cfg.Providers {
+ if _, ok := base.Providers[providerName]; !ok {
+ base.Providers[providerName] = globalProvider
+ } else {
+ base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
+ }
}
- return true
}
- if hasVertexAICredentials() {
- var model models.ModelID
- maxTokens := int64(5000)
-
- if agent == AgentTitle {
- model = models.VertexAIGemini25Flash
- maxTokens = 80
- } else {
- model = models.VertexAIGemini25
+ finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
+ for providerName, providerConfig := range base.Providers {
+ err := validateProvider(providerName, providerConfig)
+ if err != nil {
+ logging.Warn("Skipping provider", "name", providerName, "error", err)
}
+ finalProviders[providerName] = providerConfig
+ }
+ base.Providers = finalProviders
+}
- cfg.Agents[agent] = Agent{
- Model: model,
- MaxTokens: maxTokens,
+func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
+ switch providerId {
+ case provider.InferenceProviderAnthropic:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeAnthropic,
+ }
+ case provider.InferenceProviderOpenAI:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeOpenAI,
+ }
+ case provider.InferenceProviderGemini:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeGemini,
+ }
+ case provider.InferenceProviderBedrock:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeBedrock,
+ }
+ case provider.InferenceProviderAzure:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeAzure,
+ }
+ case provider.InferenceProviderOpenRouter:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeOpenAI,
+ BaseURL: "https://openrouter.ai/api/v1",
+ ExtraHeaders: map[string]string{
+ "HTTP-Referer": "crush.charm.land",
+ "X-Title": "Crush",
+ },
+ }
+ case provider.InferenceProviderXAI:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeXAI,
+ BaseURL: "https://api.x.ai/v1",
+ }
+ case provider.InferenceProviderVertexAI:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeVertexAI,
+ }
+ default:
+ return ProviderConfig{
+ ID: providerId,
+ ProviderType: provider.TypeOpenAI,
}
- return true
}
-
- return false
}
-func updateCfgFile(updateCfg func(config *Config)) error {
- if cfg == nil {
- return fmt.Errorf("config not loaded")
+func defaultConfigBasedOnEnv() *Config {
+ cfg := &Config{
+ Options: Options{
+ DataDirectory: defaultDataDirectory,
+ ContextPaths: defaultContextPaths,
+ },
+ Providers: make(map[provider.InferenceProvider]ProviderConfig),
+ }
+
+ providers := Providers()
+
+ for _, p := range providers {
+ if strings.HasPrefix(p.APIKey, "$") {
+ envVar := strings.TrimPrefix(p.APIKey, "$")
+ if apiKey := os.Getenv(envVar); apiKey != "" {
+ providerConfig := providerDefaultConfig(p.ID)
+ providerConfig.APIKey = apiKey
+ providerConfig.DefaultLargeModel = p.DefaultLargeModelID
+ providerConfig.DefaultSmallModel = p.DefaultSmallModelID
+ baseURL := p.APIEndpoint
+ if strings.HasPrefix(baseURL, "$") {
+ envVar := strings.TrimPrefix(baseURL, "$")
+ if url := os.Getenv(envVar); url != "" {
+ baseURL = url
+ }
+ }
+ providerConfig.BaseURL = baseURL
+ for _, model := range p.Models {
+ providerConfig.Models = append(providerConfig.Models, Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ SupportsImages: model.SupportsImages,
+ })
+ }
+ cfg.Providers[p.ID] = providerConfig
+ }
+ }
}
+ // TODO: support local models
- // Get the config file path
- configFile := viper.ConfigFileUsed()
- var configData []byte
- if configFile == "" {
- homeDir, err := os.UserHomeDir()
- if err != nil {
- return fmt.Errorf("failed to get home directory: %w", err)
+ if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
+ providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
+ providerConfig.ExtraParams = map[string]string{
+ "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
+ "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
}
- configFile = filepath.Join(homeDir, fmt.Sprintf(".%s.json", appName))
- logging.Info("config file not found, creating new one", "path", configFile)
- configData = []byte(`{}`)
- } else {
- // Read the existing config file
- data, err := os.ReadFile(configFile)
- if err != nil {
- return fmt.Errorf("failed to read config file: %w", err)
- }
- configData = data
+ cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
}
- // Parse the JSON
- var userCfg *Config
- if err := json.Unmarshal(configData, &userCfg); err != nil {
- return fmt.Errorf("failed to parse config file: %w", err)
+ if hasAWSCredentials() {
+ providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
+ providerConfig.ExtraParams = map[string]string{
+ "region": os.Getenv("AWS_DEFAULT_REGION"),
+ }
+ if providerConfig.ExtraParams["region"] == "" {
+ providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
+ }
+ cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
}
+ return cfg
+}
- updateCfg(userCfg)
+func hasAWSCredentials() bool {
+ if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
+ return true
+ }
- // Write the updated config back to file
- updatedData, err := json.MarshalIndent(userCfg, "", " ")
- if err != nil {
- return fmt.Errorf("failed to marshal config: %w", err)
+ if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
+ return true
}
- if err := os.WriteFile(configFile, updatedData, 0o644); err != nil {
- return fmt.Errorf("failed to write config file: %w", err)
+ if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
+ return true
}
- return nil
-}
+ if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
+ os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
+ return true
+ }
-// Get returns the current configuration.
-// It's safe to call this function multiple times.
-func Get() *Config {
- return cfg
+ return false
}
-// WorkingDirectory returns the current working directory from the configuration.
func WorkingDirectory() string {
- if cfg == nil {
- panic("config not loaded")
- }
- return cfg.WorkingDir
+ return cwd
}
-func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error {
- if cfg == nil {
- panic("config not loaded")
+func GetAgentModel(agentID AgentID) Model {
+ cfg := Get()
+ agent, ok := cfg.Agents[agentID]
+ if !ok {
+ logging.Error("Agent not found", "agent_id", agentID)
+ return Model{}
}
- existingAgentCfg := cfg.Agents[agentName]
-
- model, ok := models.SupportedModels[modelID]
+ providerConfig, ok := cfg.Providers[agent.Provider]
if !ok {
- return fmt.Errorf("model %s not supported", modelID)
+ logging.Error("Provider not found for agent", "agent_id", agentID, "provider", agent.Provider)
+ return Model{}
}
- maxTokens := existingAgentCfg.MaxTokens
- if model.DefaultMaxTokens > 0 {
- maxTokens = model.DefaultMaxTokens
+ for _, model := range providerConfig.Models {
+ if model.ID == agent.Model {
+ return model
+ }
}
- newAgentCfg := Agent{
- Model: modelID,
- MaxTokens: maxTokens,
- ReasoningEffort: existingAgentCfg.ReasoningEffort,
- }
- cfg.Agents[agentName] = newAgentCfg
+ logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
+ return Model{}
+}
- if err := validateAgent(cfg, agentName, newAgentCfg); err != nil {
- // revert config update on failure
- cfg.Agents[agentName] = existingAgentCfg
- return fmt.Errorf("failed to update agent model: %w", err)
+func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
+ cfg := Get()
+ providerConfig, ok := cfg.Providers[provider]
+ if !ok {
+ logging.Error("Provider not found", "provider", provider)
+ return Model{}
}
- return updateCfgFile(func(config *Config) {
- if config.Agents == nil {
- config.Agents = make(map[AgentName]Agent)
+ for _, model := range providerConfig.Models {
+ if model.ID == modelID {
+ return model
}
- config.Agents[agentName] = newAgentCfg
- })
-}
-
-// UpdateTheme updates the theme in the configuration and writes it to the config file.
-func UpdateTheme(themeName string) error {
- if cfg == nil {
- return fmt.Errorf("config not loaded")
}
- // Update the in-memory config
- cfg.TUI.Theme = themeName
-
- // Update the file config
- return updateCfgFile(func(config *Config) {
- config.TUI.Theme = themeName
- })
+ logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
+ return Model{}
}
@@ -1,4 +1,4 @@
-package configv2
+package config
import (
"encoding/json"
@@ -28,7 +28,7 @@ func TestConfigWithEnv(t *testing.T) {
os.Setenv("GEMINI_API_KEY", "test-gemini-key")
os.Setenv("XAI_API_KEY", "test-xai-key")
os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
- cfg := InitConfig(cwdDir)
+ cfg, _ := Init(cwdDir, false)
data, _ := json.MarshalIndent(cfg, "", " ")
fmt.Println(string(data))
assert.Len(t, cfg.Providers, 5)
@@ -1,4 +1,4 @@
-package configv2
+package config
import (
"fmt"
@@ -17,23 +17,20 @@ type ProjectInitFlag struct {
Initialized bool `json:"initialized"`
}
-// ShouldShowInitDialog checks if the initialization dialog should be shown for the current directory
-func ShouldShowInitDialog() (bool, error) {
- if cfg == nil {
+// ProjectNeedsInitialization checks if the current project needs initialization
+func ProjectNeedsInitialization() (bool, error) {
+ if instance == nil {
return false, fmt.Errorf("config not loaded")
}
- // Create the flag file path
- flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename)
+ flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
// Check if the flag file exists
_, err := os.Stat(flagFilePath)
if err == nil {
- // File exists, don't show the dialog
return false, nil
}
- // If the error is not "file not found", return the error
if !os.IsNotExist(err) {
return false, fmt.Errorf("failed to check init flag file: %w", err)
}
@@ -44,11 +41,9 @@ func ShouldShowInitDialog() (bool, error) {
return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err)
}
if crushExists {
- // CRUSH.md already exists, don't show the dialog
return false, nil
}
- // File doesn't exist, show the dialog
return true, nil
}
@@ -75,13 +70,11 @@ func crushMdExists(dir string) (bool, error) {
// MarkProjectInitialized marks the current project as initialized
func MarkProjectInitialized() error {
- if cfg == nil {
+ if instance == nil {
return fmt.Errorf("config not loaded")
}
- // Create the flag file path
- flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename)
+ flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
- // Create an empty file to mark the project as initialized
file, err := os.Create(flagFilePath)
if err != nil {
return fmt.Errorf("failed to create init flag file: %w", err)
@@ -1,4 +1,4 @@
-package configv2
+package config
import (
"encoding/json"
@@ -1,660 +0,0 @@
-package configv2
-
-import (
- "encoding/json"
- "errors"
- "maps"
- "os"
- "path/filepath"
- "slices"
- "strings"
- "sync"
-
- "github.com/charmbracelet/crush/internal/fur/provider"
- "github.com/charmbracelet/crush/internal/logging"
-)
-
-const (
- defaultDataDirectory = ".crush"
- defaultLogLevel = "info"
- appName = "crush"
-
- MaxTokensFallbackDefault = 4096
-)
-
-var defaultContextPaths = []string{
- ".github/copilot-instructions.md",
- ".cursorrules",
- ".cursor/rules/",
- "CLAUDE.md",
- "CLAUDE.local.md",
- "crush.md",
- "crush.local.md",
- "Crush.md",
- "Crush.local.md",
- "CRUSH.md",
- "CRUSH.local.md",
-}
-
-type AgentID string
-
-const (
- AgentCoder AgentID = "coder"
- AgentTask AgentID = "task"
- AgentTitle AgentID = "title"
- AgentSummarize AgentID = "summarize"
-)
-
-type Model struct {
- ID string `json:"id"`
- Name string `json:"model"`
- CostPer1MIn float64 `json:"cost_per_1m_in"`
- CostPer1MOut float64 `json:"cost_per_1m_out"`
- CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
- CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
- ContextWindow int64 `json:"context_window"`
- DefaultMaxTokens int64 `json:"default_max_tokens"`
- CanReason bool `json:"can_reason"`
- ReasoningEffort string `json:"reasoning_effort"`
- SupportsImages bool `json:"supports_attachments"`
-}
-
-type VertexAIOptions struct {
- APIKey string `json:"api_key,omitempty"`
- Project string `json:"project,omitempty"`
- Location string `json:"location,omitempty"`
-}
-
-type ProviderConfig struct {
- ID provider.InferenceProvider `json:"id"`
- BaseURL string `json:"base_url,omitempty"`
- ProviderType provider.Type `json:"provider_type"`
- APIKey string `json:"api_key,omitempty"`
- Disabled bool `json:"disabled"`
- ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
- // used for e.x for vertex to set the project
- ExtraParams map[string]string `json:"extra_params,omitempty"`
-
- DefaultLargeModel string `json:"default_large_model,omitempty"`
- DefaultSmallModel string `json:"default_small_model,omitempty"`
-
- Models []Model `json:"models,omitempty"`
-}
-
-type Agent struct {
- ID AgentID `json:"id"`
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- // This is the id of the system prompt used by the agent
- Disabled bool `json:"disabled"`
-
- Provider provider.InferenceProvider `json:"provider"`
- Model string `json:"model"`
-
- // The available tools for the agent
- // if this is nil, all tools are available
- AllowedTools []string `json:"allowed_tools"`
-
- // this tells us which MCPs are available for this agent
- // if this is empty all mcps are available
- // the string array is the list of tools from the AllowedMCP the agent has available
- // if the string array is nil, all tools from the AllowedMCP are available
- AllowedMCP map[string][]string `json:"allowed_mcp"`
-
- // The list of LSPs that this agent can use
- // if this is nil, all LSPs are available
- AllowedLSP []string `json:"allowed_lsp"`
-
- // Overrides the context paths for this agent
- ContextPaths []string `json:"context_paths"`
-}
-
-type MCPType string
-
-const (
- MCPStdio MCPType = "stdio"
- MCPSse MCPType = "sse"
-)
-
-type MCP struct {
- Command string `json:"command"`
- Env []string `json:"env"`
- Args []string `json:"args"`
- Type MCPType `json:"type"`
- URL string `json:"url"`
- Headers map[string]string `json:"headers"`
-}
-
-type LSPConfig struct {
- Disabled bool `json:"enabled"`
- Command string `json:"command"`
- Args []string `json:"args"`
- Options any `json:"options"`
-}
-
-type TUIOptions struct {
- CompactMode bool `json:"compact_mode"`
- // Here we can add themes later or any TUI related options
-}
-
-type Options struct {
- ContextPaths []string `json:"context_paths"`
- TUI TUIOptions `json:"tui"`
- Debug bool `json:"debug"`
- DebugLSP bool `json:"debug_lsp"`
- DisableAutoSummarize bool `json:"disable_auto_summarize"`
- // Relative to the cwd
- DataDirectory string `json:"data_directory"`
-}
-
-type Config struct {
- // List of configured providers
- Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
-
- // List of configured agents
- Agents map[AgentID]Agent `json:"agents,omitempty"`
-
- // List of configured MCPs
- MCP map[string]MCP `json:"mcp,omitempty"`
-
- // List of configured LSPs
- LSP map[string]LSPConfig `json:"lsp,omitempty"`
-
- // Miscellaneous options
- Options Options `json:"options"`
-}
-
-var (
- instance *Config // The single instance of the Singleton
- cwd string
- once sync.Once // Ensures the initialization happens only once
-
-)
-
-func loadConfig(cwd string) (*Config, error) {
- // First read the global config file
- cfgPath := ConfigPath()
-
- cfg := defaultConfigBasedOnEnv()
-
- var globalCfg *Config
- if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
- // some other error occurred while checking the file
- return nil, err
- } else if err == nil {
- // config file exists, read it
- file, err := os.ReadFile(cfgPath)
- if err != nil {
- return nil, err
- }
- globalCfg = &Config{}
- if err := json.Unmarshal(file, globalCfg); err != nil {
- return nil, err
- }
- } else {
- // config file does not exist, create a new one
- globalCfg = &Config{}
- }
-
- var localConfig *Config
- // Global config loaded, now read the local config file
- localConfigPath := filepath.Join(cwd, "crush.json")
- if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
- // some other error occurred while checking the file
- return nil, err
- } else if err == nil {
- // local config file exists, read it
- file, err := os.ReadFile(localConfigPath)
- if err != nil {
- return nil, err
- }
- localConfig = &Config{}
- if err := json.Unmarshal(file, localConfig); err != nil {
- return nil, err
- }
- }
-
- // merge options
- mergeOptions(cfg, globalCfg, localConfig)
-
- mergeProviderConfigs(cfg, globalCfg, localConfig)
- // no providers found the app is not initialized yet
- if len(cfg.Providers) == 0 {
- return cfg, nil
- }
- preferredProvider := getPreferredProvider(cfg.Providers)
-
- if preferredProvider == nil {
- return nil, errors.New("no valid providers configured")
- }
-
- agents := map[AgentID]Agent{
- AgentCoder: {
- ID: AgentCoder,
- Name: "Coder",
- Description: "An agent that helps with executing coding tasks.",
- Provider: preferredProvider.ID,
- Model: preferredProvider.DefaultLargeModel,
- ContextPaths: cfg.Options.ContextPaths,
- // All tools allowed
- },
- AgentTask: {
- ID: AgentTask,
- Name: "Task",
- Description: "An agent that helps with searching for context and finding implementation details.",
- Provider: preferredProvider.ID,
- Model: preferredProvider.DefaultLargeModel,
- ContextPaths: cfg.Options.ContextPaths,
- AllowedTools: []string{
- "glob",
- "grep",
- "ls",
- "sourcegraph",
- "view",
- },
- // NO MCPs or LSPs by default
- AllowedMCP: map[string][]string{},
- AllowedLSP: []string{},
- },
- AgentTitle: {
- ID: AgentTitle,
- Name: "Title",
- Description: "An agent that helps with generating titles for sessions.",
- Provider: preferredProvider.ID,
- Model: preferredProvider.DefaultSmallModel,
- ContextPaths: cfg.Options.ContextPaths,
- AllowedTools: []string{},
- // NO MCPs or LSPs by default
- AllowedMCP: map[string][]string{},
- AllowedLSP: []string{},
- },
- AgentSummarize: {
- ID: AgentSummarize,
- Name: "Summarize",
- Description: "An agent that helps with summarizing sessions.",
- Provider: preferredProvider.ID,
- Model: preferredProvider.DefaultSmallModel,
- ContextPaths: cfg.Options.ContextPaths,
- AllowedTools: []string{},
- // NO MCPs or LSPs by default
- AllowedMCP: map[string][]string{},
- AllowedLSP: []string{},
- },
- }
- cfg.Agents = agents
- mergeAgents(cfg, globalCfg, localConfig)
- mergeMCPs(cfg, globalCfg, localConfig)
- mergeLSPs(cfg, globalCfg, localConfig)
-
- return cfg, nil
-}
-
-func InitConfig(workingDir string) *Config {
- once.Do(func() {
- cwd = workingDir
- cfg, err := loadConfig(cwd)
- if err != nil {
- // TODO: Handle this better
- panic("Failed to load config: " + err.Error())
- }
- instance = cfg
- })
-
- return instance
-}
-
-func GetConfig() *Config {
- if instance == nil {
- // TODO: Handle this better
- panic("Config not initialized. Call InitConfig first.")
- }
- return instance
-}
-
-func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
- providers := Providers()
- for _, p := range providers {
- if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
- return &providerConfig
- }
- }
- // if none found return the first configured provider
- for _, providerConfig := range configuredProviders {
- if !providerConfig.Disabled {
- return &providerConfig
- }
- }
- return nil
-}
-
-func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
- if other.APIKey != "" {
- base.APIKey = other.APIKey
- }
- // Only change these options if the provider is not a known provider
- if !slices.Contains(provider.KnownProviders(), p) {
- if other.BaseURL != "" {
- base.BaseURL = other.BaseURL
- }
- if other.ProviderType != "" {
- base.ProviderType = other.ProviderType
- }
- if len(base.ExtraHeaders) > 0 {
- if base.ExtraHeaders == nil {
- base.ExtraHeaders = make(map[string]string)
- }
- maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
- }
- if len(other.ExtraParams) > 0 {
- if base.ExtraParams == nil {
- base.ExtraParams = make(map[string]string)
- }
- maps.Copy(base.ExtraParams, other.ExtraParams)
- }
- }
-
- if other.Disabled {
- base.Disabled = other.Disabled
- }
-
- if other.DefaultLargeModel != "" {
- base.DefaultLargeModel = other.DefaultLargeModel
- }
- // Add new models if they don't exist
- if other.Models != nil {
- for _, model := range other.Models {
- // check if the model already exists
- exists := false
- for _, existingModel := range base.Models {
- if existingModel.ID == model.ID {
- exists = true
- break
- }
- }
- if !exists {
- base.Models = append(base.Models, model)
- }
- }
- }
-
- return base
-}
-
-func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
- if !slices.Contains(provider.KnownProviders(), p) {
- if providerConfig.ProviderType != provider.TypeOpenAI {
- return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
- }
- if providerConfig.BaseURL == "" {
- return errors.New("base URL must be set for custom providers")
- }
- if providerConfig.APIKey == "" {
- return errors.New("API key must be set for custom providers")
- }
- }
- return nil
-}
-
-func mergeOptions(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
- if cfg == nil {
- continue
- }
- baseOptions := base.Options
- other := cfg.Options
- if len(other.ContextPaths) > 0 {
- baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
- }
-
- if other.TUI.CompactMode {
- baseOptions.TUI.CompactMode = other.TUI.CompactMode
- }
-
- if other.Debug {
- baseOptions.Debug = other.Debug
- }
-
- if other.DebugLSP {
- baseOptions.DebugLSP = other.DebugLSP
- }
-
- if other.DisableAutoSummarize {
- baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
- }
-
- if other.DataDirectory != "" {
- baseOptions.DataDirectory = other.DataDirectory
- }
- base.Options = baseOptions
- }
-}
-
-func mergeAgents(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
- if cfg == nil {
- continue
- }
- for agentID, newAgent := range cfg.Agents {
- if _, ok := base.Agents[agentID]; !ok {
- newAgent.ID = agentID // Ensure the ID is set correctly
- base.Agents[agentID] = newAgent
- } else {
- switch agentID {
- case AgentCoder:
- baseAgent := base.Agents[agentID]
- baseAgent.Model = newAgent.Model
- baseAgent.Provider = newAgent.Provider
- baseAgent.AllowedMCP = newAgent.AllowedMCP
- baseAgent.AllowedLSP = newAgent.AllowedLSP
- base.Agents[agentID] = baseAgent
- case AgentTask:
- baseAgent := base.Agents[agentID]
- baseAgent.Model = newAgent.Model
- baseAgent.Provider = newAgent.Provider
- base.Agents[agentID] = baseAgent
- case AgentTitle:
- baseAgent := base.Agents[agentID]
- baseAgent.Model = newAgent.Model
- baseAgent.Provider = newAgent.Provider
- base.Agents[agentID] = baseAgent
- case AgentSummarize:
- baseAgent := base.Agents[agentID]
- baseAgent.Model = newAgent.Model
- baseAgent.Provider = newAgent.Provider
- base.Agents[agentID] = baseAgent
- default:
- baseAgent := base.Agents[agentID]
- baseAgent.Name = newAgent.Name
- baseAgent.Description = newAgent.Description
- baseAgent.Disabled = newAgent.Disabled
- baseAgent.Provider = newAgent.Provider
- baseAgent.Model = newAgent.Model
- baseAgent.AllowedTools = newAgent.AllowedTools
- baseAgent.AllowedMCP = newAgent.AllowedMCP
- baseAgent.AllowedLSP = newAgent.AllowedLSP
- base.Agents[agentID] = baseAgent
-
- }
- }
- }
- }
-}
-
-func mergeMCPs(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
- if cfg == nil {
- continue
- }
- maps.Copy(base.MCP, cfg.MCP)
- }
-}
-
-func mergeLSPs(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
- if cfg == nil {
- continue
- }
- maps.Copy(base.LSP, cfg.LSP)
- }
-}
-
-func mergeProviderConfigs(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
- if cfg == nil {
- continue
- }
- for providerName, globalProvider := range cfg.Providers {
- if _, ok := base.Providers[providerName]; !ok {
- base.Providers[providerName] = globalProvider
- } else {
- base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
- }
- }
- }
-
- finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
- for providerName, providerConfig := range base.Providers {
- err := validateProvider(providerName, providerConfig)
- if err != nil {
- logging.Warn("Skipping provider", "name", providerName, "error", err)
- }
- finalProviders[providerName] = providerConfig
- }
- base.Providers = finalProviders
-}
-
-func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
- switch providerId {
- case provider.InferenceProviderAnthropic:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeAnthropic,
- }
- case provider.InferenceProviderOpenAI:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeOpenAI,
- }
- case provider.InferenceProviderGemini:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeGemini,
- }
- case provider.InferenceProviderBedrock:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeBedrock,
- }
- case provider.InferenceProviderAzure:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeAzure,
- }
- case provider.InferenceProviderOpenRouter:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeOpenAI,
- BaseURL: "https://openrouter.ai/api/v1",
- ExtraHeaders: map[string]string{
- "HTTP-Referer": "crush.charm.land",
- "X-Title": "Crush",
- },
- }
- case provider.InferenceProviderXAI:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeXAI,
- BaseURL: "https://api.x.ai/v1",
- }
- case provider.InferenceProviderVertexAI:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeVertexAI,
- }
- default:
- return ProviderConfig{
- ID: providerId,
- ProviderType: provider.TypeOpenAI,
- }
- }
-}
-
-func defaultConfigBasedOnEnv() *Config {
- cfg := &Config{
- Options: Options{
- DataDirectory: defaultDataDirectory,
- ContextPaths: defaultContextPaths,
- },
- Providers: make(map[provider.InferenceProvider]ProviderConfig),
- }
-
- providers := Providers()
-
- for _, p := range providers {
- if strings.HasPrefix(p.APIKey, "$") {
- envVar := strings.TrimPrefix(p.APIKey, "$")
- if apiKey := os.Getenv(envVar); apiKey != "" {
- providerConfig := providerDefaultConfig(p.ID)
- providerConfig.APIKey = apiKey
- providerConfig.DefaultLargeModel = p.DefaultLargeModelID
- providerConfig.DefaultSmallModel = p.DefaultSmallModelID
- for _, model := range p.Models {
- providerConfig.Models = append(providerConfig.Models, Model{
- ID: model.ID,
- Name: model.Name,
- CostPer1MIn: model.CostPer1MIn,
- CostPer1MOut: model.CostPer1MOut,
- CostPer1MInCached: model.CostPer1MInCached,
- CostPer1MOutCached: model.CostPer1MOutCached,
- ContextWindow: model.ContextWindow,
- DefaultMaxTokens: model.DefaultMaxTokens,
- CanReason: model.CanReason,
- SupportsImages: model.SupportsImages,
- })
- }
- cfg.Providers[p.ID] = providerConfig
- }
- }
- }
- // TODO: support local models
-
- if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
- providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
- providerConfig.ExtraParams = map[string]string{
- "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
- "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
- }
- cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
- }
-
- if hasAWSCredentials() {
- providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
- cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
- }
- return cfg
-}
-
-func hasAWSCredentials() bool {
- if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
- return true
- }
-
- if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
- return true
- }
-
- if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
- return true
- }
-
- if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
- os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
- return true
- }
-
- return false
-}
-
-func WorkingDirectory() string {
- return cwd
-}
@@ -1,7 +1,6 @@
package db
import (
- "context"
"database/sql"
"fmt"
"os"
@@ -16,8 +15,8 @@ import (
"github.com/pressly/goose/v3"
)
-func Connect(ctx context.Context) (*sql.DB, error) {
- dataDir := config.Get().Data.Directory
+func Connect() (*sql.DB, error) {
+ dataDir := config.Get().Options.DataDirectory
if dataDir == "" {
return nil, fmt.Errorf("data.dir is not set")
}
@@ -17,12 +17,13 @@ INSERT INTO messages (
role,
parts,
model,
+ provider,
created_at,
updated_at
) VALUES (
- ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
+ ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
-RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at
+RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
`
type CreateMessageParams struct {
@@ -31,6 +32,7 @@ type CreateMessageParams struct {
Role string `json:"role"`
Parts string `json:"parts"`
Model sql.NullString `json:"model"`
+ Provider sql.NullString `json:"provider"`
}
func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
@@ -40,6 +42,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
arg.Role,
arg.Parts,
arg.Model,
+ arg.Provider,
)
var i Message
err := row.Scan(
@@ -51,6 +54,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
+ &i.Provider,
)
return i, err
}
@@ -76,7 +80,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e
}
const getMessage = `-- name: GetMessage :one
-SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
FROM messages
WHERE id = ? LIMIT 1
`
@@ -93,12 +97,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
+ &i.Provider,
)
return i, err
}
const listMessagesBySession = `-- name: ListMessagesBySession :many
-SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
FROM messages
WHERE session_id = ?
ORDER BY created_at ASC
@@ -122,6 +127,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
+ &i.Provider,
); err != nil {
return nil, err
}
@@ -0,0 +1,11 @@
+-- +goose Up
+-- +goose StatementBegin
+-- Add provider column to messages table
+ALTER TABLE messages ADD COLUMN provider TEXT;
+-- +goose StatementEnd
+
+-- +goose Down
+-- +goose StatementBegin
+-- Remove provider column from messages table
+ALTER TABLE messages DROP COLUMN provider;
+-- +goose StatementEnd
@@ -27,6 +27,7 @@ type Message struct {
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
FinishedAt sql.NullInt64 `json:"finished_at"`
+ Provider sql.NullString `json:"provider"`
}
type Session struct {
@@ -16,10 +16,11 @@ INSERT INTO messages (
role,
parts,
model,
+ provider,
created_at,
updated_at
) VALUES (
- ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
+ ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
RETURNING *;
@@ -10,7 +10,7 @@ import (
"github.com/charmbracelet/crush/internal/fur/provider"
)
-const defaultURL = "http://localhost:8080"
+const defaultURL = "https://fur.charmcli.dev"
// Client represents a client for the fur service.
type Client struct {
@@ -6,14 +6,13 @@ type Type string
// All the supported AI provider types.
const (
- TypeOpenAI Type = "openai"
- TypeAnthropic Type = "anthropic"
- TypeGemini Type = "gemini"
- TypeAzure Type = "azure"
- TypeBedrock Type = "bedrock"
- TypeVertexAI Type = "vertexai"
- TypeXAI Type = "xai"
- TypeOpenRouter Type = "openrouter"
+ TypeOpenAI Type = "openai"
+ TypeAnthropic Type = "anthropic"
+ TypeGemini Type = "gemini"
+ TypeAzure Type = "azure"
+ TypeBedrock Type = "bedrock"
+ TypeVertexAI Type = "vertexai"
+ TypeXAI Type = "xai"
)
// InferenceProvider represents the inference provider identifier.
@@ -5,17 +5,15 @@ import (
"encoding/json"
"fmt"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
- "github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
)
type agentTool struct {
- sessions session.Service
- messages message.Service
- lspClients map[string]*lsp.Client
+ agent Service
+ sessions session.Service
+ messages message.Service
}
const (
@@ -58,17 +56,12 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
}
- agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients))
- if err != nil {
- return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
- }
-
session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
}
- done, err := agent.Run(ctx, session.ID, params.Prompt)
+ done, err := b.agent.Run(ctx, session.ID, params.Prompt)
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
}
@@ -101,13 +94,13 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
}
func NewAgentTool(
- Sessions session.Service,
- Messages message.Service,
- LspClients map[string]*lsp.Client,
+ agent Service,
+ sessions session.Service,
+ messages message.Service,
) tools.BaseTool {
return &agentTool{
- sessions: Sessions,
- messages: Messages,
- lspClients: LspClients,
+ sessions: sessions,
+ messages: messages,
+ agent: agent,
}
}
@@ -4,16 +4,18 @@ import (
"context"
"errors"
"fmt"
+ "slices"
"strings"
"sync"
"time"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ configv2 "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
+ "github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -47,71 +49,198 @@ type AgentEvent struct {
type Service interface {
pubsub.Suscriber[AgentEvent]
- Model() models.Model
+ Model() configv2.Model
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
Cancel(sessionID string)
CancelAll()
IsSessionBusy(sessionID string) bool
IsBusy() bool
- Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
+ Update(model configv2.PreferredModel) (configv2.Model, error)
Summarize(ctx context.Context, sessionID string) error
}
type agent struct {
*pubsub.Broker[AgentEvent]
+ agentCfg configv2.Agent
sessions session.Service
messages message.Service
- tools []tools.BaseTool
- provider provider.Provider
+ tools []tools.BaseTool
+ provider provider.Provider
+ providerID string
- titleProvider provider.Provider
- summarizeProvider provider.Provider
+ titleProvider provider.Provider
+ summarizeProvider provider.Provider
+ summarizeProviderID string
activeRequests sync.Map
}
+var agentPromptMap = map[configv2.AgentID]prompt.PromptID{
+ configv2.AgentCoder: prompt.PromptCoder,
+ configv2.AgentTask: prompt.PromptTask,
+}
+
func NewAgent(
- agentName config.AgentName,
+ agentCfg configv2.Agent,
+ // These services are needed in the tools
+ permissions permission.Service,
sessions session.Service,
messages message.Service,
- agentTools []tools.BaseTool,
+ history history.Service,
+ lspClients map[string]*lsp.Client,
) (Service, error) {
- agentProvider, err := createAgentProvider(agentName)
+ ctx := context.Background()
+ cfg := configv2.Get()
+ otherTools := GetMcpTools(ctx, permissions)
+ if len(lspClients) > 0 {
+ otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
+ }
+
+ allTools := []tools.BaseTool{
+ tools.NewBashTool(permissions),
+ tools.NewEditTool(lspClients, permissions, history),
+ tools.NewFetchTool(permissions),
+ tools.NewGlobTool(),
+ tools.NewGrepTool(),
+ tools.NewLsTool(),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients),
+ tools.NewWriteTool(lspClients, permissions, history),
+ }
+
+ if agentCfg.ID == configv2.AgentCoder {
+ taskAgentCfg := configv2.Get().Agents[configv2.AgentTask]
+ if taskAgentCfg.ID == "" {
+ return nil, fmt.Errorf("task agent not found in config")
+ }
+ taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create task agent: %w", err)
+ }
+
+ allTools = append(
+ allTools,
+ NewAgentTool(
+ taskAgent,
+ sessions,
+ messages,
+ ),
+ )
+ }
+
+ allTools = append(allTools, otherTools...)
+ var providerCfg configv2.ProviderConfig
+ for _, p := range cfg.Providers {
+ if p.ID == agentCfg.Provider {
+ providerCfg = p
+ break
+ }
+ }
+ if providerCfg.ID == "" {
+ return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider)
+ }
+
+ var model configv2.Model
+ for _, m := range providerCfg.Models {
+ if m.ID == agentCfg.Model {
+ model = m
+ break
+ }
+ }
+ if model.ID == "" {
+ return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider)
+ }
+
+ promptID := agentPromptMap[agentCfg.ID]
+ if promptID == "" {
+ promptID = prompt.PromptDefault
+ }
+ opts := []provider.ProviderClientOption{
+ provider.WithModel(model),
+ provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
+ provider.WithMaxTokens(model.DefaultMaxTokens),
+ }
+ agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
if err != nil {
return nil, err
}
- var titleProvider provider.Provider
- // Only generate titles for the coder agent
- if agentName == config.AgentCoder {
- titleProvider, err = createAgentProvider(config.AgentTitle)
- if err != nil {
- return nil, err
+
+ smallModelCfg := cfg.Models.Small
+ var smallModel configv2.Model
+
+ var smallModelProviderCfg configv2.ProviderConfig
+ if smallModelCfg.Provider == providerCfg.ID {
+ smallModelProviderCfg = providerCfg
+ } else {
+ for _, p := range cfg.Providers {
+ if p.ID == smallModelCfg.Provider {
+ smallModelProviderCfg = p
+ break
+ }
+ }
+ if smallModelProviderCfg.ID == "" {
+ return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
}
}
- var summarizeProvider provider.Provider
- if agentName == config.AgentCoder {
- summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
- if err != nil {
- return nil, err
+ for _, m := range smallModelProviderCfg.Models {
+ if m.ID == smallModelCfg.ModelID {
+ smallModel = m
+ break
+ }
+ }
+ if smallModel.ID == "" {
+ return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
+ }
+
+ titleOpts := []provider.ProviderClientOption{
+ provider.WithModel(smallModel),
+ provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
+ provider.WithMaxTokens(40),
+ }
+ titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
+ if err != nil {
+ return nil, err
+ }
+ summarizeOpts := []provider.ProviderClientOption{
+ provider.WithModel(smallModel),
+ provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
+ provider.WithMaxTokens(smallModel.DefaultMaxTokens),
+ }
+ summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
+ if err != nil {
+ return nil, err
+ }
+
+ agentTools := []tools.BaseTool{}
+ if agentCfg.AllowedTools == nil {
+ agentTools = allTools
+ } else {
+ for _, tool := range allTools {
+ if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
+ agentTools = append(agentTools, tool)
+ }
}
}
agent := &agent{
- Broker: pubsub.NewBroker[AgentEvent](),
- provider: agentProvider,
- messages: messages,
- sessions: sessions,
- tools: agentTools,
- titleProvider: titleProvider,
- summarizeProvider: summarizeProvider,
- activeRequests: sync.Map{},
+ Broker: pubsub.NewBroker[AgentEvent](),
+ agentCfg: agentCfg,
+ provider: agentProvider,
+ providerID: string(providerCfg.ID),
+ messages: messages,
+ sessions: sessions,
+ tools: agentTools,
+ titleProvider: titleProvider,
+ summarizeProvider: summarizeProvider,
+ summarizeProviderID: string(smallModelProviderCfg.ID),
+ activeRequests: sync.Map{},
}
return agent, nil
}
-func (a *agent) Model() models.Model {
+func (a *agent) Model() configv2.Model {
return a.provider.Model()
}
@@ -207,7 +336,7 @@ func (a *agent) err(err error) AgentEvent {
}
func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
- if !a.provider.Model().SupportsAttachments && attachments != nil {
+ if !a.provider.Model().SupportsImages && attachments != nil {
attachments = nil
}
events := make(chan AgentEvent)
@@ -327,9 +456,10 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- Model: a.provider.Model().ID,
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
+ Model: a.provider.Model().ID,
+ Provider: a.providerID,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
@@ -424,8 +554,9 @@ out:
parts = append(parts, tr)
}
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: parts,
+ Role: message.Tool,
+ Parts: parts,
+ Provider: a.providerID,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
@@ -484,7 +615,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
return nil
}
-func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error {
sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -506,21 +637,48 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
return nil
}
-func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
+func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) {
if a.IsBusy() {
- return models.Model{}, fmt.Errorf("cannot change model while processing requests")
+ return configv2.Model{}, fmt.Errorf("cannot change model while processing requests")
}
- if err := config.UpdateAgentModel(agentName, modelID); err != nil {
- return models.Model{}, fmt.Errorf("failed to update config: %w", err)
+ cfg := configv2.Get()
+ var providerCfg configv2.ProviderConfig
+ for _, p := range cfg.Providers {
+ if p.ID == modelCfg.Provider {
+ providerCfg = p
+ break
+ }
+ }
+ if providerCfg.ID == "" {
+ return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider)
}
- provider, err := createAgentProvider(agentName)
- if err != nil {
- return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
+ var model configv2.Model
+ for _, m := range providerCfg.Models {
+ if m.ID == modelCfg.ModelID {
+ model = m
+ break
+ }
+ }
+ if model.ID == "" {
+ return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider)
}
- a.provider = provider
+ promptID := agentPromptMap[a.agentCfg.ID]
+ if promptID == "" {
+ promptID = prompt.PromptDefault
+ }
+ opts := []provider.ProviderClientOption{
+ provider.WithModel(model),
+ provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
+ provider.WithMaxTokens(model.DefaultMaxTokens),
+ }
+ agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
+ if err != nil {
+ return configv2.Model{}, err
+ }
+ a.provider = agentProvider
return a.provider.Model(), nil
}
@@ -654,7 +812,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
Time: time.Now().Unix(),
},
},
- Model: a.summarizeProvider.Model().ID,
+ Model: a.summarizeProvider.Model().ID,
+ Provider: a.summarizeProviderID,
})
if err != nil {
event = AgentEvent{
@@ -705,51 +864,3 @@ func (a *agent) CancelAll() {
return true
})
}
-
-func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
- cfg := config.Get()
- agentConfig, ok := cfg.Agents[agentName]
- if !ok {
- return nil, fmt.Errorf("agent %s not found", agentName)
- }
- model, ok := models.SupportedModels[agentConfig.Model]
- if !ok {
- return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
- }
-
- providerCfg, ok := cfg.Providers[model.Provider]
- if !ok {
- return nil, fmt.Errorf("provider %s not supported", model.Provider)
- }
- if providerCfg.Disabled {
- return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
- }
- maxTokens := model.DefaultMaxTokens
- if agentConfig.MaxTokens > 0 {
- maxTokens = agentConfig.MaxTokens
- }
- opts := []provider.ProviderClientOption{
- provider.WithAPIKey(providerCfg.APIKey),
- provider.WithModel(model),
- provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
- provider.WithMaxTokens(maxTokens),
- }
- // TODO: reimplement
- // if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
- // opts = append(
- // opts,
- // provider.WithOpenAIOptions(
- // provider.WithReasoningEffort(agentConfig.ReasoningEffort),
- // ),
- // )
- // }
- agentProvider, err := provider.NewProvider(
- model.Provider,
- opts...,
- )
- if err != nil {
- return nil, fmt.Errorf("could not create provider: %v", err)
- }
-
- return agentProvider, nil
-}
@@ -18,7 +18,7 @@ import (
type mcpTool struct {
mcpName string
tool mcp.Tool
- mcpConfig config.MCPServer
+ mcpConfig config.MCP
permissions permission.Service
}
@@ -128,7 +128,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("invalid mcp type"), nil
}
-func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
+func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
return &mcpTool{
mcpName: name,
tool: tool,
@@ -139,7 +139,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC
var mcpTools []tools.BaseTool
-func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
+func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
@@ -170,7 +170,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
if len(mcpTools) > 0 {
return mcpTools
}
- for name, m := range config.Get().MCPServers {
+ for name, m := range config.Get().MCP {
switch m.Type {
case config.MCPStdio:
c, err := client.NewStdioMCPClient(
@@ -1,50 +0,0 @@
-package agent
-
-import (
- "context"
-
- "github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/llm/tools"
- "github.com/charmbracelet/crush/internal/lsp"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/session"
-)
-
-func CoderAgentTools(
- permissions permission.Service,
- sessions session.Service,
- messages message.Service,
- history history.Service,
- lspClients map[string]*lsp.Client,
-) []tools.BaseTool {
- ctx := context.Background()
- otherTools := GetMcpTools(ctx, permissions)
- if len(lspClients) > 0 {
- otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
- }
- return append(
- []tools.BaseTool{
- tools.NewBashTool(permissions),
- tools.NewEditTool(lspClients, permissions, history),
- tools.NewFetchTool(permissions),
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients),
- tools.NewWriteTool(lspClients, permissions, history),
- NewAgentTool(sessions, messages, lspClients),
- }, otherTools...,
- )
-}
-
-func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
- return []tools.BaseTool{
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients),
- }
-}
@@ -1,111 +0,0 @@
-package models
-
-const (
- ProviderAnthropic InferenceProvider = "anthropic"
-
- // Models
- Claude35Sonnet ModelID = "claude-3.5-sonnet"
- Claude3Haiku ModelID = "claude-3-haiku"
- Claude37Sonnet ModelID = "claude-3.7-sonnet"
- Claude35Haiku ModelID = "claude-3.5-haiku"
- Claude3Opus ModelID = "claude-3-opus"
- Claude4Opus ModelID = "claude-4-opus"
- Claude4Sonnet ModelID = "claude-4-sonnet"
-)
-
-// https://docs.anthropic.com/en/docs/about-claude/models/all-models
-var AnthropicModels = map[ModelID]Model{
- Claude35Sonnet: {
- ID: Claude35Sonnet,
- Name: "Claude 3.5 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-5-sonnet-latest",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 5000,
- SupportsAttachments: true,
- },
- Claude3Haiku: {
- ID: Claude3Haiku,
- Name: "Claude 3 Haiku",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
- CostPer1MIn: 0.25,
- CostPer1MInCached: 0.30,
- CostPer1MOutCached: 0.03,
- CostPer1MOut: 1.25,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
- Claude37Sonnet: {
- ID: Claude37Sonnet,
- Name: "Claude 3.7 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-7-sonnet-latest",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- Claude35Haiku: {
- ID: Claude35Haiku,
- Name: "Claude 3.5 Haiku",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-5-haiku-latest",
- CostPer1MIn: 0.80,
- CostPer1MInCached: 1.0,
- CostPer1MOutCached: 0.08,
- CostPer1MOut: 4.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
- Claude3Opus: {
- ID: Claude3Opus,
- Name: "Claude 3 Opus",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-opus-latest",
- CostPer1MIn: 15.0,
- CostPer1MInCached: 18.75,
- CostPer1MOutCached: 1.50,
- CostPer1MOut: 75.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
- Claude4Sonnet: {
- ID: Claude4Sonnet,
- Name: "Claude 4 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-sonnet-4-20250514",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- Claude4Opus: {
- ID: Claude4Opus,
- Name: "Claude 4 Opus",
- Provider: ProviderAnthropic,
- APIModel: "claude-opus-4-20250514",
- CostPer1MIn: 15.0,
- CostPer1MInCached: 18.75,
- CostPer1MOutCached: 1.50,
- CostPer1MOut: 75.0,
- ContextWindow: 200000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
-}
@@ -1,168 +0,0 @@
-package models
-
-const ProviderAzure InferenceProvider = "azure"
-
-const (
- AzureGPT41 ModelID = "azure.gpt-4.1"
- AzureGPT41Mini ModelID = "azure.gpt-4.1-mini"
- AzureGPT41Nano ModelID = "azure.gpt-4.1-nano"
- AzureGPT45Preview ModelID = "azure.gpt-4.5-preview"
- AzureGPT4o ModelID = "azure.gpt-4o"
- AzureGPT4oMini ModelID = "azure.gpt-4o-mini"
- AzureO1 ModelID = "azure.o1"
- AzureO1Mini ModelID = "azure.o1-mini"
- AzureO3 ModelID = "azure.o3"
- AzureO3Mini ModelID = "azure.o3-mini"
- AzureO4Mini ModelID = "azure.o4-mini"
-)
-
-var AzureModels = map[ModelID]Model{
- AzureGPT41: {
- ID: AzureGPT41,
- Name: "Azure OpenAI – GPT 4.1",
- Provider: ProviderAzure,
- APIModel: "gpt-4.1",
- CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT41Mini: {
- ID: AzureGPT41Mini,
- Name: "Azure OpenAI – GPT 4.1 mini",
- Provider: ProviderAzure,
- APIModel: "gpt-4.1-mini",
- CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT41Nano: {
- ID: AzureGPT41Nano,
- Name: "Azure OpenAI – GPT 4.1 nano",
- Provider: ProviderAzure,
- APIModel: "gpt-4.1-nano",
- CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT45Preview: {
- ID: AzureGPT45Preview,
- Name: "Azure OpenAI – GPT 4.5 preview",
- Provider: ProviderAzure,
- APIModel: "gpt-4.5-preview",
- CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT4o: {
- ID: AzureGPT4o,
- Name: "Azure OpenAI – GPT-4o",
- Provider: ProviderAzure,
- APIModel: "gpt-4o",
- CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT4o].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureGPT4oMini: {
- ID: AzureGPT4oMini,
- Name: "Azure OpenAI – GPT-4o mini",
- Provider: ProviderAzure,
- APIModel: "gpt-4o-mini",
- CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- AzureO1: {
- ID: AzureO1,
- Name: "Azure OpenAI – O1",
- Provider: ProviderAzure,
- APIModel: "o1",
- CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
- CanReason: OpenAIModels[O1].CanReason,
- SupportsAttachments: true,
- },
- AzureO1Mini: {
- ID: AzureO1Mini,
- Name: "Azure OpenAI – O1 mini",
- Provider: ProviderAzure,
- APIModel: "o1-mini",
- CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O1Mini].CanReason,
- SupportsAttachments: true,
- },
- AzureO3: {
- ID: AzureO3,
- Name: "Azure OpenAI – O3",
- Provider: ProviderAzure,
- APIModel: "o3",
- CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O3].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
- CanReason: OpenAIModels[O3].CanReason,
- SupportsAttachments: true,
- },
- AzureO3Mini: {
- ID: AzureO3Mini,
- Name: "Azure OpenAI – O3 mini",
- Provider: ProviderAzure,
- APIModel: "o3-mini",
- CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O3Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O3Mini].CanReason,
- SupportsAttachments: false,
- },
- AzureO4Mini: {
- ID: AzureO4Mini,
- Name: "Azure OpenAI – O4 mini",
- Provider: ProviderAzure,
- APIModel: "o4-mini",
- CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O4Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O4Mini].CanReason,
- SupportsAttachments: true,
- },
-}
@@ -1,67 +0,0 @@
-package models
-
-const (
- ProviderGemini InferenceProvider = "gemini"
-
- // Models
- Gemini25Flash ModelID = "gemini-2.5-flash"
- Gemini25 ModelID = "gemini-2.5"
- Gemini20Flash ModelID = "gemini-2.0-flash"
- Gemini20FlashLite ModelID = "gemini-2.0-flash-lite"
-)
-
-var GeminiModels = map[ModelID]Model{
- Gemini25Flash: {
- ID: Gemini25Flash,
- Name: "Gemini 2.5 Flash",
- Provider: ProviderGemini,
- APIModel: "gemini-2.5-flash-preview-04-17",
- CostPer1MIn: 0.15,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.60,
- ContextWindow: 1000000,
- DefaultMaxTokens: 50000,
- SupportsAttachments: true,
- },
- Gemini25: {
- ID: Gemini25,
- Name: "Gemini 2.5 Pro",
- Provider: ProviderGemini,
- APIModel: "gemini-2.5-pro-preview-05-06",
- CostPer1MIn: 1.25,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 10,
- ContextWindow: 1000000,
- DefaultMaxTokens: 50000,
- SupportsAttachments: true,
- },
-
- Gemini20Flash: {
- ID: Gemini20Flash,
- Name: "Gemini 2.0 Flash",
- Provider: ProviderGemini,
- APIModel: "gemini-2.0-flash",
- CostPer1MIn: 0.10,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.40,
- ContextWindow: 1000000,
- DefaultMaxTokens: 6000,
- SupportsAttachments: true,
- },
- Gemini20FlashLite: {
- ID: Gemini20FlashLite,
- Name: "Gemini 2.0 Flash Lite",
- Provider: ProviderGemini,
- APIModel: "gemini-2.0-flash-lite",
- CostPer1MIn: 0.05,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.30,
- ContextWindow: 1000000,
- DefaultMaxTokens: 6000,
- SupportsAttachments: true,
- },
-}
@@ -1,87 +0,0 @@
-package models
-
-const (
- ProviderGROQ InferenceProvider = "groq"
-
- // GROQ
- QWENQwq ModelID = "qwen-qwq"
-
- // GROQ preview models
- Llama4Scout ModelID = "meta-llama/llama-4-scout-17b-16e-instruct"
- Llama4Maverick ModelID = "meta-llama/llama-4-maverick-17b-128e-instruct"
- Llama3_3_70BVersatile ModelID = "llama-3.3-70b-versatile"
- DeepseekR1DistillLlama70b ModelID = "deepseek-r1-distill-llama-70b"
-)
-
-var GroqModels = map[ModelID]Model{
- //
- // GROQ
- QWENQwq: {
- ID: QWENQwq,
- Name: "Qwen Qwq",
- Provider: ProviderGROQ,
- APIModel: "qwen-qwq-32b",
- CostPer1MIn: 0.29,
- CostPer1MInCached: 0.275,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 0.39,
- ContextWindow: 128_000,
- DefaultMaxTokens: 50000,
- // for some reason, the groq api doesn't like the reasoningEffort parameter
- CanReason: false,
- SupportsAttachments: false,
- },
-
- Llama4Scout: {
- ID: Llama4Scout,
- Name: "Llama4Scout",
- Provider: ProviderGROQ,
- APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
- CostPer1MIn: 0.11,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.34,
- ContextWindow: 128_000, // 10M when?
- SupportsAttachments: true,
- },
-
- Llama4Maverick: {
- ID: Llama4Maverick,
- Name: "Llama4Maverick",
- Provider: ProviderGROQ,
- APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
- CostPer1MIn: 0.20,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.20,
- ContextWindow: 128_000,
- SupportsAttachments: true,
- },
-
- Llama3_3_70BVersatile: {
- ID: Llama3_3_70BVersatile,
- Name: "Llama3_3_70BVersatile",
- Provider: ProviderGROQ,
- APIModel: "llama-3.3-70b-versatile",
- CostPer1MIn: 0.59,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.79,
- ContextWindow: 128_000,
- SupportsAttachments: false,
- },
-
- DeepseekR1DistillLlama70b: {
- ID: DeepseekR1DistillLlama70b,
- Name: "DeepseekR1DistillLlama70b",
- Provider: ProviderGROQ,
- APIModel: "deepseek-r1-distill-llama-70b",
- CostPer1MIn: 0.75,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0.99,
- ContextWindow: 128_000,
- CanReason: true,
- SupportsAttachments: false,
- },
-}
@@ -1,206 +0,0 @@
-package models
-
-import (
- "cmp"
- "context"
- "encoding/json"
- "net/http"
- "net/url"
- "os"
- "regexp"
- "strings"
- "unicode"
-
- "github.com/charmbracelet/crush/internal/logging"
- "github.com/spf13/viper"
-)
-
-const (
- ProviderLocal InferenceProvider = "local"
-
- localModelsPath = "v1/models"
- lmStudioBetaModelsPath = "api/v0/models"
-)
-
-func init() {
- if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
- localEndpoint, err := url.Parse(endpoint)
- if err != nil {
- logging.Debug("Failed to parse local endpoint",
- "error", err,
- "endpoint", endpoint,
- )
- return
- }
-
- load := func(url *url.URL, path string) []localModel {
- url.Path = path
- return listLocalModels(url.String())
- }
-
- models := load(localEndpoint, lmStudioBetaModelsPath)
-
- if len(models) == 0 {
- models = load(localEndpoint, localModelsPath)
- }
-
- if len(models) == 0 {
- logging.Debug("No local models found",
- "endpoint", endpoint,
- )
- return
- }
-
- loadLocalModels(models)
-
- viper.SetDefault("providers.local.apiKey", "dummy")
- }
-}
-
-type localModelList struct {
- Data []localModel `json:"data"`
-}
-
-type localModel struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Type string `json:"type"`
- Publisher string `json:"publisher"`
- Arch string `json:"arch"`
- CompatibilityType string `json:"compatibility_type"`
- Quantization string `json:"quantization"`
- State string `json:"state"`
- MaxContextLength int64 `json:"max_context_length"`
- LoadedContextLength int64 `json:"loaded_context_length"`
-}
-
-func listLocalModels(modelsEndpoint string) []localModel {
- res, err := http.NewRequestWithContext(context.Background(), http.MethodGet, modelsEndpoint, nil)
- if err != nil {
- logging.Debug("Failed to list local models",
- "error", err,
- "endpoint", modelsEndpoint,
- )
- }
- defer res.Body.Close()
-
- if res.Response.StatusCode != http.StatusOK {
- logging.Debug("Failed to list local models",
- "status", res.Response.Status,
- "endpoint", modelsEndpoint,
- )
- }
-
- var modelList localModelList
- if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
- logging.Debug("Failed to list local models",
- "error", err,
- "endpoint", modelsEndpoint,
- )
- }
-
- var supportedModels []localModel
- for _, model := range modelList.Data {
- if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
- if model.Object != "model" || model.Type != "llm" {
- logging.Debug("Skipping unsupported LMStudio model",
- "endpoint", modelsEndpoint,
- "id", model.ID,
- "object", model.Object,
- "type", model.Type,
- )
-
- continue
- }
- }
-
- supportedModels = append(supportedModels, model)
- }
-
- return supportedModels
-}
-
-func loadLocalModels(models []localModel) {
- for i, m := range models {
- model := convertLocalModel(m)
- SupportedModels[model.ID] = model
-
- if i == 0 || m.State == "loaded" {
- viper.SetDefault("agents.coder.model", model.ID)
- viper.SetDefault("agents.summarizer.model", model.ID)
- viper.SetDefault("agents.task.model", model.ID)
- viper.SetDefault("agents.title.model", model.ID)
- }
- }
-}
-
-func convertLocalModel(model localModel) Model {
- return Model{
- ID: ModelID("local." + model.ID),
- Name: friendlyModelName(model.ID),
- Provider: ProviderLocal,
- APIModel: model.ID,
- ContextWindow: cmp.Or(model.LoadedContextLength, 4096),
- DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096),
- CanReason: true,
- SupportsAttachments: true,
- }
-}
-
-var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
-
-func friendlyModelName(modelID string) string {
- mainID := modelID
- tag := ""
-
- if slash := strings.LastIndex(mainID, "/"); slash != -1 {
- mainID = mainID[slash+1:]
- }
-
- if at := strings.Index(modelID, "@"); at != -1 {
- mainID = modelID[:at]
- tag = modelID[at+1:]
- }
-
- match := modelInfoRegex.FindStringSubmatch(mainID)
- if match == nil {
- return modelID
- }
-
- capitalize := func(s string) string {
- if s == "" {
- return ""
- }
- runes := []rune(s)
- runes[0] = unicode.ToUpper(runes[0])
- return string(runes)
- }
-
- family := capitalize(match[1])
- version := ""
- label := ""
-
- if len(match) > 2 && match[2] != "" {
- version = strings.ToUpper(match[2])
- }
-
- if len(match) > 3 && match[3] != "" {
- label = capitalize(match[3])
- }
-
- var parts []string
- if family != "" {
- parts = append(parts, family)
- }
- if version != "" {
- parts = append(parts, version)
- }
- if label != "" {
- parts = append(parts, label)
- }
- if tag != "" {
- parts = append(parts, tag)
- }
-
- return strings.Join(parts, " ")
-}
@@ -1,74 +0,0 @@
-package models
-
-import "maps"
-
-type (
- ModelID string
- InferenceProvider string
-)
-
-type Model struct {
- ID ModelID `json:"id"`
- Name string `json:"name"`
- Provider InferenceProvider `json:"provider"`
- APIModel string `json:"api_model"`
- CostPer1MIn float64 `json:"cost_per_1m_in"`
- CostPer1MOut float64 `json:"cost_per_1m_out"`
- CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
- CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
- ContextWindow int64 `json:"context_window"`
- DefaultMaxTokens int64 `json:"default_max_tokens"`
- CanReason bool `json:"can_reason"`
- SupportsAttachments bool `json:"supports_attachments"`
-}
-
-// Model IDs
-const ( // GEMINI
- // Bedrock
- BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
-)
-
-const (
- ProviderBedrock InferenceProvider = "bedrock"
- // ForTests
- ProviderMock InferenceProvider = "__mock"
-)
-
-var SupportedModels = map[ModelID]Model{
- // Bedrock
- BedrockClaude37Sonnet: {
- ID: BedrockClaude37Sonnet,
- Name: "Bedrock: Claude 3.7 Sonnet",
- Provider: ProviderBedrock,
- APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- },
-}
-
-var KnownProviders = []InferenceProvider{
- ProviderAnthropic,
- ProviderOpenAI,
- ProviderGemini,
- ProviderAzure,
- ProviderGROQ,
- ProviderLocal,
- ProviderOpenRouter,
- ProviderVertexAI,
- ProviderBedrock,
- ProviderXAI,
- ProviderMock,
-}
-
-func init() {
- maps.Copy(SupportedModels, AnthropicModels)
- maps.Copy(SupportedModels, OpenAIModels)
- maps.Copy(SupportedModels, GeminiModels)
- maps.Copy(SupportedModels, GroqModels)
- maps.Copy(SupportedModels, AzureModels)
- maps.Copy(SupportedModels, OpenRouterModels)
- maps.Copy(SupportedModels, XAIModels)
- maps.Copy(SupportedModels, VertexAIGeminiModels)
-}
@@ -1,181 +0,0 @@
-package models
-
-const (
- ProviderOpenAI InferenceProvider = "openai"
-
- GPT41 ModelID = "gpt-4.1"
- GPT41Mini ModelID = "gpt-4.1-mini"
- GPT41Nano ModelID = "gpt-4.1-nano"
- GPT45Preview ModelID = "gpt-4.5-preview"
- GPT4o ModelID = "gpt-4o"
- GPT4oMini ModelID = "gpt-4o-mini"
- O1 ModelID = "o1"
- O1Pro ModelID = "o1-pro"
- O1Mini ModelID = "o1-mini"
- O3 ModelID = "o3"
- O3Mini ModelID = "o3-mini"
- O4Mini ModelID = "o4-mini"
-)
-
-var OpenAIModels = map[ModelID]Model{
- GPT41: {
- ID: GPT41,
- Name: "GPT 4.1",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.1",
- CostPer1MIn: 2.00,
- CostPer1MInCached: 0.50,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 8.00,
- ContextWindow: 1_047_576,
- DefaultMaxTokens: 20000,
- SupportsAttachments: true,
- },
- GPT41Mini: {
- ID: GPT41Mini,
- Name: "GPT 4.1 mini",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.1",
- CostPer1MIn: 0.40,
- CostPer1MInCached: 0.10,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 1.60,
- ContextWindow: 200_000,
- DefaultMaxTokens: 20000,
- SupportsAttachments: true,
- },
- GPT41Nano: {
- ID: GPT41Nano,
- Name: "GPT 4.1 nano",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.1-nano",
- CostPer1MIn: 0.10,
- CostPer1MInCached: 0.025,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 0.40,
- ContextWindow: 1_047_576,
- DefaultMaxTokens: 20000,
- SupportsAttachments: true,
- },
- GPT45Preview: {
- ID: GPT45Preview,
- Name: "GPT 4.5 preview",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.5-preview",
- CostPer1MIn: 75.00,
- CostPer1MInCached: 37.50,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 150.00,
- ContextWindow: 128_000,
- DefaultMaxTokens: 15000,
- SupportsAttachments: true,
- },
- GPT4o: {
- ID: GPT4o,
- Name: "GPT 4o",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4o",
- CostPer1MIn: 2.50,
- CostPer1MInCached: 1.25,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 10.00,
- ContextWindow: 128_000,
- DefaultMaxTokens: 4096,
- SupportsAttachments: true,
- },
- GPT4oMini: {
- ID: GPT4oMini,
- Name: "GPT 4o mini",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4o-mini",
- CostPer1MIn: 0.15,
- CostPer1MInCached: 0.075,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 0.60,
- ContextWindow: 128_000,
- SupportsAttachments: true,
- },
- O1: {
- ID: O1,
- Name: "O1",
- Provider: ProviderOpenAI,
- APIModel: "o1",
- CostPer1MIn: 15.00,
- CostPer1MInCached: 7.50,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 60.00,
- ContextWindow: 200_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- O1Pro: {
- ID: O1Pro,
- Name: "o1 pro",
- Provider: ProviderOpenAI,
- APIModel: "o1-pro",
- CostPer1MIn: 150.00,
- CostPer1MInCached: 0.0,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 600.00,
- ContextWindow: 200_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- O1Mini: {
- ID: O1Mini,
- Name: "o1 mini",
- Provider: ProviderOpenAI,
- APIModel: "o1-mini",
- CostPer1MIn: 1.10,
- CostPer1MInCached: 0.55,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 4.40,
- ContextWindow: 128_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
- O3: {
- ID: O3,
- Name: "o3",
- Provider: ProviderOpenAI,
- APIModel: "o3",
- CostPer1MIn: 10.00,
- CostPer1MInCached: 2.50,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 40.00,
- ContextWindow: 200_000,
- CanReason: true,
- SupportsAttachments: true,
- },
- O3Mini: {
- ID: O3Mini,
- Name: "o3 mini",
- Provider: ProviderOpenAI,
- APIModel: "o3-mini",
- CostPer1MIn: 1.10,
- CostPer1MInCached: 0.55,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 4.40,
- ContextWindow: 200_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: false,
- },
- O4Mini: {
- ID: O4Mini,
- Name: "o4 mini",
- Provider: ProviderOpenAI,
- APIModel: "o4-mini",
- CostPer1MIn: 1.10,
- CostPer1MInCached: 0.275,
- CostPer1MOutCached: 0.0,
- CostPer1MOut: 4.40,
- ContextWindow: 128_000,
- DefaultMaxTokens: 50000,
- CanReason: true,
- SupportsAttachments: true,
- },
-}
@@ -1,276 +0,0 @@
-package models
-
-const (
- ProviderOpenRouter InferenceProvider = "openrouter"
-
- OpenRouterGPT41 ModelID = "openrouter.gpt-4.1"
- OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini"
- OpenRouterGPT41Nano ModelID = "openrouter.gpt-4.1-nano"
- OpenRouterGPT45Preview ModelID = "openrouter.gpt-4.5-preview"
- OpenRouterGPT4o ModelID = "openrouter.gpt-4o"
- OpenRouterGPT4oMini ModelID = "openrouter.gpt-4o-mini"
- OpenRouterO1 ModelID = "openrouter.o1"
- OpenRouterO1Pro ModelID = "openrouter.o1-pro"
- OpenRouterO1Mini ModelID = "openrouter.o1-mini"
- OpenRouterO3 ModelID = "openrouter.o3"
- OpenRouterO3Mini ModelID = "openrouter.o3-mini"
- OpenRouterO4Mini ModelID = "openrouter.o4-mini"
- OpenRouterGemini25Flash ModelID = "openrouter.gemini-2.5-flash"
- OpenRouterGemini25 ModelID = "openrouter.gemini-2.5"
- OpenRouterClaude35Sonnet ModelID = "openrouter.claude-3.5-sonnet"
- OpenRouterClaude3Haiku ModelID = "openrouter.claude-3-haiku"
- OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet"
- OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku"
- OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus"
- OpenRouterDeepSeekR1Free ModelID = "openrouter.deepseek-r1-free"
-)
-
-var OpenRouterModels = map[ModelID]Model{
- OpenRouterGPT41: {
- ID: OpenRouterGPT41,
- Name: "OpenRouter – GPT 4.1",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4.1",
- CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
- },
- OpenRouterGPT41Mini: {
- ID: OpenRouterGPT41Mini,
- Name: "OpenRouter – GPT 4.1 mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4.1-mini",
- CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
- },
- OpenRouterGPT41Nano: {
- ID: OpenRouterGPT41Nano,
- Name: "OpenRouter – GPT 4.1 nano",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4.1-nano",
- CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
- },
- OpenRouterGPT45Preview: {
- ID: OpenRouterGPT45Preview,
- Name: "OpenRouter – GPT 4.5 preview",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4.5-preview",
- CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
- },
- OpenRouterGPT4o: {
- ID: OpenRouterGPT4o,
- Name: "OpenRouter – GPT 4o",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4o",
- CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT4o].ContextWindow,
- DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
- },
- OpenRouterGPT4oMini: {
- ID: OpenRouterGPT4oMini,
- Name: "OpenRouter – GPT 4o mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/gpt-4o-mini",
- CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
- },
- OpenRouterO1: {
- ID: OpenRouterO1,
- Name: "OpenRouter – O1",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o1",
- CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
- CanReason: OpenAIModels[O1].CanReason,
- },
- OpenRouterO1Pro: {
- ID: OpenRouterO1Pro,
- Name: "OpenRouter – o1 pro",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o1-pro",
- CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1Pro].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1Pro].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1Pro].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1Pro].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1Pro].DefaultMaxTokens,
- CanReason: OpenAIModels[O1Pro].CanReason,
- },
- OpenRouterO1Mini: {
- ID: OpenRouterO1Mini,
- Name: "OpenRouter – o1 mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o1-mini",
- CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O1Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O1Mini].CanReason,
- },
- OpenRouterO3: {
- ID: OpenRouterO3,
- Name: "OpenRouter – o3",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o3",
- CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O3].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
- CanReason: OpenAIModels[O3].CanReason,
- },
- OpenRouterO3Mini: {
- ID: OpenRouterO3Mini,
- Name: "OpenRouter – o3 mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o3-mini-high",
- CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O3Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O3Mini].CanReason,
- },
- OpenRouterO4Mini: {
- ID: OpenRouterO4Mini,
- Name: "OpenRouter – o4 mini",
- Provider: ProviderOpenRouter,
- APIModel: "openai/o4-mini-high",
- CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
- CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
- CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
- CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
- ContextWindow: OpenAIModels[O4Mini].ContextWindow,
- DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
- CanReason: OpenAIModels[O4Mini].CanReason,
- },
- OpenRouterGemini25Flash: {
- ID: OpenRouterGemini25Flash,
- Name: "OpenRouter – Gemini 2.5 Flash",
- Provider: ProviderOpenRouter,
- APIModel: "google/gemini-2.5-flash-preview:thinking",
- CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
- CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
- CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
- CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
- ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
- DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
- },
- OpenRouterGemini25: {
- ID: OpenRouterGemini25,
- Name: "OpenRouter – Gemini 2.5 Pro",
- Provider: ProviderOpenRouter,
- APIModel: "google/gemini-2.5-pro-preview-03-25",
- CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
- CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
- CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
- CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
- ContextWindow: GeminiModels[Gemini25].ContextWindow,
- DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
- },
- OpenRouterClaude35Sonnet: {
- ID: OpenRouterClaude35Sonnet,
- Name: "OpenRouter – Claude 3.5 Sonnet",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3.5-sonnet",
- CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude35Sonnet].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude35Sonnet].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude35Sonnet].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude35Sonnet].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude35Sonnet].DefaultMaxTokens,
- },
- OpenRouterClaude3Haiku: {
- ID: OpenRouterClaude3Haiku,
- Name: "OpenRouter – Claude 3 Haiku",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3-haiku",
- CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude3Haiku].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude3Haiku].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude3Haiku].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude3Haiku].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude3Haiku].DefaultMaxTokens,
- },
- OpenRouterClaude37Sonnet: {
- ID: OpenRouterClaude37Sonnet,
- Name: "OpenRouter – Claude 3.7 Sonnet",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3.7-sonnet",
- CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude37Sonnet].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude37Sonnet].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude37Sonnet].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude37Sonnet].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude37Sonnet].DefaultMaxTokens,
- CanReason: AnthropicModels[Claude37Sonnet].CanReason,
- },
- OpenRouterClaude35Haiku: {
- ID: OpenRouterClaude35Haiku,
- Name: "OpenRouter – Claude 3.5 Haiku",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3.5-haiku",
- CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude35Haiku].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude35Haiku].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude35Haiku].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude35Haiku].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude35Haiku].DefaultMaxTokens,
- },
- OpenRouterClaude3Opus: {
- ID: OpenRouterClaude3Opus,
- Name: "OpenRouter – Claude 3 Opus",
- Provider: ProviderOpenRouter,
- APIModel: "anthropic/claude-3-opus",
- CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn,
- CostPer1MInCached: AnthropicModels[Claude3Opus].CostPer1MInCached,
- CostPer1MOut: AnthropicModels[Claude3Opus].CostPer1MOut,
- CostPer1MOutCached: AnthropicModels[Claude3Opus].CostPer1MOutCached,
- ContextWindow: AnthropicModels[Claude3Opus].ContextWindow,
- DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens,
- },
-
- OpenRouterDeepSeekR1Free: {
- ID: OpenRouterDeepSeekR1Free,
- Name: "OpenRouter – DeepSeek R1 Free",
- Provider: ProviderOpenRouter,
- APIModel: "deepseek/deepseek-r1-0528:free",
- CostPer1MIn: 0,
- CostPer1MInCached: 0,
- CostPer1MOut: 0,
- CostPer1MOutCached: 0,
- ContextWindow: 163_840,
- DefaultMaxTokens: 10000,
- },
-}
@@ -1,38 +0,0 @@
-package models
-
-const (
- ProviderVertexAI InferenceProvider = "vertexai"
-
- // Models
- VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash"
- VertexAIGemini25 ModelID = "vertexai.gemini-2.5"
-)
-
-var VertexAIGeminiModels = map[ModelID]Model{
- VertexAIGemini25Flash: {
- ID: VertexAIGemini25Flash,
- Name: "VertexAI: Gemini 2.5 Flash",
- Provider: ProviderVertexAI,
- APIModel: "gemini-2.5-flash-preview-04-17",
- CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
- CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
- CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
- CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
- ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
- DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
- SupportsAttachments: true,
- },
- VertexAIGemini25: {
- ID: VertexAIGemini25,
- Name: "VertexAI: Gemini 2.5 Pro",
- Provider: ProviderVertexAI,
- APIModel: "gemini-2.5-pro-preview-03-25",
- CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
- CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
- CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
- CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
- ContextWindow: GeminiModels[Gemini25].ContextWindow,
- DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
- SupportsAttachments: true,
- },
-}
@@ -1,61 +0,0 @@
-package models
-
-const (
- ProviderXAI InferenceProvider = "xai"
-
- XAIGrok3Beta ModelID = "grok-3-beta"
- XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
- XAIGrok3FastBeta ModelID = "grok-3-fast-beta"
- XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
-)
-
-var XAIModels = map[ModelID]Model{
- XAIGrok3Beta: {
- ID: XAIGrok3Beta,
- Name: "Grok3 Beta",
- Provider: ProviderXAI,
- APIModel: "grok-3-beta",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 0,
- CostPer1MOut: 15,
- CostPer1MOutCached: 0,
- ContextWindow: 131_072,
- DefaultMaxTokens: 20_000,
- },
- XAIGrok3MiniBeta: {
- ID: XAIGrok3MiniBeta,
- Name: "Grok3 Mini Beta",
- Provider: ProviderXAI,
- APIModel: "grok-3-mini-beta",
- CostPer1MIn: 0.3,
- CostPer1MInCached: 0,
- CostPer1MOut: 0.5,
- CostPer1MOutCached: 0,
- ContextWindow: 131_072,
- DefaultMaxTokens: 20_000,
- },
- XAIGrok3FastBeta: {
- ID: XAIGrok3FastBeta,
- Name: "Grok3 Fast Beta",
- Provider: ProviderXAI,
- APIModel: "grok-3-fast-beta",
- CostPer1MIn: 5,
- CostPer1MInCached: 0,
- CostPer1MOut: 25,
- CostPer1MOutCached: 0,
- ContextWindow: 131_072,
- DefaultMaxTokens: 20_000,
- },
- XAiGrok3MiniFastBeta: {
- ID: XAiGrok3MiniFastBeta,
- Name: "Grok3 Mini Fast Beta",
- Provider: ProviderXAI,
- APIModel: "grok-3-mini-fast-beta",
- CostPer1MIn: 0.6,
- CostPer1MInCached: 0,
- CostPer1MOut: 4.0,
- CostPer1MOutCached: 0,
- ContextWindow: 131_072,
- DefaultMaxTokens: 20_000,
- },
-}
@@ -9,19 +9,27 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
+ "github.com/charmbracelet/crush/internal/logging"
)
-func CoderPrompt(provider models.InferenceProvider) string {
+func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string {
basePrompt := baseAnthropicCoderPrompt
- switch provider {
- case models.ProviderOpenAI:
+ switch p {
+ case provider.InferenceProviderOpenAI:
basePrompt = baseOpenAICoderPrompt
}
envInfo := getEnvironmentInfo()
- return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
+ basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
+
+ contextContent := getContextFromPaths(contextFiles)
+ logging.Debug("Context content", "Context", contextContent)
+ if contextContent != "" {
+ return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
+ }
+ return basePrompt
}
const baseOpenAICoderPrompt = `
@@ -1,60 +1,44 @@
package prompt
import (
- "fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
- "github.com/charmbracelet/crush/internal/logging"
+ "github.com/charmbracelet/crush/internal/fur/provider"
)
-func GetAgentPrompt(agentName config.AgentName, provider models.InferenceProvider) string {
+type PromptID string
+
+const (
+ PromptCoder PromptID = "coder"
+ PromptTitle PromptID = "title"
+ PromptTask PromptID = "task"
+ PromptSummarizer PromptID = "summarizer"
+ PromptDefault PromptID = "default"
+)
+
+func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string {
basePrompt := ""
- switch agentName {
- case config.AgentCoder:
+ switch promptID {
+ case PromptCoder:
basePrompt = CoderPrompt(provider)
- case config.AgentTitle:
+ case PromptTitle:
basePrompt = TitlePrompt(provider)
- case config.AgentTask:
+ case PromptTask:
basePrompt = TaskPrompt(provider)
- case config.AgentSummarizer:
+ case PromptSummarizer:
basePrompt = SummarizerPrompt(provider)
default:
basePrompt = "You are a helpful assistant"
}
-
- if agentName == config.AgentCoder || agentName == config.AgentTask {
- // Add context from project-specific instruction files if they exist
- contextContent := getContextFromPaths()
- logging.Debug("Context content", "Context", contextContent)
- if contextContent != "" {
- return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
- }
- }
return basePrompt
}
-var (
- onceContext sync.Once
- contextContent string
-)
-
-func getContextFromPaths() string {
- onceContext.Do(func() {
- var (
- cfg = config.Get()
- workDir = cfg.WorkingDir
- contextPaths = cfg.ContextPaths
- )
-
- contextContent = processContextPaths(workDir, contextPaths)
- })
-
- return contextContent
+func getContextFromPaths(contextPaths []string) string {
+ return processContextPaths(config.WorkingDirectory(), contextPaths)
}
func processContextPaths(workDir string, paths []string) string {
@@ -15,16 +15,10 @@ func TestGetContextFromPaths(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
- _, err := config.Load(tmpDir, false)
+ _, err := config.Init(tmpDir, false)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
- cfg := config.Get()
- cfg.WorkingDir = tmpDir
- cfg.ContextPaths = []string{
- "file.txt",
- "directory/",
- }
testFiles := []string{
"file.txt",
"directory/file_a.txt",
@@ -34,7 +28,12 @@ func TestGetContextFromPaths(t *testing.T) {
createTestFiles(t, tmpDir, testFiles)
- context := getContextFromPaths()
+ context := getContextFromPaths(
+ []string{
+ "file.txt",
+ "directory/",
+ },
+ )
expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir)
assert.Equal(t, expectedContext, context)
}
@@ -1,8 +1,10 @@
package prompt
-import "github.com/charmbracelet/crush/internal/llm/models"
+import (
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
-func SummarizerPrompt(_ models.InferenceProvider) string {
+func SummarizerPrompt(_ provider.InferenceProvider) string {
return `You are a helpful AI assistant tasked with summarizing conversations.
When asked to summarize, provide a detailed but concise summary of the conversation.
@@ -3,10 +3,10 @@ package prompt
import (
"fmt"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
)
-func TaskPrompt(_ models.InferenceProvider) string {
+func TaskPrompt(_ provider.InferenceProvider) string {
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
@@ -1,8 +1,10 @@
package prompt
-import "github.com/charmbracelet/crush/internal/llm/models"
+import (
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
-func TitlePrompt(_ models.InferenceProvider) string {
+func TitlePrompt(_ provider.InferenceProvider) string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message
@@ -13,7 +13,7 @@ import (
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
@@ -59,7 +59,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
var contentBlocks []anthropic.ContentBlockParamUnion
contentBlocks = append(contentBlocks, content)
for _, binaryContent := range msg.BinaryContent() {
- base64Image := binaryContent.String(models.ProviderAnthropic)
+ base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
contentBlocks = append(contentBlocks, imageBlock)
}
@@ -164,7 +164,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
// }
return anthropic.MessageNewParams{
- Model: anthropic.Model(a.providerOptions.model.APIModel),
+ Model: anthropic.Model(a.providerOptions.model.ID),
MaxTokens: a.providerOptions.maxTokens,
Temperature: temperature,
Messages: messages,
@@ -184,7 +184,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(preparedMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -233,7 +233,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
// jsonData, _ := json.Marshal(preparedMessages)
// logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
- "os"
"strings"
"github.com/charmbracelet/crush/internal/llm/tools"
@@ -19,14 +18,8 @@ type bedrockClient struct {
type BedrockClient ProviderClient
func newBedrockClient(opts providerClientOptions) BedrockClient {
- // Apply bedrock specific options if they are added in the future
-
// Get AWS region from environment
- region := os.Getenv("AWS_REGION")
- if region == "" {
- region = os.Getenv("AWS_DEFAULT_REGION")
- }
-
+ region := opts.extraParams["region"]
if region == "" {
region = "us-east-1" // default region
}
@@ -39,11 +32,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
// Prefix the model name with region
regionPrefix := region[:2]
- modelName := opts.model.APIModel
- opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
+ modelName := opts.model.ID
+ opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
// Determine which provider to use based on the model
- if strings.Contains(string(opts.model.APIModel), "anthropic") {
+ if strings.Contains(string(opts.model.ID), "anthropic") {
// Create Anthropic client with Bedrock configuration
anthropicOpts := opts
// TODO: later find a way to check if the AWS account has caching enabled
@@ -157,7 +157,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
geminiMessages := g.convertMessages(messages)
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -173,7 +173,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
- chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history)
attempts := 0
for {
@@ -245,7 +245,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
geminiMessages := g.convertMessages(messages)
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -261,7 +261,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
- chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history)
attempts := 0
eventChan := make(chan ProviderEvent)
@@ -9,7 +9,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
@@ -68,7 +68,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
for _, binaryContent := range msg.BinaryContent() {
- imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
+ imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
@@ -153,7 +153,7 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason {
func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
params := openai.ChatCompletionNewParams{
- Model: openai.ChatModel(o.providerOptions.model.APIModel),
+ Model: openai.ChatModel(o.providerOptions.model.ID),
Messages: messages,
Tools: tools,
}
@@ -180,7 +180,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(params)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -237,7 +237,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
}
cfg := config.Get()
- if cfg.Debug {
+ if cfg.Options.Debug {
jsonData, _ := json.Marshal(params)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -3,9 +3,9 @@ package provider
import (
"context"
"fmt"
- "os"
- "github.com/charmbracelet/crush/internal/llm/models"
+ configv2 "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
@@ -55,17 +55,18 @@ type Provider interface {
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
- Model() models.Model
+ Model() configv2.Model
}
type providerClientOptions struct {
baseURL string
apiKey string
- model models.Model
+ model configv2.Model
disableCache bool
maxTokens int64
systemMessage string
extraHeaders map[string]string
+ extraParams map[string]string
}
type ProviderClientOption func(*providerClientOptions)
@@ -80,77 +81,6 @@ type baseProvider[C ProviderClient] struct {
client C
}
-func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) {
- clientOptions := providerClientOptions{}
- for _, o := range opts {
- o(&clientOptions)
- }
- switch providerName {
- case models.ProviderAnthropic:
- return &baseProvider[AnthropicClient]{
- options: clientOptions,
- client: newAnthropicClient(clientOptions, false),
- }, nil
- case models.ProviderOpenAI:
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderGemini:
- return &baseProvider[GeminiClient]{
- options: clientOptions,
- client: newGeminiClient(clientOptions),
- }, nil
- case models.ProviderBedrock:
- return &baseProvider[BedrockClient]{
- options: clientOptions,
- client: newBedrockClient(clientOptions),
- }, nil
- case models.ProviderGROQ:
- clientOptions.baseURL = "https://api.groq.com/openai/v1"
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderAzure:
- return &baseProvider[AzureClient]{
- options: clientOptions,
- client: newAzureClient(clientOptions),
- }, nil
- case models.ProviderVertexAI:
- return &baseProvider[VertexAIClient]{
- options: clientOptions,
- client: newVertexAIClient(clientOptions),
- }, nil
- case models.ProviderOpenRouter:
- clientOptions.baseURL = "https://openrouter.ai/api/v1"
- clientOptions.extraHeaders = map[string]string{
- "HTTP-Referer": "crush.charm.land",
- "X-Title": "Crush",
- }
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderXAI:
- clientOptions.baseURL = "https://api.x.ai/v1"
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderLocal:
- clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT")
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderMock:
- // TODO: implement mock client for test
- panic("not implemented")
- }
- return nil, fmt.Errorf("provider not supported: %s", providerName)
-}
-
func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
for _, msg := range messages {
// The message has no content
@@ -167,7 +97,7 @@ func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.M
return p.client.send(ctx, messages, tools)
}
-func (p *baseProvider[C]) Model() models.Model {
+func (p *baseProvider[C]) Model() configv2.Model {
return p.options.model
}
@@ -176,7 +106,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
return p.client.stream(ctx, messages, tools)
}
-func WithModel(model models.Model) ProviderClientOption {
+func WithModel(model configv2.Model) ProviderClientOption {
return func(options *providerClientOptions) {
options.model = model
}
@@ -199,3 +129,53 @@ func WithSystemMessage(systemMessage string) ProviderClientOption {
options.systemMessage = systemMessage
}
}
+
+func NewProviderV2(cfg configv2.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
+ clientOptions := providerClientOptions{
+ baseURL: cfg.BaseURL,
+ apiKey: cfg.APIKey,
+ extraHeaders: cfg.ExtraHeaders,
+ }
+ for _, o := range opts {
+ o(&clientOptions)
+ }
+ switch cfg.ProviderType {
+ case provider.TypeAnthropic:
+ return &baseProvider[AnthropicClient]{
+ options: clientOptions,
+ client: newAnthropicClient(clientOptions, false),
+ }, nil
+ case provider.TypeOpenAI:
+ return &baseProvider[OpenAIClient]{
+ options: clientOptions,
+ client: newOpenAIClient(clientOptions),
+ }, nil
+ case provider.TypeGemini:
+ return &baseProvider[GeminiClient]{
+ options: clientOptions,
+ client: newGeminiClient(clientOptions),
+ }, nil
+ case provider.TypeBedrock:
+ return &baseProvider[BedrockClient]{
+ options: clientOptions,
+ client: newBedrockClient(clientOptions),
+ }, nil
+ case provider.TypeAzure:
+ return &baseProvider[AzureClient]{
+ options: clientOptions,
+ client: newAzureClient(clientOptions),
+ }, nil
+ case provider.TypeVertexAI:
+ return &baseProvider[VertexAIClient]{
+ options: clientOptions,
+ client: newVertexAIClient(clientOptions),
+ }, nil
+ case provider.TypeXAI:
+ clientOptions.baseURL = "https://api.x.ai/v1"
+ return &baseProvider[OpenAIClient]{
+ options: clientOptions,
+ client: newOpenAIClient(clientOptions),
+ }, nil
+ }
+ return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
+}
@@ -2,7 +2,6 @@ package provider
import (
"context"
- "os"
"github.com/charmbracelet/crush/internal/logging"
"google.golang.org/genai"
@@ -11,9 +10,11 @@ import (
type VertexAIClient ProviderClient
func newVertexAIClient(opts providerClientOptions) VertexAIClient {
+ project := opts.extraHeaders["project"]
+ location := opts.extraHeaders["location"]
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
- Project: os.Getenv("GOOGLE_CLOUD_PROJECT"),
- Location: os.Getenv("GOOGLE_CLOUD_LOCATION"),
+ Project: project,
+ Location: location,
Backend: genai.BackendVertexAI,
})
if err != nil {
@@ -286,7 +286,7 @@ func (c *Client) SetServerState(state ServerState) {
// WaitForServerReady waits for the server to be ready by polling the server
// with a simple request until it responds successfully or times out
func (c *Client) WaitForServerReady(ctx context.Context) error {
- cnf := config.Get()
+ cfg := config.Get()
// Set initial state
c.SetServerState(StateStarting)
@@ -299,7 +299,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Waiting for LSP server to be ready...")
}
@@ -308,7 +308,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
// For TypeScript-like servers, we need to open some key files first
if serverType == ServerTypeTypeScript {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("TypeScript-like server detected, opening key configuration files")
}
c.openKeyConfigFiles(ctx)
@@ -325,7 +325,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
if err == nil {
// Server responded successfully
c.SetServerState(StateReady)
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("LSP server is ready")
}
return nil
@@ -333,7 +333,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
}
}
@@ -496,7 +496,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
// openTypeScriptFiles finds and opens TypeScript files to help initialize the server
func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
- cnf := config.Get()
+ cfg := config.Get()
filesOpened := 0
maxFilesToOpen := 5 // Limit to a reasonable number of files
@@ -526,7 +526,7 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
// Try to open the file
if err := c.OpenFile(ctx, path); err == nil {
filesOpened++
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opened TypeScript file for initialization", "file", path)
}
}
@@ -535,11 +535,11 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
return nil
})
- if err != nil && cnf.DebugLSP {
+ if err != nil && cfg.Options.DebugLSP {
logging.Debug("Error walking directory for TypeScript files", "error", err)
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opened TypeScript files for initialization", "count", filesOpened)
}
}
@@ -664,7 +664,7 @@ func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
}
func (c *Client) CloseFile(ctx context.Context, filepath string) error {
- cnf := config.Get()
+ cfg := config.Get()
uri := string(protocol.URIFromPath(filepath))
c.openFilesMu.Lock()
@@ -680,7 +680,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
},
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Closing file", "file", filepath)
}
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
@@ -704,7 +704,7 @@ func (c *Client) IsFileOpen(filepath string) bool {
// CloseAllFiles closes all currently open files
func (c *Client) CloseAllFiles(ctx context.Context) {
- cnf := config.Get()
+ cfg := config.Get()
c.openFilesMu.Lock()
filesToClose := make([]string, 0, len(c.openFiles))
@@ -719,12 +719,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
// Then close them all
for _, filePath := range filesToClose {
err := c.CloseFile(ctx, filePath)
- if err != nil && cnf.DebugLSP {
+ if err != nil && cfg.Options.DebugLSP {
logging.Warn("Error closing file", "file", filePath, "error", err)
}
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Closed all files", "files", filesToClose)
}
}
@@ -82,13 +82,13 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche
// Notifications
func HandleServerMessage(params json.RawMessage) {
- cnf := config.Get()
+ cfg := config.Get()
var msg struct {
Type int `json:"type"`
Message string `json:"message"`
}
if err := json.Unmarshal(params, &msg); err == nil {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
}
}
@@ -18,9 +18,9 @@ func WriteMessage(w io.Writer, msg *Message) error {
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
- cnf := config.Get()
+ cfg := config.Get()
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
}
@@ -39,7 +39,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
// ReadMessage reads a single LSP message from the given reader
func ReadMessage(r *bufio.Reader) (*Message, error) {
- cnf := config.Get()
+ cfg := config.Get()
// Read headers
var contentLength int
for {
@@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
line = strings.TrimSpace(line)
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received header", "line", line)
}
@@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Content-Length", "length", contentLength)
}
@@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
return nil, fmt.Errorf("failed to read content: %w", err)
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received content", "content", string(content))
}
@@ -91,11 +91,11 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
// handleMessages reads and dispatches messages in a loop
func (c *Client) handleMessages() {
- cnf := config.Get()
+ cfg := config.Get()
for {
msg, err := ReadMessage(c.stdout)
if err != nil {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Error("Error reading message", "error", err)
}
return
@@ -103,7 +103,7 @@ func (c *Client) handleMessages() {
// Handle server->client request (has both Method and ID)
if msg.Method != "" && msg.ID != 0 {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
}
@@ -157,11 +157,11 @@ func (c *Client) handleMessages() {
c.notificationMu.RUnlock()
if ok {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Handling notification", "method", msg.Method)
}
go handler(msg.Params)
- } else if cnf.DebugLSP {
+ } else if cfg.Options.DebugLSP {
logging.Debug("No handler for notification", "method", msg.Method)
}
continue
@@ -174,12 +174,12 @@ func (c *Client) handleMessages() {
c.handlersMu.RUnlock()
if ok {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received response for request", "id", msg.ID)
}
ch <- msg
close(ch)
- } else if cnf.DebugLSP {
+ } else if cfg.Options.DebugLSP {
logging.Debug("No handler for response", "id", msg.ID)
}
}
@@ -188,10 +188,10 @@ func (c *Client) handleMessages() {
// Call makes a request and waits for the response
func (c *Client) Call(ctx context.Context, method string, params any, result any) error {
- cnf := config.Get()
+ cfg := config.Get()
id := c.nextID.Add(1)
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Making call", "method", method, "id", id)
}
@@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
return fmt.Errorf("failed to send request: %w", err)
}
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Request sent", "method", method, "id", id)
}
// Wait for response
resp := <-ch
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Received response", "id", id)
}
@@ -249,8 +249,8 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
// Notify sends a notification (a request without an ID that doesn't expect a response)
func (c *Client) Notify(ctx context.Context, method string, params any) error {
- cnf := config.Get()
- if cnf.DebugLSP {
+ cfg := config.Get()
+ if cfg.Options.DebugLSP {
logging.Debug("Sending notification", "method", method)
}
@@ -43,7 +43,7 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher {
// AddRegistrations adds file watchers to track
func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) {
- cnf := config.Get()
+ cfg := config.Get()
logging.Debug("Adding file watcher registrations")
w.registrationMu.Lock()
@@ -53,7 +53,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
w.registrations = append(w.registrations, watchers...)
// Print detailed registration information for debugging
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Adding file watcher registrations",
"id", id,
"watchers", len(watchers),
@@ -122,7 +122,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
highPriorityFilesOpened := w.openHighPriorityFiles(ctx, serverName)
filesOpened += highPriorityFilesOpened
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opened high-priority files",
"count", highPriorityFilesOpened,
"serverName", serverName)
@@ -130,7 +130,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// If we've already opened enough high-priority files, we might not need more
if filesOpened >= maxFilesToOpen {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Reached file limit with high-priority files",
"filesOpened", filesOpened,
"maxFiles", maxFilesToOpen)
@@ -148,7 +148,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// Skip directories that should be excluded
if d.IsDir() {
if path != w.workspacePath && shouldExcludeDir(path) {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
@@ -176,7 +176,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
})
elapsedTime := time.Since(startTime)
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Limited workspace scan complete",
"filesOpened", filesOpened,
"maxFiles", maxFilesToOpen,
@@ -185,11 +185,11 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
)
}
- if err != nil && cnf.DebugLSP {
+ if err != nil && cfg.Options.DebugLSP {
logging.Debug("Error scanning workspace for files to open", "error", err)
}
}()
- } else if cnf.DebugLSP {
+ } else if cfg.Options.DebugLSP {
logging.Debug("Using on-demand file loading for server", "server", serverName)
}
}
@@ -197,7 +197,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// openHighPriorityFiles opens important files for the server type
// Returns the number of files opened
func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName string) int {
- cnf := config.Get()
+ cfg := config.Get()
filesOpened := 0
// Define patterns for high-priority files based on server type
@@ -265,7 +265,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
// Use doublestar.Glob to find files matching the pattern (supports ** patterns)
matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern)
if err != nil {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err)
}
continue
@@ -299,12 +299,12 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
for j := i; j < end; j++ {
fullPath := filesToOpen[j]
if err := w.client.OpenFile(ctx, fullPath); err != nil {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Error opening high-priority file", "path", fullPath, "error", err)
}
} else {
filesOpened++
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opened high-priority file", "path", fullPath)
}
}
@@ -321,7 +321,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
// WatchWorkspace sets up file watching for a workspace
func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath string) {
- cnf := config.Get()
+ cfg := config.Get()
w.workspacePath = workspacePath
// Store the watcher in the context for later use
@@ -356,7 +356,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
// Skip excluded directories (except workspace root)
if d.IsDir() && path != workspacePath {
if shouldExcludeDir(path) {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
@@ -409,7 +409,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
}
// Debug logging
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
matched, kind := w.isPathWatched(event.Name)
logging.Debug("File event",
"path", event.Name,
@@ -676,8 +676,8 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
// notifyFileEvent sends a didChangeWatchedFiles notification for a file event
func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
- cnf := config.Get()
- if cnf.DebugLSP {
+ cfg := config.Get()
+ if cfg.Options.DebugLSP {
logging.Debug("Notifying file event",
"uri", uri,
"changeType", changeType,
@@ -826,7 +826,7 @@ func shouldExcludeDir(dirPath string) bool {
// shouldExcludeFile returns true if the file should be excluded from opening
func shouldExcludeFile(filePath string) bool {
fileName := filepath.Base(filePath)
- cnf := config.Get()
+ cfg := config.Get()
// Skip dot files
if strings.HasPrefix(fileName, ".") {
return true
@@ -852,12 +852,12 @@ func shouldExcludeFile(filePath string) bool {
// Skip large files
if info.Size() > maxFileSize {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Skipping large file",
"path", filePath,
"size", info.Size(),
"maxSize", maxFileSize,
- "debug", cnf.Debug,
+ "debug", cfg.Options.Debug,
"sizeMB", float64(info.Size())/(1024*1024),
"maxSizeMB", float64(maxFileSize)/(1024*1024),
)
@@ -870,7 +870,7 @@ func shouldExcludeFile(filePath string) bool {
// openMatchingFile opens a file if it matches any of the registered patterns
func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
- cnf := config.Get()
+ cfg := config.Get()
// Skip directories
info, err := os.Stat(path)
if err != nil || info.IsDir() {
@@ -890,10 +890,10 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// Check if the file is a high-priority file that should be opened immediately
// This helps with project initialization for certain language servers
if isHighPriorityFile(path, serverName) {
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Opening high-priority file", "path", path, "serverName", serverName)
}
- if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
+ if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
logging.Error("Error opening high-priority file", "path", path, "error", err)
}
return
@@ -905,7 +905,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// Check file size - for preloading we're more conservative
if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
- if cnf.DebugLSP {
+ if cfg.Options.DebugLSP {
logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
}
return
@@ -937,7 +937,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
if shouldOpen {
// Don't need to check if it's already open - the client.OpenFile handles that
- if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
+ if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
logging.Error("Error opening file", "path", path, "error", err)
}
}
@@ -5,7 +5,7 @@ import (
"slices"
"time"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/fur/provider"
)
type MessageRole string
@@ -71,9 +71,9 @@ type BinaryContent struct {
Data []byte
}
-func (bc BinaryContent) String(provider models.InferenceProvider) string {
+func (bc BinaryContent) String(p provider.InferenceProvider) string {
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
- if provider == models.ProviderOpenAI {
+ if p == provider.InferenceProviderOpenAI {
return "data:" + bc.MIMEType + ";base64," + base64Encoded
}
return base64Encoded
@@ -113,7 +113,8 @@ type Message struct {
Role MessageRole
SessionID string
Parts []ContentPart
- Model models.ModelID
+ Model string
+ Provider string
CreatedAt int64
UpdatedAt int64
}
@@ -8,15 +8,15 @@ import (
"time"
"github.com/charmbracelet/crush/internal/db"
- "github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
type CreateMessageParams struct {
- Role MessageRole
- Parts []ContentPart
- Model models.ModelID
+ Role MessageRole
+ Parts []ContentPart
+ Model string
+ Provider string
}
type Service interface {
@@ -70,6 +70,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
Role: string(params.Role),
Parts: string(partsJSON),
Model: sql.NullString{String: string(params.Model), Valid: true},
+ Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
})
if err != nil {
return Message{}, err
@@ -154,7 +155,8 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
SessionID: item.SessionID,
Role: MessageRole(item.Role),
Parts: parts,
- Model: models.ModelID(item.Model.String),
+ Model: item.Model.String,
+ Provider: item.Provider.String,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
}, nil
@@ -7,7 +7,6 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fsext"
- "github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -112,11 +111,7 @@ func (h *header) details() string {
parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount)))
}
- cfg := config.Get()
- agentCfg := cfg.Agents[config.AgentCoder]
- selectedModelID := agentCfg.Model
- model := models.SupportedModels[selectedModelID]
-
+ model := config.GetAgentModel(config.AgentCoder)
percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100
formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage)))
parts = append(parts, formattedPercentage)
@@ -10,7 +10,8 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/lipgloss/v2"
- "github.com/charmbracelet/crush/internal/llm/models"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/tui/components/anim"
"github.com/charmbracelet/crush/internal/tui/components/core"
@@ -290,8 +291,9 @@ func (m *assistantSectionModel) View() tea.View {
duration := finishTime.Sub(m.lastUserMessageTime)
infoMsg := t.S().Subtle.Render(duration.String())
icon := t.S().Subtle.Render(styles.ModelIcon)
- model := t.S().Muted.Render(models.SupportedModels[m.message.Model].Name)
- assistant := fmt.Sprintf("%s %s %s", icon, model, infoMsg)
+ model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model)
+ modelFormatted := t.S().Muted.Render(model.Name)
+ assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
return tea.NewView(
t.S().Base.PaddingLeft(2).Render(
core.Section(assistant, m.width-2),
@@ -13,7 +13,6 @@ import (
"github.com/charmbracelet/crush/internal/diff"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
@@ -406,7 +405,7 @@ func (m *sidebarCmp) mcpBlock() string {
mcpList := []string{section, ""}
- mcp := config.Get().MCPServers
+ mcp := config.Get().MCP
if len(mcp) == 0 {
return lipgloss.JoinVertical(
lipgloss.Left,
@@ -475,10 +474,7 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string {
}
func (s *sidebarCmp) currentModelBlock() string {
- cfg := config.Get()
- agentCfg := cfg.Agents[config.AgentCoder]
- selectedModelID := agentCfg.Model
- model := models.SupportedModels[selectedModelID]
+ model := config.GetAgentModel(config.AgentCoder)
t := styles.CurrentTheme()
@@ -63,7 +63,7 @@ func buildCommandSources(cfg *config.Config) []commandSource {
// Project directory
sources = append(sources, commandSource{
- path: filepath.Join(cfg.Data.Directory, "commands"),
+ path: filepath.Join(cfg.Options.DataDirectory, "commands"),
prefix: ProjectCommandPrefix,
})
@@ -5,7 +5,7 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/lipgloss/v2"
- "github.com/charmbracelet/crush/internal/config"
+ configv2 "github.com/charmbracelet/crush/internal/config"
cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
@@ -184,7 +184,7 @@ If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (
Add the .crush directory to the .gitignore file if it's not already there.`
// Mark the project as initialized
- if err := config.MarkProjectInitialized(); err != nil {
+ if err := configv2.MarkProjectInitialized(); err != nil {
return util.ReportError(err)
}
@@ -196,7 +196,7 @@ Add the .crush directory to the .gitignore file if it's not already there.`
)
} else {
// Mark the project as initialized without running the command
- if err := config.MarkProjectInitialized(); err != nil {
+ if err := configv2.MarkProjectInitialized(); err != nil {
return util.ReportError(err)
}
}
@@ -1,13 +1,11 @@
package models
import (
- "slices"
-
"github.com/charmbracelet/bubbles/v2/help"
"github.com/charmbracelet/bubbles/v2/key"
tea "github.com/charmbracelet/bubbletea/v2"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
+ configv2 "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/tui/components/completions"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/core/list"
@@ -26,7 +24,7 @@ const (
// ModelSelectedMsg is sent when a model is selected
type ModelSelectedMsg struct {
- Model models.Model
+ Model configv2.PreferredModel
}
// CloseModelDialogMsg is sent when a model is selected
@@ -37,6 +35,11 @@ type ModelDialog interface {
dialogs.DialogModel
}
+type ModelOption struct {
+ Provider provider.Provider
+ Model provider.Model
+}
+
type modelDialogCmp struct {
width int
wWidth int // Width of the terminal window
@@ -80,47 +83,31 @@ func NewModelDialogCmp() ModelDialog {
}
}
-var ProviderPopularity = map[models.InferenceProvider]int{
- models.ProviderAnthropic: 1,
- models.ProviderOpenAI: 2,
- models.ProviderGemini: 3,
- models.ProviderGROQ: 4,
- models.ProviderOpenRouter: 5,
- models.ProviderBedrock: 6,
- models.ProviderAzure: 7,
- models.ProviderVertexAI: 8,
- models.ProviderXAI: 9,
-}
-
-var ProviderName = map[models.InferenceProvider]string{
- models.ProviderAnthropic: "Anthropic",
- models.ProviderOpenAI: "OpenAI",
- models.ProviderGemini: "Gemini",
- models.ProviderGROQ: "Groq",
- models.ProviderOpenRouter: "OpenRouter",
- models.ProviderBedrock: "AWS Bedrock",
- models.ProviderAzure: "Azure",
- models.ProviderVertexAI: "VertexAI",
- models.ProviderXAI: "xAI",
-}
-
func (m *modelDialogCmp) Init() tea.Cmd {
- cfg := config.Get()
- enabledProviders := getEnabledProviders(cfg)
+ providers := configv2.Providers()
+ cfg := configv2.Get()
+ coderAgent := cfg.Agents[configv2.AgentCoder]
modelItems := []util.Model{}
- for _, provider := range enabledProviders {
- name, ok := ProviderName[provider]
- if !ok {
- name = string(provider) // Fallback to provider ID if name is not defined
+ selectIndex := 0
+ for _, provider := range providers {
+ name := provider.Name
+ if name == "" {
+ name = string(provider.ID)
}
modelItems = append(modelItems, commands.NewItemSection(name))
- for _, model := range getModelsForProvider(provider) {
- modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
+ for _, model := range provider.Models {
+ if model.ID == coderAgent.Model && provider.ID == coderAgent.Provider {
+ selectIndex = len(modelItems) // Set the selected index to the current model
+ }
+ modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
+ Provider: provider,
+ Model: model,
+ }))
}
}
- m.modelList.SetItems(modelItems)
- return m.modelList.Init()
+
+ return tea.Sequence(m.modelList.Init(), m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
}
func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -137,11 +124,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil // No item selected, do nothing
}
items := m.modelList.Items()
- selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
+ selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
return m, tea.Sequence(
util.CmdHandler(dialogs.CloseDialogMsg{}),
- util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
+ util.CmdHandler(ModelSelectedMsg{Model: configv2.PreferredModel{
+ ModelID: selectedItem.Model.ID,
+ Provider: selectedItem.Provider.ID,
+ }}),
)
case key.Matches(msg, m.keyMap.Close):
return m, util.CmdHandler(dialogs.CloseDialogMsg{})
@@ -189,58 +179,6 @@ func (m *modelDialogCmp) listHeight() int {
return min(listHeigh, m.wHeight/2)
}
-func GetSelectedModel(cfg *config.Config) models.Model {
- agentCfg := cfg.Agents[config.AgentCoder]
- selectedModelID := agentCfg.Model
- return models.SupportedModels[selectedModelID]
-}
-
-func getEnabledProviders(cfg *config.Config) []models.InferenceProvider {
- var providers []models.InferenceProvider
- for providerID, provider := range cfg.Providers {
- if !provider.Disabled {
- providers = append(providers, providerID)
- }
- }
-
- // Sort by provider popularity
- slices.SortFunc(providers, func(a, b models.InferenceProvider) int {
- rA := ProviderPopularity[a]
- rB := ProviderPopularity[b]
-
- // models not included in popularity ranking default to last
- if rA == 0 {
- rA = 999
- }
- if rB == 0 {
- rB = 999
- }
- return rA - rB
- })
- return providers
-}
-
-func getModelsForProvider(provider models.InferenceProvider) []models.Model {
- var providerModels []models.Model
- for _, model := range models.SupportedModels {
- if model.Provider == provider {
- providerModels = append(providerModels, model)
- }
- }
-
- // reverse alphabetical order (if llm naming was consistent latest would appear first)
- slices.SortFunc(providerModels, func(a, b models.Model) int {
- if a.Name > b.Name {
- return -1
- } else if a.Name < b.Name {
- return 1
- }
- return 0
- })
-
- return providerModels
-}
-
func (m *modelDialogCmp) Position() (int, int) {
row := m.wHeight/4 - 2 // just a bit above the center
col := m.wWidth / 2
@@ -9,7 +9,6 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/tui/components/chat"
@@ -171,14 +170,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
util.CmdHandler(ChatFocusedMsg{Focused: false}),
)
case key.Matches(msg, p.keyMap.AddAttachment):
- cfg := config.Get()
- agentCfg := cfg.Agents[config.AgentCoder]
- selectedModelID := agentCfg.Model
- model := models.SupportedModels[selectedModelID]
- if model.SupportsAttachments {
+ model := config.GetAgentModel(config.AgentCoder)
+ if model.SupportsImages {
return p, util.CmdHandler(OpenFilePickerMsg{})
} else {
- return p, util.ReportWarn("File attachments are not supported by the current model: " + string(selectedModelID))
+ return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name)
}
case key.Matches(msg, p.keyMap.Tab):
if p.session.ID == "" {
@@ -8,6 +8,7 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
+ configv2 "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/agent"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/permission"
@@ -69,7 +70,7 @@ func (a appModel) Init() tea.Cmd {
// Check if we should show the init dialog
cmds = append(cmds, func() tea.Msg {
- shouldShow, err := config.ShouldShowInitDialog()
+ shouldShow, err := configv2.ProjectNeedsInitialization()
if err != nil {
return util.InfoMsg{
Type: util.InfoTypeError,
@@ -172,7 +173,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Model Switch
case models.ModelSelectedMsg:
- model, err := a.app.CoderAgent.Update(config.AgentCoder, msg.Model.ID)
+ model, err := a.app.CoderAgent.Update(msg.Model)
if err != nil {
return a, util.ReportError(err)
}
@@ -222,7 +223,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
model := a.app.CoderAgent.Model()
contextWindow := model.ContextWindow
tokens := session.CompletionTokens + session.PromptTokens
- if (tokens >= int64(float64(contextWindow)*0.95)) && config.Get().AutoCompact {
+ if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize {
// Show compact confirmation dialog
cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{
Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false),