diff --git a/go.mod b/go.mod index 86b3a8541ddc01679f4a6e25a02ea6599b6b3754..8bc77cf3ffe7cdd96131027fe09f5b8f1a50796a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/charmbracelet/crush -go 1.24.0 +go 1.24.3 + +replace github.com/charmbracelet/fur => ../fur require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 @@ -15,6 +17,7 @@ require ( github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.3.0.20250609143341-c76fa36f1b94 github.com/charmbracelet/fang v0.1.0 + github.com/charmbracelet/fur v0.0.0-00010101000000-000000000000 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 diff --git a/internal/config/config.go b/internal/config/config.go index 3b794dee60cb13a0c42413aedce3199ea5998352..3944cb1374582f9af0eeb7bfadd05ef5f9a8c198 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -76,17 +76,17 @@ type TUIConfig struct { // Config is the main configuration structure for the application. type Config struct { - Data Data `json:"data"` - WorkingDir string `json:"wd,omitempty"` - MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` - Providers map[models.ModelProvider]Provider `json:"providers,omitempty"` - LSP map[string]LSPConfig `json:"lsp,omitempty"` - Agents map[AgentName]Agent `json:"agents,omitempty"` - Debug bool `json:"debug,omitempty"` - DebugLSP bool `json:"debugLSP,omitempty"` - ContextPaths []string `json:"contextPaths,omitempty"` - TUI TUIConfig `json:"tui"` - AutoCompact bool `json:"autoCompact,omitempty"` + Data Data `json:"data"` + WorkingDir string `json:"wd,omitempty"` + MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` + Providers map[models.InferenceProvider]Provider `json:"providers,omitempty"` + LSP map[string]LSPConfig `json:"lsp,omitempty"` + Agents map[AgentName]Agent `json:"agents,omitempty"` + Debug bool `json:"debug,omitempty"` + DebugLSP bool `json:"debugLSP,omitempty"` + ContextPaths []string `json:"contextPaths,omitempty"` + TUI TUIConfig `json:"tui"` + AutoCompact bool `json:"autoCompact,omitempty"` } // Application constants @@ -128,7 +128,7 @@ func Load(workingDir string, debug bool) (*Config, error) { cfg = &Config{ WorkingDir: workingDir, MCPServers: make(map[string]MCPServer), - Providers: make(map[models.ModelProvider]Provider), + Providers: make(map[models.InferenceProvider]Provider), LSP: make(map[string]LSPConfig), } @@ -640,7 +640,7 @@ func Validate() error { } // getProviderAPIKey gets the API key for a provider from environment variables -func getProviderAPIKey(provider models.ModelProvider) string { +func getProviderAPIKey(provider models.InferenceProvider) string { switch provider { case models.ProviderAnthropic: return os.Getenv("ANTHROPIC_API_KEY") diff --git a/internal/config_v2/config.go b/internal/config_v2/config.go new file mode 100644 index 0000000000000000000000000000000000000000..4ab12a83fe6de3e94105cf00d4045f652dd26cae --- /dev/null +++ b/internal/config_v2/config.go @@ -0,0 +1,440 @@ +package configv2 + +import ( + "encoding/json" + "errors" + "maps" + "os" + "path/filepath" + "slices" + "strings" + "sync" + + "github.com/charmbracelet/crush/internal/logging" + "github.com/charmbracelet/fur/pkg/provider" +) + +const ( + defaultDataDirectory = ".crush" + defaultLogLevel = "info" + appName = "crush" + + MaxTokensFallbackDefault = 4096 +) + +type Model struct { + ID string `json:"id"` + Name string `json:"model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + ReasoningEffort string `json:"reasoning_effort"` + SupportsImages bool `json:"supports_attachments"` +} + +type VertexAIOptions struct { + APIKey string `json:"api_key,omitempty"` + Project string `json:"project,omitempty"` + Location string `json:"location,omitempty"` +} + +type ProviderConfig struct { + BaseURL string `json:"base_url,omitempty"` + ProviderType provider.Type `json:"provider_type"` + APIKey string `json:"api_key,omitempty"` + Disabled bool `json:"disabled"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // used for e.x for vertex to set the project + ExtraParams map[string]string `json:"extra_params,omitempty"` + + DefaultModel string `json:"default_model"` +} + +type Agent struct { + Name string `json:"name"` + // This is the id of the system prompt used by the agent + // TODO: still needs to be implemented + PromptID string `json:"prompt_id"` + Disabled bool `json:"disabled"` + + Provider provider.InferenceProvider `json:"provider"` + Model Model `json:"model"` + + // The available tools for the agent + // if this is empty, all tools are available + AllowedTools []string `json:"allowed_tools"` + + // this tells us which MCPs are available for this agent + // if this is empty all mcps are available + // the string array is the list of tools from the MCP the agent has available + // if the string array is empty, all tools from the MCP are available + MCP map[string][]string `json:"mcp"` + + // The list of LSPs that this agent can use + // if this is empty, all LSPs are available + LSP []string `json:"lsp"` + + // Overrides the context paths for this agent + ContextPaths []string `json:"context_paths"` +} + +type MCPType string + +const ( + MCPStdio MCPType = "stdio" + MCPSse MCPType = "sse" +) + +type MCP struct { + Command string `json:"command"` + Env []string `json:"env"` + Args []string `json:"args"` + Type MCPType `json:"type"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` +} + +type LSPConfig struct { + Disabled bool `json:"enabled"` + Command string `json:"command"` + Args []string `json:"args"` + Options any `json:"options"` +} + +type TUIOptions struct { + CompactMode bool `json:"compact_mode"` + // Here we can add themes later or any TUI related options +} + +type Options struct { + ContextPaths []string `json:"context_paths"` + TUI TUIOptions `json:"tui"` + Debug bool `json:"debug"` + DebugLSP bool `json:"debug_lsp"` + DisableAutoSummarize bool `json:"disable_auto_summarize"` + // Relative to the cwd + DataDirectory string `json:"data_directory"` +} + +type Config struct { + // List of configured providers + Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` + + // List of configured agents + Agents map[string]Agent `json:"agents,omitempty"` + + // List of configured MCPs + MCP map[string]MCP `json:"mcp,omitempty"` + + // List of configured LSPs + LSP map[string]LSPConfig `json:"lsp,omitempty"` + + // Miscellaneous options + Options Options `json:"options"` + + // Used to add models that are not already in the repository + Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"` +} + +var ( + instance *Config // The single instance of the Singleton + cwd string + once sync.Once // Ensures the initialization happens only once +) + +func loadConfig(cwd string) (*Config, error) { + // First read the global config file + cfgPath := ConfigPath() + + cfg := defaultConfigBasedOnEnv() + + var globalCfg *Config + if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // config file exists, read it + file, err := os.ReadFile(cfgPath) + if err != nil { + return nil, err + } + globalCfg = &Config{} + if err := json.Unmarshal(file, globalCfg); err != nil { + return nil, err + } + } else { + // config file does not exist, create a new one + globalCfg = &Config{} + } + + var localConfig *Config + // Global config loaded, now read the local config file + localConfigPath := filepath.Join(cwd, "crush.json") + if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // local config file exists, read it + file, err := os.ReadFile(localConfigPath) + if err != nil { + return nil, err + } + localConfig = &Config{} + if err := json.Unmarshal(file, localConfig); err != nil { + return nil, err + } + } + + // merge options + cfg.Options = mergeOptions(cfg.Options, globalCfg.Options) + cfg.Options = mergeOptions(cfg.Options, localConfig.Options) + + mergeProviderConfigs(cfg, globalCfg, localConfig) + return cfg, nil +} + +func InitConfig(workingDir string) *Config { + once.Do(func() { + cwd = workingDir + cfg, err := loadConfig(cwd) + if err != nil { + // TODO: Handle this better + panic("Failed to load config: " + err.Error()) + } + instance = cfg + }) + + return instance +} + +func GetConfig() *Config { + if instance == nil { + // TODO: Handle this better + panic("Config not initialized. Call InitConfig first.") + } + return instance +} + +func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { + if other.APIKey != "" { + base.APIKey = other.APIKey + } + // Only change these options if the provider is not a known provider + if !slices.Contains(provider.KnownProviders(), p) { + if other.BaseURL != "" { + base.BaseURL = other.BaseURL + } + if other.ProviderType != "" { + base.ProviderType = other.ProviderType + } + if len(base.ExtraHeaders) > 0 { + if base.ExtraHeaders == nil { + base.ExtraHeaders = make(map[string]string) + } + maps.Copy(base.ExtraHeaders, other.ExtraHeaders) + } + if len(other.ExtraParams) > 0 { + if base.ExtraParams == nil { + base.ExtraParams = make(map[string]string) + } + maps.Copy(base.ExtraParams, other.ExtraParams) + } + } + + if other.Disabled { + base.Disabled = other.Disabled + } + + return base +} + +func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error { + if !slices.Contains(provider.KnownProviders(), p) { + if providerConfig.ProviderType != provider.TypeOpenAI { + return errors.New("invalid provider type: " + string(providerConfig.ProviderType)) + } + if providerConfig.BaseURL == "" { + return errors.New("base URL must be set for custom providers") + } + if providerConfig.APIKey == "" { + return errors.New("API key must be set for custom providers") + } + } + return nil +} + +func mergeOptions(base, other Options) Options { + result := base + + if len(other.ContextPaths) > 0 { + base.ContextPaths = append(base.ContextPaths, other.ContextPaths...) + } + + if other.TUI.CompactMode { + result.TUI.CompactMode = other.TUI.CompactMode + } + + if other.Debug { + result.Debug = other.Debug + } + + if other.DebugLSP { + result.DebugLSP = other.DebugLSP + } + + if other.DisableAutoSummarize { + result.DisableAutoSummarize = other.DisableAutoSummarize + } + + if other.DataDirectory != "" { + result.DataDirectory = other.DataDirectory + } + + return result +} + +func mergeProviderConfigs(base, global, local *Config) { + if global != nil { + for providerName, globalProvider := range global.Providers { + if _, ok := base.Providers[providerName]; !ok { + base.Providers[providerName] = globalProvider + } else { + base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider) + } + } + } + if local != nil { + for providerName, localProvider := range local.Providers { + if _, ok := base.Providers[providerName]; !ok { + base.Providers[providerName] = localProvider + } else { + base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider) + } + } + } + + finalProviders := make(map[provider.InferenceProvider]ProviderConfig) + for providerName, providerConfig := range base.Providers { + err := validateProvider(providerName, providerConfig) + if err != nil { + logging.Warn("Skipping provider", "name", providerName, "error", err) + } + finalProviders[providerName] = providerConfig + } + base.Providers = finalProviders +} + +func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig { + switch providerName { + case provider.InferenceProviderAnthropic: + return ProviderConfig{ + ProviderType: provider.TypeAnthropic, + } + case provider.InferenceProviderOpenAI: + return ProviderConfig{ + ProviderType: provider.TypeOpenAI, + } + case provider.InferenceProviderGemini: + return ProviderConfig{ + ProviderType: provider.TypeGemini, + } + case provider.InferenceProviderBedrock: + return ProviderConfig{ + ProviderType: provider.TypeBedrock, + } + case provider.InferenceProviderAzure: + return ProviderConfig{ + ProviderType: provider.TypeAzure, + } + case provider.InferenceProviderOpenRouter: + return ProviderConfig{ + ProviderType: provider.TypeOpenAI, + BaseURL: "https://openrouter.ai/api/v1", + ExtraHeaders: map[string]string{ + "HTTP-Referer": "crush.charm.land", + "X-Title": "Crush", + }, + } + case provider.InferenceProviderXAI: + return ProviderConfig{ + ProviderType: provider.TypeXAI, + BaseURL: "https://api.x.ai/v1", + } + case provider.InferenceProviderVertexAI: + return ProviderConfig{ + ProviderType: provider.TypeVertexAI, + } + default: + return ProviderConfig{ + ProviderType: provider.TypeOpenAI, + } + } +} + +func defaultConfigBasedOnEnv() *Config { + cfg := &Config{ + Options: Options{ + DataDirectory: defaultDataDirectory, + }, + Providers: make(map[provider.InferenceProvider]ProviderConfig), + } + + providers := Providers() + + for _, p := range providers { + if strings.HasPrefix(p.APIKey, "$") { + envVar := strings.TrimPrefix(p.APIKey, "$") + if apiKey := os.Getenv(envVar); apiKey != "" { + providerConfig := providerDefaultConfig(p.ID) + providerConfig.APIKey = apiKey + providerConfig.DefaultModel = p.DefaultModelID + cfg.Providers[p.ID] = providerConfig + } + } + } + // TODO: support local models + + if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" { + providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI) + providerConfig.ExtraParams = map[string]string{ + "project": os.Getenv("GOOGLE_CLOUD_PROJECT"), + "location": os.Getenv("GOOGLE_CLOUD_LOCATION"), + } + cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig + } + + if hasAWSCredentials() { + providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock) + cfg.Providers[provider.InferenceProviderBedrock] = providerConfig + } + return cfg +} + +func hasAWSCredentials() bool { + if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { + return true + } + + if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { + return true + } + + if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { + return true + } + + if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { + return true + } + + return false +} + +func WorkingDirectory() string { + return cwd +} diff --git a/internal/config_v2/config_test.go b/internal/config_v2/config_test.go new file mode 100644 index 0000000000000000000000000000000000000000..50b829271dcd42213141ecf2b9b72f5890480668 --- /dev/null +++ b/internal/config_v2/config_test.go @@ -0,0 +1,33 @@ +package configv2 + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func resetEnvVars() { + os.Setenv("ANTHROPIC_API_KEY", "") + os.Setenv("OPENAI_API_KEY", "") + os.Setenv("GEMINI_API_KEY", "") + os.Setenv("XAI_API_KEY", "") + os.Setenv("OPENROUTER_API_KEY", "") +} + +func TestConfigWithEnv(t *testing.T) { + resetEnvVars() + testConfigDir = t.TempDir() + + cwdDir := t.TempDir() + + os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key") + os.Setenv("OPENAI_API_KEY", "test-openai-key") + os.Setenv("GEMINI_API_KEY", "test-gemini-key") + os.Setenv("XAI_API_KEY", "test-xai-key") + os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") + cfg := InitConfig(cwdDir) + fmt.Println(cfg) + assert.Len(t, cfg.Providers, 5) +} diff --git a/internal/config_v2/fs.go b/internal/config_v2/fs.go new file mode 100644 index 0000000000000000000000000000000000000000..976267a2a68efb718449f59b3720d0d186720cdf --- /dev/null +++ b/internal/config_v2/fs.go @@ -0,0 +1,71 @@ +package configv2 + +import ( + "fmt" + "os" + "path/filepath" + "runtime" +) + +var testConfigDir string + +func baseConfigPath() string { + if testConfigDir != "" { + return testConfigDir + } + + xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") + if xdgConfigHome != "" { + return filepath.Join(xdgConfigHome, "crush") + } + + // return the path to the main config directory + // for windows, it should be in `%LOCALAPPDATA%/crush/` + // for linux and macOS, it should be in `$HOME/.config/crush/` + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") + } + return filepath.Join(localAppData, appName) + } + + return filepath.Join(os.Getenv("HOME"), ".config", appName) +} + +func baseDataPath() string { + if testConfigDir != "" { + return testConfigDir + } + + xdgDataHome := os.Getenv("XDG_DATA_HOME") + if xdgDataHome != "" { + return filepath.Join(xdgDataHome, appName) + } + + // return the path to the main data directory + // for windows, it should be in `%LOCALAPPDATA%/crush/` + // for linux and macOS, it should be in `$HOME/.local/share/crush/` + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") + } + return filepath.Join(localAppData, appName) + } + + return filepath.Join(os.Getenv("HOME"), ".local", "share", appName) +} + +func ConfigPath() string { + return filepath.Join(baseConfigPath(), fmt.Sprintf("%s.json", appName)) +} + +func CrushInitialized() bool { + cfgPath := ConfigPath() + if _, err := os.Stat(cfgPath); os.IsNotExist(err) { + // config file does not exist, so Crush is not initialized + return false + } + return true +} diff --git a/internal/config_v2/provider.go b/internal/config_v2/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..94fe2d44d74e3039dcdeaa0dc76e95b840a03125 --- /dev/null +++ b/internal/config_v2/provider.go @@ -0,0 +1,69 @@ +package configv2 + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + + "github.com/charmbracelet/fur/pkg/client" + "github.com/charmbracelet/fur/pkg/provider" +) + +var fur = client.New() + +var ( + providerOnc sync.Once // Ensures the initialization happens only once + providerList []provider.Provider +) + +func providersPath() string { + return filepath.Join(baseDataPath(), "providers.json") +} + +func saveProviders(providers []provider.Provider) error { + path := providersPath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + data, err := json.MarshalIndent(providers, "", " ") + if err != nil { + return err + } + + return os.WriteFile(path, data, 0o644) +} + +func loadProviders() ([]provider.Provider, error) { + path := providersPath() + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var providers []provider.Provider + err = json.Unmarshal(data, &providers) + return providers, err +} + +func Providers() []provider.Provider { + providerOnc.Do(func() { + // Try to get providers from upstream API + if providers, err := fur.GetProviders(); err == nil { + providerList = providers + // Save providers locally for future fallback + _ = saveProviders(providers) + } else { + // If upstream fails, try to load from local cache + if localProviders, localErr := loadProviders(); localErr == nil { + providerList = localProviders + } else { + // If both fail, return empty list + providerList = []provider.Provider{} + } + } + }) + return providerList +} diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go index 9da03a835126956a74ee16888397abb343811ec4..85c47def3d94034297265c506c5870f2b449d286 100644 --- a/internal/llm/models/anthropic.go +++ b/internal/llm/models/anthropic.go @@ -1,7 +1,7 @@ package models const ( - ProviderAnthropic ModelProvider = "anthropic" + ProviderAnthropic InferenceProvider = "anthropic" // Models Claude35Sonnet ModelID = "claude-3.5-sonnet" diff --git a/internal/llm/models/azure.go b/internal/llm/models/azure.go index 416597302f362b4f2d7c605f7166ced2b200885a..eb7ae293ee053d953f5bcbb20120089ca6bae95b 100644 --- a/internal/llm/models/azure.go +++ b/internal/llm/models/azure.go @@ -1,6 +1,6 @@ package models -const ProviderAzure ModelProvider = "azure" +const ProviderAzure InferenceProvider = "azure" const ( AzureGPT41 ModelID = "azure.gpt-4.1" diff --git a/internal/llm/models/gemini.go b/internal/llm/models/gemini.go index 794ec3f0a06a0e9975d110cd4fb89e1427a32552..9749c6d3409acf7b05cd67690504e2cb3ac4fd39 100644 --- a/internal/llm/models/gemini.go +++ b/internal/llm/models/gemini.go @@ -1,7 +1,7 @@ package models const ( - ProviderGemini ModelProvider = "gemini" + ProviderGemini InferenceProvider = "gemini" // Models Gemini25Flash ModelID = "gemini-2.5-flash" diff --git a/internal/llm/models/groq.go b/internal/llm/models/groq.go index 19917f20bb2647e296db681e30b1b0f379bf7349..39288962c8e42a1acec8a01b3157b10d9b00b5dc 100644 --- a/internal/llm/models/groq.go +++ b/internal/llm/models/groq.go @@ -1,7 +1,7 @@ package models const ( - ProviderGROQ ModelProvider = "groq" + ProviderGROQ InferenceProvider = "groq" // GROQ QWENQwq ModelID = "qwen-qwq" diff --git a/internal/llm/models/local.go b/internal/llm/models/local.go index 3a50fdf48fe86167600eceee3cce26b6caac900e..c469e99fd65d5befbfffe5126a31c88eae68e150 100644 --- a/internal/llm/models/local.go +++ b/internal/llm/models/local.go @@ -16,7 +16,7 @@ import ( ) const ( - ProviderLocal ModelProvider = "local" + ProviderLocal InferenceProvider = "local" localModelsPath = "v1/models" lmStudioBetaModelsPath = "api/v0/models" diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 50e8723989ccb268a9f515b4c693662654fa38d5..0aefc170d32d1023f0d246a2cc7522e895453a88 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -3,23 +3,23 @@ package models import "maps" type ( - ModelID string - ModelProvider string + ModelID string + InferenceProvider string ) type Model struct { - ID ModelID `json:"id"` - Name string `json:"name"` - Provider ModelProvider `json:"provider"` - APIModel string `json:"api_model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - SupportsAttachments bool `json:"supports_attachments"` + ID ModelID `json:"id"` + Name string `json:"name"` + Provider InferenceProvider `json:"provider"` + APIModel string `json:"api_model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + SupportsAttachments bool `json:"supports_attachments"` } // Model IDs @@ -29,9 +29,9 @@ const ( // GEMINI ) const ( - ProviderBedrock ModelProvider = "bedrock" + ProviderBedrock InferenceProvider = "bedrock" // ForTests - ProviderMock ModelProvider = "__mock" + ProviderMock InferenceProvider = "__mock" ) var SupportedModels = map[ModelID]Model{ @@ -48,6 +48,20 @@ var SupportedModels = map[ModelID]Model{ }, } +var KnownProviders = []InferenceProvider{ + ProviderAnthropic, + ProviderOpenAI, + ProviderGemini, + ProviderAzure, + ProviderGROQ, + ProviderLocal, + ProviderOpenRouter, + ProviderVertexAI, + ProviderBedrock, + ProviderXAI, + ProviderMock, +} + func init() { maps.Copy(SupportedModels, AnthropicModels) maps.Copy(SupportedModels, OpenAIModels) diff --git a/internal/llm/models/openai.go b/internal/llm/models/openai.go index abe0e30c53207c0a7cacfe3d27f43495cfec6e8e..e4173277cbdfe1e579068d2981df1e70b9943cb1 100644 --- a/internal/llm/models/openai.go +++ b/internal/llm/models/openai.go @@ -1,7 +1,7 @@ package models const ( - ProviderOpenAI ModelProvider = "openai" + ProviderOpenAI InferenceProvider = "openai" GPT41 ModelID = "gpt-4.1" GPT41Mini ModelID = "gpt-4.1-mini" diff --git a/internal/llm/models/openrouter.go b/internal/llm/models/openrouter.go index 10ad5a0409937e060616eda6c188991d979e3ea1..8884e03442d30787fd505ca6a6c518d299748752 100644 --- a/internal/llm/models/openrouter.go +++ b/internal/llm/models/openrouter.go @@ -1,7 +1,7 @@ package models const ( - ProviderOpenRouter ModelProvider = "openrouter" + ProviderOpenRouter InferenceProvider = "openrouter" OpenRouterGPT41 ModelID = "openrouter.gpt-4.1" OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini" diff --git a/internal/llm/models/vertexai.go b/internal/llm/models/vertexai.go index d71dfc0bed0a8071c89ab22883990413384cd56f..c9b5744b62c28e2529cac44b1e97234158d2eacf 100644 --- a/internal/llm/models/vertexai.go +++ b/internal/llm/models/vertexai.go @@ -1,7 +1,7 @@ package models const ( - ProviderVertexAI ModelProvider = "vertexai" + ProviderVertexAI InferenceProvider = "vertexai" // Models VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash" diff --git a/internal/llm/models/xai.go b/internal/llm/models/xai.go index 00caf3b89750c0789f75f6273d49e38a4cdf6282..a59eac97ee6fee5db5550663083062099512eddc 100644 --- a/internal/llm/models/xai.go +++ b/internal/llm/models/xai.go @@ -1,7 +1,7 @@ package models const ( - ProviderXAI ModelProvider = "xai" + ProviderXAI InferenceProvider = "xai" XAIGrok3Beta ModelID = "grok-3-beta" XAIGrok3MiniBeta ModelID = "grok-3-mini-beta" diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index ea31bfa0297c1ce207e188a7f162e26831927636..b272f4e9f263ff596d06aae787e8b5a1c3ac2aec 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -13,7 +13,7 @@ import ( "github.com/charmbracelet/crush/internal/llm/tools" ) -func CoderPrompt(provider models.ModelProvider) string { +func CoderPrompt(provider models.InferenceProvider) string { basePrompt := baseAnthropicCoderPrompt switch provider { case models.ProviderOpenAI: diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 9065fd6bc0bb4a69eab19479b9a18b270dddee02..ed75d29c500cce16f16d06892ad8fcabc254a08d 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -12,7 +12,7 @@ import ( "github.com/charmbracelet/crush/internal/logging" ) -func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { +func GetAgentPrompt(agentName config.AgentName, provider models.InferenceProvider) string { basePrompt := "" switch agentName { case config.AgentCoder: diff --git a/internal/llm/prompt/summarizer.go b/internal/llm/prompt/summarizer.go index 87a0f95c66af8b51d07a3a4e792c07dea7dab503..f5a1de0f8619252d99082c6ca54e152cc25a7bc7 100644 --- a/internal/llm/prompt/summarizer.go +++ b/internal/llm/prompt/summarizer.go @@ -2,7 +2,7 @@ package prompt import "github.com/charmbracelet/crush/internal/llm/models" -func SummarizerPrompt(_ models.ModelProvider) string { +func SummarizerPrompt(_ models.InferenceProvider) string { return `You are a helpful AI assistant tasked with summarizing conversations. When asked to summarize, provide a detailed but concise summary of the conversation. diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index 53fd67dc2f88928b4fbe9773db0cd1487bcd811a..89acf1f02121ea008359eaa5201222061dad0cff 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -6,7 +6,7 @@ import ( "github.com/charmbracelet/crush/internal/llm/models" ) -func TaskPrompt(_ models.ModelProvider) string { +func TaskPrompt(_ models.InferenceProvider) string { agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 03e47288507fa66bb88605bff4b2194b889cc3f7..0b3177b37857c24d299df0d6e64393cd60ea23eb 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -2,7 +2,7 @@ package prompt import "github.com/charmbracelet/crush/internal/llm/models" -func TitlePrompt(_ models.ModelProvider) string { +func TitlePrompt(_ models.InferenceProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index d63f73f67cf0455bcccbb06ae70e3dde6a09557c..40c7317fba3eb944ba83421bcee8bf1702882fcb 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -82,7 +82,7 @@ type baseProvider[C ProviderClient] struct { client C } -func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) { +func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) { clientOptions := providerClientOptions{} for _, o := range opts { o(&clientOptions) diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index fe2de2f4588f9dbe583e4f8af85e61eea67d5648..0e01ef9c3a082faea86bd6c76d6e9d53c9f8a933 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -17,8 +17,8 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient { } client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ - Project: os.Getenv("VERTEXAI_PROJECT"), - Location: os.Getenv("VERTEXAI_LOCATION"), + Project: os.Getenv("GOOGLE_CLOUD_PROJECT"), + Location: os.Getenv("GOOGLE_CLOUD_LOCATION"), Backend: genai.BackendVertexAI, }) if err != nil { diff --git a/internal/message/content.go b/internal/message/content.go index 383134b596e62a5fc18b2c8404d770fc6a2d4112..b9e83ba4dd7fcc96216755a3871f0553b58d88d7 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -71,7 +71,7 @@ type BinaryContent struct { Data []byte } -func (bc BinaryContent) String(provider models.ModelProvider) string { +func (bc BinaryContent) String(provider models.InferenceProvider) string { base64Encoded := base64.StdEncoding.EncodeToString(bc.Data) if provider == models.ProviderOpenAI { return "data:" + bc.MIMEType + ";base64," + base64Encoded diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index afd067adc06d99c3c9da911812750631423231e6..405bd1f0f8c7891db1958e70f97e290dd9a8d411 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -76,7 +76,7 @@ func NewSidebarCmp(history history.Service, lspClients map[string]*lsp.Client, c } func (m *sidebarCmp) Init() tea.Cmd { - m.logo = m.logoBlock(false) + m.logo = m.logoBlock() m.cwd = cwd() return nil } @@ -231,9 +231,9 @@ func (m *sidebarCmp) loadSessionFiles() tea.Msg { func (m *sidebarCmp) SetSize(width, height int) tea.Cmd { if width < logoBreakpoint && (m.width == 0 || m.width >= logoBreakpoint) { - m.logo = m.logoBlock(true) + m.logo = m.logoBlock() } else if width >= logoBreakpoint && (m.width == 0 || m.width < logoBreakpoint) { - m.logo = m.logoBlock(false) + m.logo = m.logoBlock() } m.width = width @@ -245,9 +245,9 @@ func (m *sidebarCmp) GetSize() (int, int) { return m.width, m.height } -func (m *sidebarCmp) logoBlock(compact bool) string { +func (m *sidebarCmp) logoBlock() string { t := styles.CurrentTheme() - return logo.Render(version.Version, compact, logo.Opts{ + return logo.Render(version.Version, true, logo.Opts{ FieldColor: t.Primary, TitleColorA: t.Secondary, TitleColorB: t.Primary, diff --git a/internal/tui/components/chat/splash/keys.go b/internal/tui/components/chat/splash/keys.go new file mode 100644 index 0000000000000000000000000000000000000000..df715c89e86971a0f788915737bf41a212c65b5a --- /dev/null +++ b/internal/tui/components/chat/splash/keys.go @@ -0,0 +1,18 @@ +package splash + +import ( + "github.com/charmbracelet/bubbles/v2/key" +) + +type KeyMap struct { + Cancel key.Binding +} + +func DefaultKeyMap() KeyMap { + return KeyMap{ + Cancel: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "cancel"), + ), + } +} diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go new file mode 100644 index 0000000000000000000000000000000000000000..75718b25471088607a66a19c7b8a56d36bd5d2d1 --- /dev/null +++ b/internal/tui/components/chat/splash/splash.go @@ -0,0 +1,85 @@ +package splash + +import ( + "github.com/charmbracelet/bubbles/v2/key" + tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/tui/components/core/layout" + "github.com/charmbracelet/crush/internal/tui/components/logo" + "github.com/charmbracelet/crush/internal/tui/styles" + "github.com/charmbracelet/crush/internal/tui/util" + "github.com/charmbracelet/crush/internal/version" + "github.com/charmbracelet/lipgloss/v2" +) + +type Splash interface { + util.Model + layout.Sizeable + layout.Help +} + +type splashCmp struct { + width, height int + keyMap KeyMap + logoRendered string +} + +func New() Splash { + return &splashCmp{ + width: 0, + height: 0, + keyMap: DefaultKeyMap(), + logoRendered: "", + } +} + +// GetSize implements SplashPage. +func (s *splashCmp) GetSize() (int, int) { + return s.width, s.height +} + +// Init implements SplashPage. +func (s *splashCmp) Init() tea.Cmd { + return nil +} + +// SetSize implements SplashPage. +func (s *splashCmp) SetSize(width int, height int) tea.Cmd { + s.width = width + s.height = height + s.logoRendered = s.logoBlock() + return nil +} + +// Update implements SplashPage. +func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + return s, s.SetSize(msg.Width, msg.Height) + } + return s, nil +} + +// View implements SplashPage. +func (s *splashCmp) View() tea.View { + content := lipgloss.JoinVertical(lipgloss.Left, s.logoRendered) + return tea.NewView(content) +} + +func (m *splashCmp) logoBlock() string { + t := styles.CurrentTheme() + return logo.Render(version.Version, false, logo.Opts{ + FieldColor: t.Primary, + TitleColorA: t.Secondary, + TitleColorB: t.Primary, + CharmColor: t.Secondary, + VersionColor: t.Primary, + Width: m.width - 2, // -2 for padding + }) +} + +// Bindings implements SplashPage. +func (s *splashCmp) Bindings() []key.Binding { + return []key.Binding{ + s.keyMap.Cancel, + } +} diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 906d87a9dfe65c1ec09bd5abaf4f9d6865545038..0197b7141560a67008ceac64c31756bd19fff74a 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -80,7 +80,7 @@ func NewModelDialogCmp() ModelDialog { } } -var ProviderPopularity = map[models.ModelProvider]int{ +var ProviderPopularity = map[models.InferenceProvider]int{ models.ProviderAnthropic: 1, models.ProviderOpenAI: 2, models.ProviderGemini: 3, @@ -92,7 +92,7 @@ var ProviderPopularity = map[models.ModelProvider]int{ models.ProviderXAI: 9, } -var ProviderName = map[models.ModelProvider]string{ +var ProviderName = map[models.InferenceProvider]string{ models.ProviderAnthropic: "Anthropic", models.ProviderOpenAI: "OpenAI", models.ProviderGemini: "Gemini", @@ -195,8 +195,8 @@ func GetSelectedModel(cfg *config.Config) models.Model { return models.SupportedModels[selectedModelID] } -func getEnabledProviders(cfg *config.Config) []models.ModelProvider { - var providers []models.ModelProvider +func getEnabledProviders(cfg *config.Config) []models.InferenceProvider { + var providers []models.InferenceProvider for providerID, provider := range cfg.Providers { if !provider.Disabled { providers = append(providers, providerID) @@ -204,7 +204,7 @@ func getEnabledProviders(cfg *config.Config) []models.ModelProvider { } // Sort by provider popularity - slices.SortFunc(providers, func(a, b models.ModelProvider) int { + slices.SortFunc(providers, func(a, b models.InferenceProvider) int { rA := ProviderPopularity[a] rB := ProviderPopularity[b] @@ -220,7 +220,7 @@ func getEnabledProviders(cfg *config.Config) []models.ModelProvider { return providers } -func getModelsForProvider(provider models.ModelProvider) []models.Model { +func getModelsForProvider(provider models.InferenceProvider) []models.Model { var providerModels []models.Model for _, model := range models.SupportedModels { if model.Provider == provider { diff --git a/internal/tui/components/logo/logo.go b/internal/tui/components/logo/logo.go index 4b044c9dbd45284c72b7d03636d7399555e5f388..9d170ee6c2a0036ad9e4ca8b11c1a373fbb15080 100644 --- a/internal/tui/components/logo/logo.go +++ b/internal/tui/components/logo/logo.go @@ -26,6 +26,7 @@ type Opts struct { TitleColorB color.Color // right gradient ramp point CharmColor color.Color // Charmâ„¢ text color VersionColor color.Color // Version text color + Width int // width of the rendered logo, used for truncation } // Render renders the Crush logo. Set the argument to true to render the narrow @@ -76,7 +77,7 @@ func Render(version string, compact bool, o Opts) string { } // Right field. - const rightWidth = 15 + rightWidth := max(15, o.Width-crushWidth-leftWidth) // 2 for the gap. const stepDownAt = 0 rightField := new(strings.Builder) for i := range fieldHeight {