Detailed changes
@@ -1,6 +1,8 @@
module github.com/charmbracelet/crush
-go 1.24.0
+go 1.24.3
+
+replace github.com/charmbracelet/fur => ../fur
require (
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
@@ -15,6 +17,7 @@ require (
github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e
github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.3.0.20250609143341-c76fa36f1b94
github.com/charmbracelet/fang v0.1.0
+ github.com/charmbracelet/fur v0.0.0-00010101000000-000000000000
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413
@@ -76,17 +76,17 @@ type TUIConfig struct {
// Config is the main configuration structure for the application.
type Config struct {
- Data Data `json:"data"`
- WorkingDir string `json:"wd,omitempty"`
- MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
- Providers map[models.ModelProvider]Provider `json:"providers,omitempty"`
- LSP map[string]LSPConfig `json:"lsp,omitempty"`
- Agents map[AgentName]Agent `json:"agents,omitempty"`
- Debug bool `json:"debug,omitempty"`
- DebugLSP bool `json:"debugLSP,omitempty"`
- ContextPaths []string `json:"contextPaths,omitempty"`
- TUI TUIConfig `json:"tui"`
- AutoCompact bool `json:"autoCompact,omitempty"`
+ Data Data `json:"data"`
+ WorkingDir string `json:"wd,omitempty"`
+ MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
+ Providers map[models.InferenceProvider]Provider `json:"providers,omitempty"`
+ LSP map[string]LSPConfig `json:"lsp,omitempty"`
+ Agents map[AgentName]Agent `json:"agents,omitempty"`
+ Debug bool `json:"debug,omitempty"`
+ DebugLSP bool `json:"debugLSP,omitempty"`
+ ContextPaths []string `json:"contextPaths,omitempty"`
+ TUI TUIConfig `json:"tui"`
+ AutoCompact bool `json:"autoCompact,omitempty"`
}
// Application constants
@@ -128,7 +128,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
cfg = &Config{
WorkingDir: workingDir,
MCPServers: make(map[string]MCPServer),
- Providers: make(map[models.ModelProvider]Provider),
+ Providers: make(map[models.InferenceProvider]Provider),
LSP: make(map[string]LSPConfig),
}
@@ -640,7 +640,7 @@ func Validate() error {
}
// getProviderAPIKey gets the API key for a provider from environment variables
-func getProviderAPIKey(provider models.ModelProvider) string {
+func getProviderAPIKey(provider models.InferenceProvider) string {
switch provider {
case models.ProviderAnthropic:
return os.Getenv("ANTHROPIC_API_KEY")
@@ -0,0 +1,440 @@
+package configv2
+
+import (
+ "encoding/json"
+ "errors"
+ "maps"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+ "sync"
+
+ "github.com/charmbracelet/crush/internal/logging"
+ "github.com/charmbracelet/fur/pkg/provider"
+)
+
+const (
+ defaultDataDirectory = ".crush"
+ defaultLogLevel = "info"
+ appName = "crush"
+
+ MaxTokensFallbackDefault = 4096
+)
+
+type Model struct {
+ ID string `json:"id"`
+ Name string `json:"model"`
+ CostPer1MIn float64 `json:"cost_per_1m_in"`
+ CostPer1MOut float64 `json:"cost_per_1m_out"`
+ CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
+ CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
+ ContextWindow int64 `json:"context_window"`
+ DefaultMaxTokens int64 `json:"default_max_tokens"`
+ CanReason bool `json:"can_reason"`
+ ReasoningEffort string `json:"reasoning_effort"`
+ SupportsImages bool `json:"supports_attachments"`
+}
+
+type VertexAIOptions struct {
+ APIKey string `json:"api_key,omitempty"`
+ Project string `json:"project,omitempty"`
+ Location string `json:"location,omitempty"`
+}
+
+type ProviderConfig struct {
+ BaseURL string `json:"base_url,omitempty"`
+ ProviderType provider.Type `json:"provider_type"`
+ APIKey string `json:"api_key,omitempty"`
+ Disabled bool `json:"disabled"`
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // used for e.x for vertex to set the project
+ ExtraParams map[string]string `json:"extra_params,omitempty"`
+
+ DefaultModel string `json:"default_model"`
+}
+
+type Agent struct {
+ Name string `json:"name"`
+ // This is the id of the system prompt used by the agent
+ // TODO: still needs to be implemented
+ PromptID string `json:"prompt_id"`
+ Disabled bool `json:"disabled"`
+
+ Provider provider.InferenceProvider `json:"provider"`
+ Model Model `json:"model"`
+
+ // The available tools for the agent
+ // if this is empty, all tools are available
+ AllowedTools []string `json:"allowed_tools"`
+
+ // this tells us which MCPs are available for this agent
+ // if this is empty all mcps are available
+ // the string array is the list of tools from the MCP the agent has available
+ // if the string array is empty, all tools from the MCP are available
+ MCP map[string][]string `json:"mcp"`
+
+ // The list of LSPs that this agent can use
+ // if this is empty, all LSPs are available
+ LSP []string `json:"lsp"`
+
+ // Overrides the context paths for this agent
+ ContextPaths []string `json:"context_paths"`
+}
+
+type MCPType string
+
+const (
+ MCPStdio MCPType = "stdio"
+ MCPSse MCPType = "sse"
+)
+
+type MCP struct {
+ Command string `json:"command"`
+ Env []string `json:"env"`
+ Args []string `json:"args"`
+ Type MCPType `json:"type"`
+ URL string `json:"url"`
+ Headers map[string]string `json:"headers"`
+}
+
+type LSPConfig struct {
+ Disabled bool `json:"enabled"`
+ Command string `json:"command"`
+ Args []string `json:"args"`
+ Options any `json:"options"`
+}
+
+type TUIOptions struct {
+ CompactMode bool `json:"compact_mode"`
+ // Here we can add themes later or any TUI related options
+}
+
+type Options struct {
+ ContextPaths []string `json:"context_paths"`
+ TUI TUIOptions `json:"tui"`
+ Debug bool `json:"debug"`
+ DebugLSP bool `json:"debug_lsp"`
+ DisableAutoSummarize bool `json:"disable_auto_summarize"`
+ // Relative to the cwd
+ DataDirectory string `json:"data_directory"`
+}
+
+type Config struct {
+ // List of configured providers
+ Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
+
+ // List of configured agents
+ Agents map[string]Agent `json:"agents,omitempty"`
+
+ // List of configured MCPs
+ MCP map[string]MCP `json:"mcp,omitempty"`
+
+ // List of configured LSPs
+ LSP map[string]LSPConfig `json:"lsp,omitempty"`
+
+ // Miscellaneous options
+ Options Options `json:"options"`
+
+ // Used to add models that are not already in the repository
+ Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"`
+}
+
+var (
+ instance *Config // The single instance of the Singleton
+ cwd string
+ once sync.Once // Ensures the initialization happens only once
+)
+
+func loadConfig(cwd string) (*Config, error) {
+ // First read the global config file
+ cfgPath := ConfigPath()
+
+ cfg := defaultConfigBasedOnEnv()
+
+ var globalCfg *Config
+ if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
+ // some other error occurred while checking the file
+ return nil, err
+ } else if err == nil {
+ // config file exists, read it
+ file, err := os.ReadFile(cfgPath)
+ if err != nil {
+ return nil, err
+ }
+ globalCfg = &Config{}
+ if err := json.Unmarshal(file, globalCfg); err != nil {
+ return nil, err
+ }
+ } else {
+ // config file does not exist, create a new one
+ globalCfg = &Config{}
+ }
+
+ var localConfig *Config
+ // Global config loaded, now read the local config file
+ localConfigPath := filepath.Join(cwd, "crush.json")
+ if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
+ // some other error occurred while checking the file
+ return nil, err
+ } else if err == nil {
+ // local config file exists, read it
+ file, err := os.ReadFile(localConfigPath)
+ if err != nil {
+ return nil, err
+ }
+ localConfig = &Config{}
+ if err := json.Unmarshal(file, localConfig); err != nil {
+ return nil, err
+ }
+ }
+
+ // merge options
+ cfg.Options = mergeOptions(cfg.Options, globalCfg.Options)
+ cfg.Options = mergeOptions(cfg.Options, localConfig.Options)
+
+ mergeProviderConfigs(cfg, globalCfg, localConfig)
+ return cfg, nil
+}
+
+func InitConfig(workingDir string) *Config {
+ once.Do(func() {
+ cwd = workingDir
+ cfg, err := loadConfig(cwd)
+ if err != nil {
+ // TODO: Handle this better
+ panic("Failed to load config: " + err.Error())
+ }
+ instance = cfg
+ })
+
+ return instance
+}
+
+func GetConfig() *Config {
+ if instance == nil {
+ // TODO: Handle this better
+ panic("Config not initialized. Call InitConfig first.")
+ }
+ return instance
+}
+
+func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
+ if other.APIKey != "" {
+ base.APIKey = other.APIKey
+ }
+ // Only change these options if the provider is not a known provider
+ if !slices.Contains(provider.KnownProviders(), p) {
+ if other.BaseURL != "" {
+ base.BaseURL = other.BaseURL
+ }
+ if other.ProviderType != "" {
+ base.ProviderType = other.ProviderType
+ }
+ if len(base.ExtraHeaders) > 0 {
+ if base.ExtraHeaders == nil {
+ base.ExtraHeaders = make(map[string]string)
+ }
+ maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
+ }
+ if len(other.ExtraParams) > 0 {
+ if base.ExtraParams == nil {
+ base.ExtraParams = make(map[string]string)
+ }
+ maps.Copy(base.ExtraParams, other.ExtraParams)
+ }
+ }
+
+ if other.Disabled {
+ base.Disabled = other.Disabled
+ }
+
+ return base
+}
+
+func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
+ if !slices.Contains(provider.KnownProviders(), p) {
+ if providerConfig.ProviderType != provider.TypeOpenAI {
+ return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
+ }
+ if providerConfig.BaseURL == "" {
+ return errors.New("base URL must be set for custom providers")
+ }
+ if providerConfig.APIKey == "" {
+ return errors.New("API key must be set for custom providers")
+ }
+ }
+ return nil
+}
+
+func mergeOptions(base, other Options) Options {
+ result := base
+
+ if len(other.ContextPaths) > 0 {
+ base.ContextPaths = append(base.ContextPaths, other.ContextPaths...)
+ }
+
+ if other.TUI.CompactMode {
+ result.TUI.CompactMode = other.TUI.CompactMode
+ }
+
+ if other.Debug {
+ result.Debug = other.Debug
+ }
+
+ if other.DebugLSP {
+ result.DebugLSP = other.DebugLSP
+ }
+
+ if other.DisableAutoSummarize {
+ result.DisableAutoSummarize = other.DisableAutoSummarize
+ }
+
+ if other.DataDirectory != "" {
+ result.DataDirectory = other.DataDirectory
+ }
+
+ return result
+}
+
+func mergeProviderConfigs(base, global, local *Config) {
+ if global != nil {
+ for providerName, globalProvider := range global.Providers {
+ if _, ok := base.Providers[providerName]; !ok {
+ base.Providers[providerName] = globalProvider
+ } else {
+ base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
+ }
+ }
+ }
+ if local != nil {
+ for providerName, localProvider := range local.Providers {
+ if _, ok := base.Providers[providerName]; !ok {
+ base.Providers[providerName] = localProvider
+ } else {
+ base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider)
+ }
+ }
+ }
+
+ finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
+ for providerName, providerConfig := range base.Providers {
+ err := validateProvider(providerName, providerConfig)
+ if err != nil {
+ logging.Warn("Skipping provider", "name", providerName, "error", err)
+ }
+ finalProviders[providerName] = providerConfig
+ }
+ base.Providers = finalProviders
+}
+
+func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig {
+ switch providerName {
+ case provider.InferenceProviderAnthropic:
+ return ProviderConfig{
+ ProviderType: provider.TypeAnthropic,
+ }
+ case provider.InferenceProviderOpenAI:
+ return ProviderConfig{
+ ProviderType: provider.TypeOpenAI,
+ }
+ case provider.InferenceProviderGemini:
+ return ProviderConfig{
+ ProviderType: provider.TypeGemini,
+ }
+ case provider.InferenceProviderBedrock:
+ return ProviderConfig{
+ ProviderType: provider.TypeBedrock,
+ }
+ case provider.InferenceProviderAzure:
+ return ProviderConfig{
+ ProviderType: provider.TypeAzure,
+ }
+ case provider.InferenceProviderOpenRouter:
+ return ProviderConfig{
+ ProviderType: provider.TypeOpenAI,
+ BaseURL: "https://openrouter.ai/api/v1",
+ ExtraHeaders: map[string]string{
+ "HTTP-Referer": "crush.charm.land",
+ "X-Title": "Crush",
+ },
+ }
+ case provider.InferenceProviderXAI:
+ return ProviderConfig{
+ ProviderType: provider.TypeXAI,
+ BaseURL: "https://api.x.ai/v1",
+ }
+ case provider.InferenceProviderVertexAI:
+ return ProviderConfig{
+ ProviderType: provider.TypeVertexAI,
+ }
+ default:
+ return ProviderConfig{
+ ProviderType: provider.TypeOpenAI,
+ }
+ }
+}
+
+func defaultConfigBasedOnEnv() *Config {
+ cfg := &Config{
+ Options: Options{
+ DataDirectory: defaultDataDirectory,
+ },
+ Providers: make(map[provider.InferenceProvider]ProviderConfig),
+ }
+
+ providers := Providers()
+
+ for _, p := range providers {
+ if strings.HasPrefix(p.APIKey, "$") {
+ envVar := strings.TrimPrefix(p.APIKey, "$")
+ if apiKey := os.Getenv(envVar); apiKey != "" {
+ providerConfig := providerDefaultConfig(p.ID)
+ providerConfig.APIKey = apiKey
+ providerConfig.DefaultModel = p.DefaultModelID
+ cfg.Providers[p.ID] = providerConfig
+ }
+ }
+ }
+ // TODO: support local models
+
+ if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
+ providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
+ providerConfig.ExtraParams = map[string]string{
+ "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
+ "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
+ }
+ cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
+ }
+
+ if hasAWSCredentials() {
+ providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
+ cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
+ }
+ return cfg
+}
+
+func hasAWSCredentials() bool {
+ if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
+ return true
+ }
+
+ if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
+ return true
+ }
+
+ if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
+ return true
+ }
+
+ if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
+ os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
+ return true
+ }
+
+ return false
+}
+
+func WorkingDirectory() string {
+ return cwd
+}
@@ -0,0 +1,33 @@
+package configv2
+
+import (
+ "fmt"
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func resetEnvVars() {
+ os.Setenv("ANTHROPIC_API_KEY", "")
+ os.Setenv("OPENAI_API_KEY", "")
+ os.Setenv("GEMINI_API_KEY", "")
+ os.Setenv("XAI_API_KEY", "")
+ os.Setenv("OPENROUTER_API_KEY", "")
+}
+
+func TestConfigWithEnv(t *testing.T) {
+ resetEnvVars()
+ testConfigDir = t.TempDir()
+
+ cwdDir := t.TempDir()
+
+ os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
+ os.Setenv("OPENAI_API_KEY", "test-openai-key")
+ os.Setenv("GEMINI_API_KEY", "test-gemini-key")
+ os.Setenv("XAI_API_KEY", "test-xai-key")
+ os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
+ cfg := InitConfig(cwdDir)
+ fmt.Println(cfg)
+ assert.Len(t, cfg.Providers, 5)
+}
@@ -0,0 +1,71 @@
+package configv2
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+)
+
+var testConfigDir string
+
+func baseConfigPath() string {
+ if testConfigDir != "" {
+ return testConfigDir
+ }
+
+ xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
+ if xdgConfigHome != "" {
+ return filepath.Join(xdgConfigHome, "crush")
+ }
+
+ // return the path to the main config directory
+ // for windows, it should be in `%LOCALAPPDATA%/crush/`
+ // for linux and macOS, it should be in `$HOME/.config/crush/`
+ if runtime.GOOS == "windows" {
+ localAppData := os.Getenv("LOCALAPPDATA")
+ if localAppData == "" {
+ localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
+ }
+ return filepath.Join(localAppData, appName)
+ }
+
+ return filepath.Join(os.Getenv("HOME"), ".config", appName)
+}
+
+func baseDataPath() string {
+ if testConfigDir != "" {
+ return testConfigDir
+ }
+
+ xdgDataHome := os.Getenv("XDG_DATA_HOME")
+ if xdgDataHome != "" {
+ return filepath.Join(xdgDataHome, appName)
+ }
+
+ // return the path to the main data directory
+ // for windows, it should be in `%LOCALAPPDATA%/crush/`
+ // for linux and macOS, it should be in `$HOME/.local/share/crush/`
+ if runtime.GOOS == "windows" {
+ localAppData := os.Getenv("LOCALAPPDATA")
+ if localAppData == "" {
+ localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
+ }
+ return filepath.Join(localAppData, appName)
+ }
+
+ return filepath.Join(os.Getenv("HOME"), ".local", "share", appName)
+}
+
+func ConfigPath() string {
+ return filepath.Join(baseConfigPath(), fmt.Sprintf("%s.json", appName))
+}
+
+func CrushInitialized() bool {
+ cfgPath := ConfigPath()
+ if _, err := os.Stat(cfgPath); os.IsNotExist(err) {
+ // config file does not exist, so Crush is not initialized
+ return false
+ }
+ return true
+}
@@ -0,0 +1,69 @@
+package configv2
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "github.com/charmbracelet/fur/pkg/client"
+ "github.com/charmbracelet/fur/pkg/provider"
+)
+
+var fur = client.New()
+
+var (
+ providerOnc sync.Once // Ensures the initialization happens only once
+ providerList []provider.Provider
+)
+
+func providersPath() string {
+ return filepath.Join(baseDataPath(), "providers.json")
+}
+
+func saveProviders(providers []provider.Provider) error {
+ path := providersPath()
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return err
+ }
+
+ data, err := json.MarshalIndent(providers, "", " ")
+ if err != nil {
+ return err
+ }
+
+ return os.WriteFile(path, data, 0o644)
+}
+
+func loadProviders() ([]provider.Provider, error) {
+ path := providersPath()
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ var providers []provider.Provider
+ err = json.Unmarshal(data, &providers)
+ return providers, err
+}
+
+func Providers() []provider.Provider {
+ providerOnc.Do(func() {
+ // Try to get providers from upstream API
+ if providers, err := fur.GetProviders(); err == nil {
+ providerList = providers
+ // Save providers locally for future fallback
+ _ = saveProviders(providers)
+ } else {
+ // If upstream fails, try to load from local cache
+ if localProviders, localErr := loadProviders(); localErr == nil {
+ providerList = localProviders
+ } else {
+ // If both fail, return empty list
+ providerList = []provider.Provider{}
+ }
+ }
+ })
+ return providerList
+}
@@ -1,7 +1,7 @@
package models
const (
- ProviderAnthropic ModelProvider = "anthropic"
+ ProviderAnthropic InferenceProvider = "anthropic"
// Models
Claude35Sonnet ModelID = "claude-3.5-sonnet"
@@ -1,6 +1,6 @@
package models
-const ProviderAzure ModelProvider = "azure"
+const ProviderAzure InferenceProvider = "azure"
const (
AzureGPT41 ModelID = "azure.gpt-4.1"
@@ -1,7 +1,7 @@
package models
const (
- ProviderGemini ModelProvider = "gemini"
+ ProviderGemini InferenceProvider = "gemini"
// Models
Gemini25Flash ModelID = "gemini-2.5-flash"
@@ -1,7 +1,7 @@
package models
const (
- ProviderGROQ ModelProvider = "groq"
+ ProviderGROQ InferenceProvider = "groq"
// GROQ
QWENQwq ModelID = "qwen-qwq"
@@ -16,7 +16,7 @@ import (
)
const (
- ProviderLocal ModelProvider = "local"
+ ProviderLocal InferenceProvider = "local"
localModelsPath = "v1/models"
lmStudioBetaModelsPath = "api/v0/models"
@@ -3,23 +3,23 @@ package models
import "maps"
type (
- ModelID string
- ModelProvider string
+ ModelID string
+ InferenceProvider string
)
type Model struct {
- ID ModelID `json:"id"`
- Name string `json:"name"`
- Provider ModelProvider `json:"provider"`
- APIModel string `json:"api_model"`
- CostPer1MIn float64 `json:"cost_per_1m_in"`
- CostPer1MOut float64 `json:"cost_per_1m_out"`
- CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
- CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
- ContextWindow int64 `json:"context_window"`
- DefaultMaxTokens int64 `json:"default_max_tokens"`
- CanReason bool `json:"can_reason"`
- SupportsAttachments bool `json:"supports_attachments"`
+ ID ModelID `json:"id"`
+ Name string `json:"name"`
+ Provider InferenceProvider `json:"provider"`
+ APIModel string `json:"api_model"`
+ CostPer1MIn float64 `json:"cost_per_1m_in"`
+ CostPer1MOut float64 `json:"cost_per_1m_out"`
+ CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
+ CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
+ ContextWindow int64 `json:"context_window"`
+ DefaultMaxTokens int64 `json:"default_max_tokens"`
+ CanReason bool `json:"can_reason"`
+ SupportsAttachments bool `json:"supports_attachments"`
}
// Model IDs
@@ -29,9 +29,9 @@ const ( // GEMINI
)
const (
- ProviderBedrock ModelProvider = "bedrock"
+ ProviderBedrock InferenceProvider = "bedrock"
// ForTests
- ProviderMock ModelProvider = "__mock"
+ ProviderMock InferenceProvider = "__mock"
)
var SupportedModels = map[ModelID]Model{
@@ -48,6 +48,20 @@ var SupportedModels = map[ModelID]Model{
},
}
+var KnownProviders = []InferenceProvider{
+ ProviderAnthropic,
+ ProviderOpenAI,
+ ProviderGemini,
+ ProviderAzure,
+ ProviderGROQ,
+ ProviderLocal,
+ ProviderOpenRouter,
+ ProviderVertexAI,
+ ProviderBedrock,
+ ProviderXAI,
+ ProviderMock,
+}
+
func init() {
maps.Copy(SupportedModels, AnthropicModels)
maps.Copy(SupportedModels, OpenAIModels)
@@ -1,7 +1,7 @@
package models
const (
- ProviderOpenAI ModelProvider = "openai"
+ ProviderOpenAI InferenceProvider = "openai"
GPT41 ModelID = "gpt-4.1"
GPT41Mini ModelID = "gpt-4.1-mini"
@@ -1,7 +1,7 @@
package models
const (
- ProviderOpenRouter ModelProvider = "openrouter"
+ ProviderOpenRouter InferenceProvider = "openrouter"
OpenRouterGPT41 ModelID = "openrouter.gpt-4.1"
OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini"
@@ -1,7 +1,7 @@
package models
const (
- ProviderVertexAI ModelProvider = "vertexai"
+ ProviderVertexAI InferenceProvider = "vertexai"
// Models
VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash"
@@ -1,7 +1,7 @@
package models
const (
- ProviderXAI ModelProvider = "xai"
+ ProviderXAI InferenceProvider = "xai"
XAIGrok3Beta ModelID = "grok-3-beta"
XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
@@ -13,7 +13,7 @@ import (
"github.com/charmbracelet/crush/internal/llm/tools"
)
-func CoderPrompt(provider models.ModelProvider) string {
+func CoderPrompt(provider models.InferenceProvider) string {
basePrompt := baseAnthropicCoderPrompt
switch provider {
case models.ProviderOpenAI:
@@ -12,7 +12,7 @@ import (
"github.com/charmbracelet/crush/internal/logging"
)
-func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
+func GetAgentPrompt(agentName config.AgentName, provider models.InferenceProvider) string {
basePrompt := ""
switch agentName {
case config.AgentCoder:
@@ -2,7 +2,7 @@ package prompt
import "github.com/charmbracelet/crush/internal/llm/models"
-func SummarizerPrompt(_ models.ModelProvider) string {
+func SummarizerPrompt(_ models.InferenceProvider) string {
return `You are a helpful AI assistant tasked with summarizing conversations.
When asked to summarize, provide a detailed but concise summary of the conversation.
@@ -6,7 +6,7 @@ import (
"github.com/charmbracelet/crush/internal/llm/models"
)
-func TaskPrompt(_ models.ModelProvider) string {
+func TaskPrompt(_ models.InferenceProvider) string {
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
@@ -2,7 +2,7 @@ package prompt
import "github.com/charmbracelet/crush/internal/llm/models"
-func TitlePrompt(_ models.ModelProvider) string {
+func TitlePrompt(_ models.InferenceProvider) string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message
@@ -82,7 +82,7 @@ type baseProvider[C ProviderClient] struct {
client C
}
-func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
+func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) {
clientOptions := providerClientOptions{}
for _, o := range opts {
o(&clientOptions)
@@ -17,8 +17,8 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient {
}
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
- Project: os.Getenv("VERTEXAI_PROJECT"),
- Location: os.Getenv("VERTEXAI_LOCATION"),
+ Project: os.Getenv("GOOGLE_CLOUD_PROJECT"),
+ Location: os.Getenv("GOOGLE_CLOUD_LOCATION"),
Backend: genai.BackendVertexAI,
})
if err != nil {
@@ -71,7 +71,7 @@ type BinaryContent struct {
Data []byte
}
-func (bc BinaryContent) String(provider models.ModelProvider) string {
+func (bc BinaryContent) String(provider models.InferenceProvider) string {
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
if provider == models.ProviderOpenAI {
return "data:" + bc.MIMEType + ";base64," + base64Encoded
@@ -76,7 +76,7 @@ func NewSidebarCmp(history history.Service, lspClients map[string]*lsp.Client, c
}
func (m *sidebarCmp) Init() tea.Cmd {
- m.logo = m.logoBlock(false)
+ m.logo = m.logoBlock()
m.cwd = cwd()
return nil
}
@@ -231,9 +231,9 @@ func (m *sidebarCmp) loadSessionFiles() tea.Msg {
func (m *sidebarCmp) SetSize(width, height int) tea.Cmd {
if width < logoBreakpoint && (m.width == 0 || m.width >= logoBreakpoint) {
- m.logo = m.logoBlock(true)
+ m.logo = m.logoBlock()
} else if width >= logoBreakpoint && (m.width == 0 || m.width < logoBreakpoint) {
- m.logo = m.logoBlock(false)
+ m.logo = m.logoBlock()
}
m.width = width
@@ -245,9 +245,9 @@ func (m *sidebarCmp) GetSize() (int, int) {
return m.width, m.height
}
-func (m *sidebarCmp) logoBlock(compact bool) string {
+func (m *sidebarCmp) logoBlock() string {
t := styles.CurrentTheme()
- return logo.Render(version.Version, compact, logo.Opts{
+ return logo.Render(version.Version, true, logo.Opts{
FieldColor: t.Primary,
TitleColorA: t.Secondary,
TitleColorB: t.Primary,
@@ -0,0 +1,18 @@
+package splash
+
+import (
+ "github.com/charmbracelet/bubbles/v2/key"
+)
+
+type KeyMap struct {
+ Cancel key.Binding
+}
+
+func DefaultKeyMap() KeyMap {
+ return KeyMap{
+ Cancel: key.NewBinding(
+ key.WithKeys("esc"),
+ key.WithHelp("esc", "cancel"),
+ ),
+ }
+}
@@ -0,0 +1,85 @@
+package splash
+
+import (
+ "github.com/charmbracelet/bubbles/v2/key"
+ tea "github.com/charmbracelet/bubbletea/v2"
+ "github.com/charmbracelet/crush/internal/tui/components/core/layout"
+ "github.com/charmbracelet/crush/internal/tui/components/logo"
+ "github.com/charmbracelet/crush/internal/tui/styles"
+ "github.com/charmbracelet/crush/internal/tui/util"
+ "github.com/charmbracelet/crush/internal/version"
+ "github.com/charmbracelet/lipgloss/v2"
+)
+
+type Splash interface {
+ util.Model
+ layout.Sizeable
+ layout.Help
+}
+
+type splashCmp struct {
+ width, height int
+ keyMap KeyMap
+ logoRendered string
+}
+
+func New() Splash {
+ return &splashCmp{
+ width: 0,
+ height: 0,
+ keyMap: DefaultKeyMap(),
+ logoRendered: "",
+ }
+}
+
+// GetSize implements SplashPage.
+func (s *splashCmp) GetSize() (int, int) {
+ return s.width, s.height
+}
+
+// Init implements SplashPage.
+func (s *splashCmp) Init() tea.Cmd {
+ return nil
+}
+
+// SetSize implements SplashPage.
+func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
+ s.width = width
+ s.height = height
+ s.logoRendered = s.logoBlock()
+ return nil
+}
+
+// Update implements SplashPage.
+func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.WindowSizeMsg:
+ return s, s.SetSize(msg.Width, msg.Height)
+ }
+ return s, nil
+}
+
+// View implements SplashPage.
+func (s *splashCmp) View() tea.View {
+ content := lipgloss.JoinVertical(lipgloss.Left, s.logoRendered)
+ return tea.NewView(content)
+}
+
+func (m *splashCmp) logoBlock() string {
+ t := styles.CurrentTheme()
+ return logo.Render(version.Version, false, logo.Opts{
+ FieldColor: t.Primary,
+ TitleColorA: t.Secondary,
+ TitleColorB: t.Primary,
+ CharmColor: t.Secondary,
+ VersionColor: t.Primary,
+ Width: m.width - 2, // -2 for padding
+ })
+}
+
+// Bindings implements SplashPage.
+func (s *splashCmp) Bindings() []key.Binding {
+ return []key.Binding{
+ s.keyMap.Cancel,
+ }
+}
@@ -80,7 +80,7 @@ func NewModelDialogCmp() ModelDialog {
}
}
-var ProviderPopularity = map[models.ModelProvider]int{
+var ProviderPopularity = map[models.InferenceProvider]int{
models.ProviderAnthropic: 1,
models.ProviderOpenAI: 2,
models.ProviderGemini: 3,
@@ -92,7 +92,7 @@ var ProviderPopularity = map[models.ModelProvider]int{
models.ProviderXAI: 9,
}
-var ProviderName = map[models.ModelProvider]string{
+var ProviderName = map[models.InferenceProvider]string{
models.ProviderAnthropic: "Anthropic",
models.ProviderOpenAI: "OpenAI",
models.ProviderGemini: "Gemini",
@@ -195,8 +195,8 @@ func GetSelectedModel(cfg *config.Config) models.Model {
return models.SupportedModels[selectedModelID]
}
-func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
- var providers []models.ModelProvider
+func getEnabledProviders(cfg *config.Config) []models.InferenceProvider {
+ var providers []models.InferenceProvider
for providerID, provider := range cfg.Providers {
if !provider.Disabled {
providers = append(providers, providerID)
@@ -204,7 +204,7 @@ func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
}
// Sort by provider popularity
- slices.SortFunc(providers, func(a, b models.ModelProvider) int {
+ slices.SortFunc(providers, func(a, b models.InferenceProvider) int {
rA := ProviderPopularity[a]
rB := ProviderPopularity[b]
@@ -220,7 +220,7 @@ func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
return providers
}
-func getModelsForProvider(provider models.ModelProvider) []models.Model {
+func getModelsForProvider(provider models.InferenceProvider) []models.Model {
var providerModels []models.Model
for _, model := range models.SupportedModels {
if model.Provider == provider {
@@ -26,6 +26,7 @@ type Opts struct {
TitleColorB color.Color // right gradient ramp point
CharmColor color.Color // Charmβ’ text color
VersionColor color.Color // Version text color
+ Width int // width of the rendered logo, used for truncation
}
// Render renders the Crush logo. Set the argument to true to render the narrow
@@ -76,7 +77,7 @@ func Render(version string, compact bool, o Opts) string {
}
// Right field.
- const rightWidth = 15
+ rightWidth := max(15, o.Width-crushWidth-leftWidth) // 2 for the gap.
const stepDownAt = 0
rightField := new(strings.Builder)
for i := range fieldHeight {