Detailed changes
@@ -1 +1 @@
-{"flagWords":[],"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema"],"version":"0.2","language":"en"}
+{"language":"en","version":"0.2","flagWords":[],"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema","preconfigured","jsons","qjebbs","LOCALAPPDATA","USERPROFILE","stretchr"]}
@@ -40,7 +40,11 @@ require (
mvdan.cc/sh/v3 v3.11.0
)
-require github.com/spf13/cast v1.7.1 // indirect
+require (
+ github.com/joho/godotenv v1.5.1 // indirect
+ github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c // indirect
+ github.com/spf13/cast v1.7.1 // indirect
+)
require (
cloud.google.com/go v0.116.0 // indirect
@@ -155,6 +155,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
+github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
+github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -208,6 +210,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pressly/goose/v3 v3.24.2 h1:c/ie0Gm8rnIVKvnDQ/scHErv46jrDv9b4I0WRcFJzYU=
github.com/pressly/goose/v3 v3.24.2/go.mod h1:kjefwFB0eR4w30Td2Gj2Mznyw94vSP+2jJYkOVNbD1k=
+github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c h1:kmzxiX+OB0knCo1V0dkEkdPelzCdAzCURCfmFArn2/A=
+github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c/go.mod h1:wNJrtinHyC3YSf6giEh4FJN8+yZV7nXBjvmfjhBIcw4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
@@ -6,6 +6,8 @@ import (
_ "net/http/pprof" // profiling
+ _ "github.com/joho/godotenv/autoload" // automatically load .env files
+
"github.com/charmbracelet/crush/cmd"
"github.com/charmbracelet/crush/internal/logging"
)
@@ -0,0 +1,124 @@
+package config
+
+import "github.com/charmbracelet/crush/internal/fur/provider"
+
+const (
+ appName = "crush"
+ defaultDataDirectory = ".crush"
+ defaultLogLevel = "info"
+)
+
+var defaultContextPaths = []string{
+ ".github/copilot-instructions.md",
+ ".cursorrules",
+ ".cursor/rules/",
+ "CLAUDE.md",
+ "CLAUDE.local.md",
+ "GEMINI.md",
+ "gemini.md",
+ "crush.md",
+ "crush.local.md",
+ "Crush.md",
+ "Crush.local.md",
+ "CRUSH.md",
+ "CRUSH.local.md",
+}
+
+type SelectedModel struct {
+ // The model id as used by the provider API.
+ // Required.
+ Model string `json:"model"`
+ // The model provider, same as the key/id used in the providers config.
+ // Required.
+ Provider string `json:"provider"`
+
+ // Only used by models that use the openai provider and need this set.
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
+
+ // Overrides the default model configuration.
+ MaxTokens int64 `json:"max_tokens,omitempty"`
+
+ // Used by anthropic models that can reason to indicate if the model should think.
+ Think bool `json:"think,omitempty"`
+}
+
+type ProviderConfig struct {
+ // The provider's API endpoint.
+ BaseURL string `json:"base_url,omitempty"`
+ // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
+ Type provider.Type `json:"type,omitempty"`
+ // The provider's API key.
+ APIKey string `json:"api_key,omitempty"`
+ // Marks the provider as disabled.
+ Disable bool `json:"disable,omitempty"`
+
+ // Extra headers to send with each request to the provider.
+ ExtraHeaders map[string]string
+
+ // Used to pass extra parameters to the provider.
+ ExtraParams map[string]string `json:"-"`
+
+ // The provider models
+ Models []provider.Model `json:"models,omitempty"`
+}
+
+type MCPType string
+
+const (
+ MCPStdio MCPType = "stdio"
+ MCPSse MCPType = "sse"
+ MCPHttp MCPType = "http"
+)
+
+type MCP struct {
+ Command string `json:"command,omitempty" `
+ Env []string `json:"env,omitempty"`
+ Args []string `json:"args,omitempty"`
+ Type MCPType `json:"type"`
+ URL string `json:"url,omitempty"`
+
+ // TODO: maybe make it possible to get the value from the env
+ Headers map[string]string `json:"headers,omitempty"`
+}
+
+type LSPConfig struct {
+ Disabled bool `json:"enabled,omitempty"`
+ Command string `json:"command"`
+ Args []string `json:"args,omitempty"`
+ Options any `json:"options,omitempty"`
+}
+
+type TUIOptions struct {
+ CompactMode bool `json:"compact_mode,omitempty"`
+ // Here we can add themes later or any TUI related options
+}
+
+type Options struct {
+ ContextPaths []string `json:"context_paths,omitempty"`
+ TUI *TUIOptions `json:"tui,omitempty"`
+ Debug bool `json:"debug,omitempty"`
+ DebugLSP bool `json:"debug_lsp,omitempty"`
+ DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
+ // Relative to the cwd
+ DataDirectory string `json:"data_directory,omitempty"`
+}
+
+// Config holds the configuration for crush.
+type Config struct {
+ workingDir string `json:"-"`
+ // We currently only support large/small as values here.
+ Models map[string]SelectedModel `json:"models,omitempty"`
+
+ // The providers that are configured
+ Providers map[string]ProviderConfig `json:"providers,omitempty"`
+
+ MCP map[string]MCP `json:"mcp,omitempty"`
+
+ LSP map[string]LSPConfig `json:"lsp,omitempty"`
+
+ Options *Options `json:"options,omitempty"`
+}
+
+func (c *Config) WorkingDir() string {
+ return c.workingDir
+}
@@ -0,0 +1,277 @@
+package config
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "runtime"
+ "slices"
+ "strings"
+
+ "github.com/charmbracelet/crush/internal/fur/client"
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/charmbracelet/crush/pkg/env"
+)
+
+// LoadReader config via io.Reader.
+func LoadReader(fd io.Reader) (*Config, error) {
+ data, err := io.ReadAll(fd)
+ if err != nil {
+ return nil, err
+ }
+
+ var config Config
+ err = json.Unmarshal(data, &config)
+ if err != nil {
+ return nil, err
+ }
+ return &config, err
+}
+
+// Load loads the configuration from the default paths.
+func Load(workingDir string, env env.Env) (*Config, error) {
+ // uses default config paths
+ configPaths := []string{
+ globalConfig(),
+ globalConfigData(),
+ filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
+ filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
+ }
+ cfg, err := loadFromConfigPaths(configPaths)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load config: %w", err)
+ }
+ // TODO: maybe add a validation step here right after loading
+ // e.x validate the models
+ // e.x validate provider config
+
+ setDefaults(workingDir, cfg)
+
+ // Load known providers, this loads the config from fur
+ providers, err := LoadProviders(client.New())
+ if err != nil {
+ return nil, fmt.Errorf("failed to load providers: %w", err)
+ }
+
+ // Configure providers
+ valueResolver := NewShellVariableResolver(env)
+ if err := configureProviders(cfg, env, valueResolver, providers); err != nil {
+ return nil, fmt.Errorf("failed to configure providers: %w", err)
+ }
+
+ return cfg, nil
+}
+
+func configureProviders(cfg *Config, env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
+ for _, p := range knownProviders {
+
+ config, ok := cfg.Providers[string(p.ID)]
+ // if the user configured a known provider we need to allow it to override a couple of parameters
+ if ok {
+ if config.BaseURL != "" {
+ p.APIEndpoint = config.BaseURL
+ }
+ if config.APIKey != "" {
+ p.APIKey = config.APIKey
+ }
+ if len(config.Models) > 0 {
+ models := []provider.Model{}
+ seen := make(map[string]bool)
+
+ for _, model := range config.Models {
+ if seen[model.ID] {
+ continue
+ }
+ seen[model.ID] = true
+ models = append(models, model)
+ }
+ for _, model := range p.Models {
+ if seen[model.ID] {
+ continue
+ }
+ seen[model.ID] = true
+ models = append(models, model)
+ }
+
+ p.Models = models
+ }
+ }
+ prepared := ProviderConfig{
+ BaseURL: p.APIEndpoint,
+ APIKey: p.APIKey,
+ Type: p.Type,
+ Disable: config.Disable,
+ ExtraHeaders: config.ExtraHeaders,
+ ExtraParams: make(map[string]string),
+ Models: p.Models,
+ }
+
+ switch p.ID {
+ // Handle specific providers that require additional configuration
+ case provider.InferenceProviderVertexAI:
+ if !hasVertexCredentials(env) {
+ continue
+ }
+ prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
+ prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
+ case provider.InferenceProviderBedrock:
+ if !hasAWSCredentials(env) {
+ continue
+ }
+ for _, model := range p.Models {
+ if !strings.HasPrefix(model.ID, "anthropic.") {
+ return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
+ }
+ }
+ default:
+ // if the provider api or endpoint are missing we skip them
+ v, err := resolver.ResolveValue(p.APIKey)
+ if v == "" || err != nil {
+ continue
+ }
+ v, err = resolver.ResolveValue(p.APIEndpoint)
+ if v == "" || err != nil {
+ continue
+ }
+ }
+ cfg.Providers[string(p.ID)] = prepared
+ }
+ return nil
+}
+
+func hasVertexCredentials(env env.Env) bool {
+ useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
+ hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
+ hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
+ return useVertex && hasProject && hasLocation
+}
+
+func hasAWSCredentials(env env.Env) bool {
+ if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
+ return true
+ }
+
+ if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
+ return true
+ }
+
+ if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
+ return true
+ }
+
+ if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
+ env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
+ return true
+ }
+
+ return false
+}
+
+func setDefaults(workingDir string, cfg *Config) {
+ cfg.workingDir = workingDir
+ if cfg.Options == nil {
+ cfg.Options = &Options{}
+ }
+ if cfg.Options.TUI == nil {
+ cfg.Options.TUI = &TUIOptions{}
+ }
+ if cfg.Options.ContextPaths == nil {
+ cfg.Options.ContextPaths = []string{}
+ }
+ if cfg.Options.DataDirectory == "" {
+ cfg.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
+ }
+ if cfg.Providers == nil {
+ cfg.Providers = make(map[string]ProviderConfig)
+ }
+ if cfg.Models == nil {
+ cfg.Models = make(map[string]SelectedModel)
+ }
+ if cfg.MCP == nil {
+ cfg.MCP = make(map[string]MCP)
+ }
+ if cfg.LSP == nil {
+ cfg.LSP = make(map[string]LSPConfig)
+ }
+
+ // Add the default context paths if they are not already present
+ cfg.Options.ContextPaths = append(defaultContextPaths, cfg.Options.ContextPaths...)
+ slices.Sort(cfg.Options.ContextPaths)
+ cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths)
+}
+
+func loadFromConfigPaths(configPaths []string) (*Config, error) {
+ var configs []io.Reader
+
+ for _, path := range configPaths {
+ fd, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ continue
+ }
+ return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
+ }
+ defer fd.Close()
+
+ configs = append(configs, fd)
+ }
+
+ return loadFromReaders(configs)
+}
+
+func loadFromReaders(readers []io.Reader) (*Config, error) {
+ if len(readers) == 0 {
+ return nil, fmt.Errorf("no configuration readers provided")
+ }
+
+ merged, err := Merge(readers)
+ if err != nil {
+ return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
+ }
+
+ return LoadReader(merged)
+}
+
+func globalConfig() string {
+ xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
+ if xdgConfigHome != "" {
+ return filepath.Join(xdgConfigHome, "crush")
+ }
+
+ // return the path to the main config directory
+ // for windows, it should be in `%LOCALAPPDATA%/crush/`
+ // for linux and macOS, it should be in `$HOME/.config/crush/`
+ if runtime.GOOS == "windows" {
+ localAppData := os.Getenv("LOCALAPPDATA")
+ if localAppData == "" {
+ localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
+ }
+ return filepath.Join(localAppData, appName)
+ }
+
+ return filepath.Join(os.Getenv("HOME"), ".config", appName, fmt.Sprintf("%s.json", appName))
+}
+
+// globalConfigData returns the path to the main data directory for the application.
+// this config is used when the app overrides configurations instead of updating the global config.
+func globalConfigData() string {
+ xdgDataHome := os.Getenv("XDG_DATA_HOME")
+ if xdgDataHome != "" {
+ return filepath.Join(xdgDataHome, appName)
+ }
+
+ // return the path to the main data directory
+ // for windows, it should be in `%LOCALAPPDATA%/crush/`
+ // for linux and macOS, it should be in `$HOME/.local/share/crush/`
+ if runtime.GOOS == "windows" {
+ localAppData := os.Getenv("LOCALAPPDATA")
+ if localAppData == "" {
+ localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
+ }
+ return filepath.Join(localAppData, appName)
+ }
+
+ return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
+}
@@ -0,0 +1,316 @@
+package config
+
+import (
+ "io"
+ "strings"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/charmbracelet/crush/pkg/env"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestConfig_LoadFromReaders(t *testing.T) {
+ data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
+ data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
+ data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
+
+ loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
+
+ assert.NoError(t, err)
+ assert.NotNil(t, loadedConfig)
+ assert.Len(t, loadedConfig.Providers, 1)
+ assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey)
+ assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL)
+}
+
+func TestConfig_setDefaults(t *testing.T) {
+ cfg := &Config{}
+
+ setDefaults("/tmp", cfg)
+
+ assert.NotNil(t, cfg.Options)
+ assert.NotNil(t, cfg.Options.TUI)
+ assert.NotNil(t, cfg.Options.ContextPaths)
+ assert.NotNil(t, cfg.Providers)
+ assert.NotNil(t, cfg.Models)
+ assert.NotNil(t, cfg.LSP)
+ assert.NotNil(t, cfg.MCP)
+ assert.Equal(t, "/tmp/.crush", cfg.Options.DataDirectory)
+ for _, path := range defaultContextPaths {
+ assert.Contains(t, cfg.Options.ContextPaths, path)
+ }
+ assert.Equal(t, "/tmp", cfg.workingDir)
+}
+
+func TestConfig_configureProviders(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ }
+
+ cfg := &Config{}
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{
+ "OPENAI_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+
+ // We want to make sure that we keep the configured API key as a placeholder
+ assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey)
+}
+
+func TestConfig_configureProvidersWithOverride(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "openai": {
+ APIKey: "xyz",
+ BaseURL: "https://api.openai.com/v2",
+ Models: []provider.Model{
+ {
+ ID: "test-model",
+ Name: "Updated",
+ },
+ {
+ ID: "another-model",
+ },
+ },
+ },
+ },
+ }
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{
+ "OPENAI_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+
+ // We want to make sure that we keep the configured API key as a placeholder
+ assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey)
+ assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL)
+ assert.Len(t, cfg.Providers["openai"].Models, 2)
+ assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Name)
+}
+
+func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "xyz",
+ BaseURL: "https://api.someendpoint.com/v2",
+ Models: []provider.Model{
+ {
+ ID: "test-model",
+ },
+ },
+ },
+ },
+ }
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{
+ "OPENAI_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.NoError(t, err)
+ // Should be to because of the env variable
+ assert.Len(t, cfg.Providers, 2)
+
+ // We want to make sure that we keep the configured API key as a placeholder
+ assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey)
+ assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL)
+ assert.Len(t, cfg.Providers["custom"].Models, 1)
+
+ _, ok := cfg.Providers["openai"]
+ assert.True(t, ok, "OpenAI provider should still be present")
+}
+
+func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: provider.InferenceProviderBedrock,
+ APIKey: "",
+ APIEndpoint: "",
+ Models: []provider.Model{{
+ ID: "anthropic.claude-sonnet-4-20250514-v1:0",
+ }},
+ },
+ }
+
+ cfg := &Config{}
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{
+ "AWS_ACCESS_KEY_ID": "test-key-id",
+ "AWS_SECRET_ACCESS_KEY": "test-secret-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+
+ bedrockProvider, ok := cfg.Providers["bedrock"]
+ assert.True(t, ok, "Bedrock provider should be present")
+ assert.Len(t, bedrockProvider.Models, 1)
+ assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
+}
+
+func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: provider.InferenceProviderBedrock,
+ APIKey: "",
+ APIEndpoint: "",
+ Models: []provider.Model{{
+ ID: "anthropic.claude-sonnet-4-20250514-v1:0",
+ }},
+ },
+ }
+
+ cfg := &Config{}
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.NoError(t, err)
+ // Provider should not be configured without credentials
+ assert.Len(t, cfg.Providers, 0)
+}
+
+func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: provider.InferenceProviderBedrock,
+ APIKey: "",
+ APIEndpoint: "",
+ Models: []provider.Model{{
+ ID: "some-random-model",
+ }},
+ },
+ }
+
+ cfg := &Config{}
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{
+ "AWS_ACCESS_KEY_ID": "test-key-id",
+ "AWS_SECRET_ACCESS_KEY": "test-secret-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.Error(t, err)
+}
+
+func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: provider.InferenceProviderVertexAI,
+ APIKey: "",
+ APIEndpoint: "",
+ Models: []provider.Model{{
+ ID: "gemini-pro",
+ }},
+ },
+ }
+
+ cfg := &Config{}
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{
+ "GOOGLE_GENAI_USE_VERTEXAI": "true",
+ "GOOGLE_CLOUD_PROJECT": "test-project",
+ "GOOGLE_CLOUD_LOCATION": "us-central1",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+
+ vertexProvider, ok := cfg.Providers["vertexai"]
+ assert.True(t, ok, "VertexAI provider should be present")
+ assert.Len(t, vertexProvider.Models, 1)
+ assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
+ assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
+ assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
+}
+
+func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: provider.InferenceProviderVertexAI,
+ APIKey: "",
+ APIEndpoint: "",
+ Models: []provider.Model{{
+ ID: "gemini-pro",
+ }},
+ },
+ }
+
+ cfg := &Config{}
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{
+ "GOOGLE_GENAI_USE_VERTEXAI": "false",
+ "GOOGLE_CLOUD_PROJECT": "test-project",
+ "GOOGLE_CLOUD_LOCATION": "us-central1",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.NoError(t, err)
+ // Provider should not be configured without proper credentials
+ assert.Len(t, cfg.Providers, 0)
+}
+
+func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: provider.InferenceProviderVertexAI,
+ APIKey: "",
+ APIEndpoint: "",
+ Models: []provider.Model{{
+ ID: "gemini-pro",
+ }},
+ },
+ }
+
+ cfg := &Config{}
+ setDefaults("/tmp", cfg)
+ env := env.NewFromMap(map[string]string{
+ "GOOGLE_GENAI_USE_VERTEXAI": "true",
+ "GOOGLE_CLOUD_LOCATION": "us-central1",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := configureProviders(cfg, env, resolver, knownProviders)
+ assert.NoError(t, err)
+ // Provider should not be configured without project
+ assert.Len(t, cfg.Providers, 0)
+}
@@ -0,0 +1,16 @@
+package config
+
+import (
+ "bytes"
+ "io"
+
+ "github.com/qjebbs/go-jsons"
+)
+
+func Merge(data []io.Reader) (io.Reader, error) {
+ got, err := jsons.Merge(data)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(got), nil
+}
@@ -0,0 +1,27 @@
+package config
+
+import (
+ "io"
+ "strings"
+ "testing"
+)
+
+func TestMerge(t *testing.T) {
+ data1 := strings.NewReader(`{"foo": "bar"}`)
+ data2 := strings.NewReader(`{"baz": "qux"}`)
+
+ merged, err := Merge([]io.Reader{data1, data2})
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ expected := `{"baz":"qux","foo":"bar"}`
+ got, err := io.ReadAll(merged)
+ if err != nil {
+ t.Fatalf("expected no error reading merged data, got %v", err)
+ }
+
+ if string(got) != expected {
+ t.Errorf("expected %s, got %s", expected, string(got))
+ }
+}
@@ -0,0 +1,93 @@
+package config
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "runtime"
+ "sync"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
+
+type ProviderClient interface {
+ GetProviders() ([]provider.Provider, error)
+}
+
+var (
+ providerOnce sync.Once
+ providerList []provider.Provider
+)
+
+// file to cache provider data
+func providerCacheFileData() string {
+ xdgDataHome := os.Getenv("XDG_DATA_HOME")
+ if xdgDataHome != "" {
+ return filepath.Join(xdgDataHome, appName)
+ }
+
+ // return the path to the main data directory
+ // for windows, it should be in `%LOCALAPPDATA%/crush/`
+ // for linux and macOS, it should be in `$HOME/.local/share/crush/`
+ if runtime.GOOS == "windows" {
+ localAppData := os.Getenv("LOCALAPPDATA")
+ if localAppData == "" {
+ localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
+ }
+ return filepath.Join(localAppData, appName)
+ }
+
+ return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
+}
+
+func saveProvidersInCache(path string, providers []provider.Provider) error {
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return err
+ }
+
+ data, err := json.MarshalIndent(providers, "", " ")
+ if err != nil {
+ return err
+ }
+
+ return os.WriteFile(path, data, 0o644)
+}
+
+func loadProvidersFromCache(path string) ([]provider.Provider, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ var providers []provider.Provider
+ err = json.Unmarshal(data, &providers)
+ return providers, err
+}
+
+func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
+ providers, err := client.GetProviders()
+ if err != nil {
+ fallbackToCache, err := loadProvidersFromCache(path)
+ if err != nil {
+ return nil, err
+ }
+ providers = fallbackToCache
+ } else {
+ if err := saveProvidersInCache(path, providerList); err != nil {
+ return nil, err
+ }
+ }
+ return providers, nil
+}
+
+func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
+ var err error
+ providerOnce.Do(func() {
+ providerList, err = loadProviders(providerCacheFileData(), client)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return providerList, nil
+}
@@ -0,0 +1,74 @@
+package config
+
+import (
+ "encoding/json"
+ "errors"
+ "os"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/stretchr/testify/assert"
+)
+
+type mockProviderClient struct {
+ shouldFail bool
+}
+
+func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
+ if m.shouldFail {
+ return nil, errors.New("failed to load providers")
+ }
+ return []provider.Provider{
+ {
+ Name: "Mock",
+ },
+ }, nil
+}
+
+func TestProvider_loadProvidersNoIssues(t *testing.T) {
+ client := &mockProviderClient{shouldFail: false}
+ tmpPath := t.TempDir() + "/providers.json"
+ providers, err := loadProviders(tmpPath, client)
+ assert.NoError(t, err)
+ assert.NotNil(t, providers)
+ assert.Len(t, providers, 1)
+
+ // check if file got saved
+ fileInfo, err := os.Stat(tmpPath)
+ assert.NoError(t, err)
+ assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
+}
+
+func TestProvider_loadProvidersWithIssues(t *testing.T) {
+ client := &mockProviderClient{shouldFail: true}
+ tmpPath := t.TempDir() + "/providers.json"
+ // store providers to a temporary file
+ oldProviders := []provider.Provider{
+ {
+ Name: "OldProvider",
+ },
+ }
+ data, err := json.Marshal(oldProviders)
+ if err != nil {
+ t.Fatalf("Failed to marshal old providers: %v", err)
+ }
+
+ err = os.WriteFile(tmpPath, data, 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write old providers to file: %v", err)
+ }
+ providers, err := loadProviders(tmpPath, client)
+ assert.NoError(t, err)
+ assert.NotNil(t, providers)
+ assert.Len(t, providers, 1)
+ assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
+}
+
+func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
+ client := &mockProviderClient{shouldFail: true}
+ tmpPath := t.TempDir() + "/providers.json"
+ providers, err := loadProviders(tmpPath, client)
+ assert.Error(t, err)
+ assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
+}
+
@@ -0,0 +1,89 @@
+package config
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/shell"
+ "github.com/charmbracelet/crush/pkg/env"
+)
+
+type VariableResolver interface {
+ ResolveValue(value string) (string, error)
+}
+
+type Shell interface {
+ Exec(ctx context.Context, command string) (stdout, stderr string, err error)
+}
+
+type shellVariableResolver struct {
+ shell Shell
+ env env.Env
+}
+
+func NewShellVariableResolver(env env.Env) VariableResolver {
+ return &shellVariableResolver{
+ shell: shell.NewShell(
+ &shell.Options{
+ Env: env.Env(),
+ },
+ ),
+ }
+}
+
+// ResolveValue is a method for resolving values, such as environment variables.
+// it will expect strings that start with `$` to be resolved as environment variables or shell commands.
+// if the string does not start with `$`, it will return the string as is.
+func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
+ if !strings.HasPrefix(value, "$") {
+ return value, nil
+ }
+
+ if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") {
+ command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")")
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ stdout, _, err := r.shell.Exec(ctx, command)
+ if err != nil {
+ return "", fmt.Errorf("command execution failed: %w", err)
+ }
+ return strings.TrimSpace(stdout), nil
+ }
+
+ if strings.HasPrefix(value, "$") {
+ varName := strings.TrimPrefix(value, "$")
+ value = r.env.Get(varName)
+ if value == "" {
+ return "", fmt.Errorf("environment variable %q not set", varName)
+ }
+ return value, nil
+ }
+ return "", fmt.Errorf("invalid value format: %s", value)
+}
+
+type environmentVariableResolver struct {
+ env env.Env
+}
+
+func NewEnvironmentVariableResolver(env env.Env) VariableResolver {
+ return &environmentVariableResolver{
+ env: env,
+ }
+}
+
+// ResolveValue resolves environment variables from the provided env.Env.
+func (r *environmentVariableResolver) ResolveValue(value string) (string, error) {
+ if !strings.HasPrefix(value, "$") {
+ return value, nil
+ }
+
+ varName := strings.TrimPrefix(value, "$")
+ resolvedValue := r.env.Get(varName)
+ if resolvedValue == "" {
+ return "", fmt.Errorf("environment variable %q not set", varName)
+ }
+ return resolvedValue, nil
+}
@@ -0,0 +1,178 @@
+package config
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/charmbracelet/crush/pkg/env"
+ "github.com/stretchr/testify/assert"
+)
+
+// mockShell implements the Shell interface for testing
+type mockShell struct {
+ execFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
+}
+
+func (m *mockShell) Exec(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if m.execFunc != nil {
+ return m.execFunc(ctx, command)
+ }
+ return "", "", nil
+}
+
+func TestShellVariableResolver_ResolveValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ envVars map[string]string
+ shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
+ expected string
+ expectError bool
+ }{
+ {
+ name: "non-variable string returns as-is",
+ value: "plain-string",
+ expected: "plain-string",
+ },
+ {
+ name: "environment variable resolution",
+ value: "$HOME",
+ envVars: map[string]string{"HOME": "/home/user"},
+ expected: "/home/user",
+ },
+ {
+ name: "missing environment variable returns error",
+ value: "$MISSING_VAR",
+ envVars: map[string]string{},
+ expectError: true,
+ },
+ {
+ name: "shell command execution",
+ value: "$(echo hello)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "echo hello" {
+ return "hello\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "hello",
+ },
+ {
+ name: "shell command with whitespace trimming",
+ value: "$(echo ' spaced ')",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "echo ' spaced '" {
+ return " spaced \n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "spaced",
+ },
+ {
+ name: "shell command execution error",
+ value: "$(false)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ return "", "", errors.New("command failed")
+ },
+ expectError: true,
+ },
+ {
+ name: "invalid format returns error",
+ value: "$",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ testEnv := env.NewFromMap(tt.envVars)
+ resolver := &shellVariableResolver{
+ shell: &mockShell{execFunc: tt.shellFunc},
+ env: testEnv,
+ }
+
+ result, err := resolver.ResolveValue(tt.value)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ envVars map[string]string
+ expected string
+ expectError bool
+ }{
+ {
+ name: "non-variable string returns as-is",
+ value: "plain-string",
+ expected: "plain-string",
+ },
+ {
+ name: "environment variable resolution",
+ value: "$HOME",
+ envVars: map[string]string{"HOME": "/home/user"},
+ expected: "/home/user",
+ },
+ {
+ name: "environment variable with complex value",
+ value: "$PATH",
+ envVars: map[string]string{"PATH": "/usr/bin:/bin:/usr/local/bin"},
+ expected: "/usr/bin:/bin:/usr/local/bin",
+ },
+ {
+ name: "missing environment variable returns error",
+ value: "$MISSING_VAR",
+ envVars: map[string]string{},
+ expectError: true,
+ },
+ {
+ name: "empty environment variable returns error",
+ value: "$EMPTY_VAR",
+ envVars: map[string]string{"EMPTY_VAR": ""},
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ testEnv := env.NewFromMap(tt.envVars)
+ resolver := NewEnvironmentVariableResolver(testEnv)
+
+ result, err := resolver.ResolveValue(tt.value)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestNewShellVariableResolver(t *testing.T) {
+ testEnv := env.NewFromMap(map[string]string{"TEST": "value"})
+ resolver := NewShellVariableResolver(testEnv)
+
+ assert.NotNil(t, resolver)
+ assert.Implements(t, (*VariableResolver)(nil), resolver)
+}
+
+func TestNewEnvironmentVariableResolver(t *testing.T) {
+ testEnv := env.NewFromMap(map[string]string{"TEST": "value"})
+ resolver := NewEnvironmentVariableResolver(testEnv)
+
+ assert.NotNil(t, resolver)
+ assert.Implements(t, (*VariableResolver)(nil), resolver)
+}
+
@@ -0,0 +1,55 @@
+package env
+
+import "os"
+
+type Env interface {
+ Get(key string) string
+ Env() []string
+}
+
+type osEnv struct{}
+
+// Get implements Env.
+func (o *osEnv) Get(key string) string {
+ return os.Getenv(key)
+}
+
+func (o *osEnv) Env() []string {
+ env := os.Environ()
+ if len(env) == 0 {
+ return nil
+ }
+ return env
+}
+
+func New() Env {
+ return &osEnv{}
+}
+
+type mapEnv struct {
+ m map[string]string
+}
+
+// Get implements Env.
+func (m *mapEnv) Get(key string) string {
+ if value, ok := m.m[key]; ok {
+ return value
+ }
+ return ""
+}
+
+// Env implements Env.
+func (m *mapEnv) Env() []string {
+ if len(m.m) == 0 {
+ return nil
+ }
+ env := make([]string, 0, len(m.m))
+ for k, v := range m.m {
+ env = append(env, k+"="+v)
+ }
+ return env
+}
+
+func NewFromMap(m map[string]string) Env {
+ return &mapEnv{m: m}
+}
@@ -1,51 +0,0 @@
-## TODOs before release
-
-- [x] Implement help
- - [x] Show full help
- - [x] Make help dependent on the focused pane and page
-- [x] Implement current model in the sidebar
-- [x] Implement LSP errors
-- [x] Implement changed files
- - [x] Implement initial load
- - [x] Implement realtime file changes
-- [ ] Events when tool error
-- [ ] Support bash commands
-- [ ] Editor attachments fixes
- - [ ] Reimplement removing attachments
-- [ ] Fix the logs view
- - [ ] Review the implementation
- - [ ] The page lags
- - [ ] Make the logs long lived ?
-- [ ] Add all possible actions to the commands
-- [ ] Parallel tool calls and permissions
- - [ ] Run the tools in parallel and add results in parallel
- - [ ] Show multiple permissions dialogs
-- [ ] Add another space around buttons
-- [ ] Completions
- - [ ] Should change the help to show the completions stuff
- - [ ] Should make it wider
- - [ ] Tab and ctrl+y should accept
- - [ ] Words should line up
- - [ ] If there are no completions and cick tab/ctrl+y/enter it should close it
-- [ ] Investigate messages issues
- - [ ] Make the agent separator look like the
- - [ ] Cleanup tool calls (watch all states)
- - [ ] Weird behavior sometimes the message does not update
- - [ ] Message length (I saw the message go beyond the correct length when there are errors)
- - [ ] Address UX issues
- - [ ] Fix issue with numbers (padding) view tool
-- [x] Implement responsive mode
-- [ ] Update interactive mode to use the spinner
-- [ ] Revisit the core list component
- - [ ] This component has become super complex we might need to fix this.
-- [ ] Handle correct LSP and MCP status icon
-- [x] Investigate ways to make the spinner less CPU intensive
-- [ ] General cleanup and documentation
-- [ ] Update the readme
-
-## Maybe
-
-- [ ] Revisit the provider/model/configs
-- [ ] Implement correct persistent shell
-- [ ] Store file read/write time somewhere so that the we can make sure that even if we restart we do not need to re-read the same file
-- [ ] Send updates to the UI when new LSP diagnostics are available