From d4927ace386b0e25fc978b47e6dbdb5d45b9c096 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 4 Jul 2025 19:41:43 +0200 Subject: [PATCH] chore: initial config tests --- cspell.json | 2 +- go.mod | 6 +- go.sum | 4 + main.go | 2 + pkg/config/config.go | 124 ++++++++++++++ pkg/config/load.go | 277 +++++++++++++++++++++++++++++++ pkg/config/load_test.go | 316 ++++++++++++++++++++++++++++++++++++ pkg/config/merge.go | 16 ++ pkg/config/merge_test.go | 27 +++ pkg/config/provider.go | 93 +++++++++++ pkg/config/provider_test.go | 74 +++++++++ pkg/config/resolve.go | 89 ++++++++++ pkg/config/resolve_test.go | 178 ++++++++++++++++++++ pkg/env/env.go | 55 +++++++ todos.md | 51 ------ 15 files changed, 1261 insertions(+), 53 deletions(-) create mode 100644 pkg/config/config.go create mode 100644 pkg/config/load.go create mode 100644 pkg/config/load_test.go create mode 100644 pkg/config/merge.go create mode 100644 pkg/config/merge_test.go create mode 100644 pkg/config/provider.go create mode 100644 pkg/config/provider_test.go create mode 100644 pkg/config/resolve.go create mode 100644 pkg/config/resolve_test.go create mode 100644 pkg/env/env.go delete mode 100644 todos.md diff --git a/cspell.json b/cspell.json index d98b1326e54c8b62c7ad700fe19b4cbbe3e4f672..5b0877dc174821537da9b67ea67d638965f9a9f7 100644 --- a/cspell.json +++ b/cspell.json @@ -1 +1 @@ -{"flagWords":[],"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema"],"version":"0.2","language":"en"} \ No newline at end of file +{"language":"en","version":"0.2","flagWords":[],"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema","preconfigured","jsons","qjebbs","LOCALAPPDATA","USERPROFILE","stretchr"]} \ No newline at end of file diff --git a/go.mod b/go.mod index 404eeb5d1e2bce529bf240972c3f0745e97aec6a..b0f04c00392c6ad3cb661542f5b9ce1b961da684 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,11 @@ require ( mvdan.cc/sh/v3 v3.11.0 ) -require github.com/spf13/cast v1.7.1 // indirect +require ( + github.com/joho/godotenv v1.5.1 // indirect + github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c // indirect + github.com/spf13/cast v1.7.1 // indirect +) require ( cloud.google.com/go v0.116.0 // indirect diff --git a/go.sum b/go.sum index fc4d0a1a51d01458e172022c8888d7a5acfdf187..ff706c7c003e9047d2b3e5bf8239646571c67321 100644 --- a/go.sum +++ b/go.sum @@ -155,6 +155,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -208,6 +210,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pressly/goose/v3 v3.24.2 h1:c/ie0Gm8rnIVKvnDQ/scHErv46jrDv9b4I0WRcFJzYU= github.com/pressly/goose/v3 v3.24.2/go.mod h1:kjefwFB0eR4w30Td2Gj2Mznyw94vSP+2jJYkOVNbD1k= +github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c h1:kmzxiX+OB0knCo1V0dkEkdPelzCdAzCURCfmFArn2/A= +github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c/go.mod h1:wNJrtinHyC3YSf6giEh4FJN8+yZV7nXBjvmfjhBIcw4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= diff --git a/main.go b/main.go index a5305d08d7ae3ede818568d5cf825d1ce52bbf61..ce145401f8828beeae7bb3b7747bdc9339e30405 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,8 @@ import ( _ "net/http/pprof" // profiling + _ "github.com/joho/godotenv/autoload" // automatically load .env files + "github.com/charmbracelet/crush/cmd" "github.com/charmbracelet/crush/internal/logging" ) diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000000000000000000000000000000000000..70f569b4e998230c5770122c164ea5232d67d919 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,124 @@ +package config + +import "github.com/charmbracelet/crush/internal/fur/provider" + +const ( + appName = "crush" + defaultDataDirectory = ".crush" + defaultLogLevel = "info" +) + +var defaultContextPaths = []string{ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "GEMINI.md", + "gemini.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md", +} + +type SelectedModel struct { + // The model id as used by the provider API. + // Required. + Model string `json:"model"` + // The model provider, same as the key/id used in the providers config. + // Required. + Provider string `json:"provider"` + + // Only used by models that use the openai provider and need this set. + ReasoningEffort string `json:"reasoning_effort,omitempty"` + + // Overrides the default model configuration. + MaxTokens int64 `json:"max_tokens,omitempty"` + + // Used by anthropic models that can reason to indicate if the model should think. + Think bool `json:"think,omitempty"` +} + +type ProviderConfig struct { + // The provider's API endpoint. + BaseURL string `json:"base_url,omitempty"` + // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai. + Type provider.Type `json:"type,omitempty"` + // The provider's API key. + APIKey string `json:"api_key,omitempty"` + // Marks the provider as disabled. + Disable bool `json:"disable,omitempty"` + + // Extra headers to send with each request to the provider. + ExtraHeaders map[string]string + + // Used to pass extra parameters to the provider. + ExtraParams map[string]string `json:"-"` + + // The provider models + Models []provider.Model `json:"models,omitempty"` +} + +type MCPType string + +const ( + MCPStdio MCPType = "stdio" + MCPSse MCPType = "sse" + MCPHttp MCPType = "http" +) + +type MCP struct { + Command string `json:"command,omitempty" ` + Env []string `json:"env,omitempty"` + Args []string `json:"args,omitempty"` + Type MCPType `json:"type"` + URL string `json:"url,omitempty"` + + // TODO: maybe make it possible to get the value from the env + Headers map[string]string `json:"headers,omitempty"` +} + +type LSPConfig struct { + Disabled bool `json:"enabled,omitempty"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + Options any `json:"options,omitempty"` +} + +type TUIOptions struct { + CompactMode bool `json:"compact_mode,omitempty"` + // Here we can add themes later or any TUI related options +} + +type Options struct { + ContextPaths []string `json:"context_paths,omitempty"` + TUI *TUIOptions `json:"tui,omitempty"` + Debug bool `json:"debug,omitempty"` + DebugLSP bool `json:"debug_lsp,omitempty"` + DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"` + // Relative to the cwd + DataDirectory string `json:"data_directory,omitempty"` +} + +// Config holds the configuration for crush. +type Config struct { + workingDir string `json:"-"` + // We currently only support large/small as values here. + Models map[string]SelectedModel `json:"models,omitempty"` + + // The providers that are configured + Providers map[string]ProviderConfig `json:"providers,omitempty"` + + MCP map[string]MCP `json:"mcp,omitempty"` + + LSP map[string]LSPConfig `json:"lsp,omitempty"` + + Options *Options `json:"options,omitempty"` +} + +func (c *Config) WorkingDir() string { + return c.workingDir +} diff --git a/pkg/config/load.go b/pkg/config/load.go new file mode 100644 index 0000000000000000000000000000000000000000..19f84625891616a4e60cbdd47566af6bb9caa6a4 --- /dev/null +++ b/pkg/config/load.go @@ -0,0 +1,277 @@ +package config + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "slices" + "strings" + + "github.com/charmbracelet/crush/internal/fur/client" + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/crush/pkg/env" +) + +// LoadReader config via io.Reader. +func LoadReader(fd io.Reader) (*Config, error) { + data, err := io.ReadAll(fd) + if err != nil { + return nil, err + } + + var config Config + err = json.Unmarshal(data, &config) + if err != nil { + return nil, err + } + return &config, err +} + +// Load loads the configuration from the default paths. +func Load(workingDir string, env env.Env) (*Config, error) { + // uses default config paths + configPaths := []string{ + globalConfig(), + globalConfigData(), + filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)), + filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)), + } + cfg, err := loadFromConfigPaths(configPaths) + if err != nil { + return nil, fmt.Errorf("failed to load config: %w", err) + } + // TODO: maybe add a validation step here right after loading + // e.x validate the models + // e.x validate provider config + + setDefaults(workingDir, cfg) + + // Load known providers, this loads the config from fur + providers, err := LoadProviders(client.New()) + if err != nil { + return nil, fmt.Errorf("failed to load providers: %w", err) + } + + // Configure providers + valueResolver := NewShellVariableResolver(env) + if err := configureProviders(cfg, env, valueResolver, providers); err != nil { + return nil, fmt.Errorf("failed to configure providers: %w", err) + } + + return cfg, nil +} + +func configureProviders(cfg *Config, env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error { + for _, p := range knownProviders { + + config, ok := cfg.Providers[string(p.ID)] + // if the user configured a known provider we need to allow it to override a couple of parameters + if ok { + if config.BaseURL != "" { + p.APIEndpoint = config.BaseURL + } + if config.APIKey != "" { + p.APIKey = config.APIKey + } + if len(config.Models) > 0 { + models := []provider.Model{} + seen := make(map[string]bool) + + for _, model := range config.Models { + if seen[model.ID] { + continue + } + seen[model.ID] = true + models = append(models, model) + } + for _, model := range p.Models { + if seen[model.ID] { + continue + } + seen[model.ID] = true + models = append(models, model) + } + + p.Models = models + } + } + prepared := ProviderConfig{ + BaseURL: p.APIEndpoint, + APIKey: p.APIKey, + Type: p.Type, + Disable: config.Disable, + ExtraHeaders: config.ExtraHeaders, + ExtraParams: make(map[string]string), + Models: p.Models, + } + + switch p.ID { + // Handle specific providers that require additional configuration + case provider.InferenceProviderVertexAI: + if !hasVertexCredentials(env) { + continue + } + prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT") + prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION") + case provider.InferenceProviderBedrock: + if !hasAWSCredentials(env) { + continue + } + for _, model := range p.Models { + if !strings.HasPrefix(model.ID, "anthropic.") { + return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID) + } + } + default: + // if the provider api or endpoint are missing we skip them + v, err := resolver.ResolveValue(p.APIKey) + if v == "" || err != nil { + continue + } + v, err = resolver.ResolveValue(p.APIEndpoint) + if v == "" || err != nil { + continue + } + } + cfg.Providers[string(p.ID)] = prepared + } + return nil +} + +func hasVertexCredentials(env env.Env) bool { + useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true" + hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != "" + hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != "" + return useVertex && hasProject && hasLocation +} + +func hasAWSCredentials(env env.Env) bool { + if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" { + return true + } + + if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" { + return true + } + + if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" { + return true + } + + if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { + return true + } + + return false +} + +func setDefaults(workingDir string, cfg *Config) { + cfg.workingDir = workingDir + if cfg.Options == nil { + cfg.Options = &Options{} + } + if cfg.Options.TUI == nil { + cfg.Options.TUI = &TUIOptions{} + } + if cfg.Options.ContextPaths == nil { + cfg.Options.ContextPaths = []string{} + } + if cfg.Options.DataDirectory == "" { + cfg.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory) + } + if cfg.Providers == nil { + cfg.Providers = make(map[string]ProviderConfig) + } + if cfg.Models == nil { + cfg.Models = make(map[string]SelectedModel) + } + if cfg.MCP == nil { + cfg.MCP = make(map[string]MCP) + } + if cfg.LSP == nil { + cfg.LSP = make(map[string]LSPConfig) + } + + // Add the default context paths if they are not already present + cfg.Options.ContextPaths = append(defaultContextPaths, cfg.Options.ContextPaths...) + slices.Sort(cfg.Options.ContextPaths) + cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths) +} + +func loadFromConfigPaths(configPaths []string) (*Config, error) { + var configs []io.Reader + + for _, path := range configPaths { + fd, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + continue + } + return nil, fmt.Errorf("failed to open config file %s: %w", path, err) + } + defer fd.Close() + + configs = append(configs, fd) + } + + return loadFromReaders(configs) +} + +func loadFromReaders(readers []io.Reader) (*Config, error) { + if len(readers) == 0 { + return nil, fmt.Errorf("no configuration readers provided") + } + + merged, err := Merge(readers) + if err != nil { + return nil, fmt.Errorf("failed to merge configuration readers: %w", err) + } + + return LoadReader(merged) +} + +func globalConfig() string { + xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") + if xdgConfigHome != "" { + return filepath.Join(xdgConfigHome, "crush") + } + + // return the path to the main config directory + // for windows, it should be in `%LOCALAPPDATA%/crush/` + // for linux and macOS, it should be in `$HOME/.config/crush/` + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") + } + return filepath.Join(localAppData, appName) + } + + return filepath.Join(os.Getenv("HOME"), ".config", appName, fmt.Sprintf("%s.json", appName)) +} + +// globalConfigData returns the path to the main data directory for the application. +// this config is used when the app overrides configurations instead of updating the global config. +func globalConfigData() string { + xdgDataHome := os.Getenv("XDG_DATA_HOME") + if xdgDataHome != "" { + return filepath.Join(xdgDataHome, appName) + } + + // return the path to the main data directory + // for windows, it should be in `%LOCALAPPDATA%/crush/` + // for linux and macOS, it should be in `$HOME/.local/share/crush/` + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") + } + return filepath.Join(localAppData, appName) + } + + return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, fmt.Sprintf("%s.json", appName)) +} diff --git a/pkg/config/load_test.go b/pkg/config/load_test.go new file mode 100644 index 0000000000000000000000000000000000000000..149258c0ff3ffeafb9db4744cbc6ee70afec33bc --- /dev/null +++ b/pkg/config/load_test.go @@ -0,0 +1,316 @@ +package config + +import ( + "io" + "strings" + "testing" + + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/crush/pkg/env" + "github.com/stretchr/testify/assert" +) + +func TestConfig_LoadFromReaders(t *testing.T) { + data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`) + data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`) + data3 := strings.NewReader(`{"providers": {"openai": {}}}`) + + loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3}) + + assert.NoError(t, err) + assert.NotNil(t, loadedConfig) + assert.Len(t, loadedConfig.Providers, 1) + assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey) + assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL) +} + +func TestConfig_setDefaults(t *testing.T) { + cfg := &Config{} + + setDefaults("/tmp", cfg) + + assert.NotNil(t, cfg.Options) + assert.NotNil(t, cfg.Options.TUI) + assert.NotNil(t, cfg.Options.ContextPaths) + assert.NotNil(t, cfg.Providers) + assert.NotNil(t, cfg.Models) + assert.NotNil(t, cfg.LSP) + assert.NotNil(t, cfg.MCP) + assert.Equal(t, "/tmp/.crush", cfg.Options.DataDirectory) + for _, path := range defaultContextPaths { + assert.Contains(t, cfg.Options.ContextPaths, path) + } + assert.Equal(t, "/tmp", cfg.workingDir) +} + +func TestConfig_configureProviders(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "https://api.openai.com/v1", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + } + + cfg := &Config{} + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + // We want to make sure that we keep the configured API key as a placeholder + assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey) +} + +func TestConfig_configureProvidersWithOverride(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "https://api.openai.com/v1", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "openai": { + APIKey: "xyz", + BaseURL: "https://api.openai.com/v2", + Models: []provider.Model{ + { + ID: "test-model", + Name: "Updated", + }, + { + ID: "another-model", + }, + }, + }, + }, + } + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + // We want to make sure that we keep the configured API key as a placeholder + assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey) + assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL) + assert.Len(t, cfg.Providers["openai"].Models, 2) + assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Name) +} + +func TestConfig_configureProvidersWithNewProvider(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "https://api.openai.com/v1", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "xyz", + BaseURL: "https://api.someendpoint.com/v2", + Models: []provider.Model{ + { + ID: "test-model", + }, + }, + }, + }, + } + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.NoError(t, err) + // Should be to because of the env variable + assert.Len(t, cfg.Providers, 2) + + // We want to make sure that we keep the configured API key as a placeholder + assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey) + assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL) + assert.Len(t, cfg.Providers["custom"].Models, 1) + + _, ok := cfg.Providers["openai"] + assert.True(t, ok, "OpenAI provider should still be present") +} + +func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: provider.InferenceProviderBedrock, + APIKey: "", + APIEndpoint: "", + Models: []provider.Model{{ + ID: "anthropic.claude-sonnet-4-20250514-v1:0", + }}, + }, + } + + cfg := &Config{} + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{ + "AWS_ACCESS_KEY_ID": "test-key-id", + "AWS_SECRET_ACCESS_KEY": "test-secret-key", + }) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + bedrockProvider, ok := cfg.Providers["bedrock"] + assert.True(t, ok, "Bedrock provider should be present") + assert.Len(t, bedrockProvider.Models, 1) + assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID) +} + +func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: provider.InferenceProviderBedrock, + APIKey: "", + APIEndpoint: "", + Models: []provider.Model{{ + ID: "anthropic.claude-sonnet-4-20250514-v1:0", + }}, + }, + } + + cfg := &Config{} + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.NoError(t, err) + // Provider should not be configured without credentials + assert.Len(t, cfg.Providers, 0) +} + +func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: provider.InferenceProviderBedrock, + APIKey: "", + APIEndpoint: "", + Models: []provider.Model{{ + ID: "some-random-model", + }}, + }, + } + + cfg := &Config{} + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{ + "AWS_ACCESS_KEY_ID": "test-key-id", + "AWS_SECRET_ACCESS_KEY": "test-secret-key", + }) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.Error(t, err) +} + +func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: provider.InferenceProviderVertexAI, + APIKey: "", + APIEndpoint: "", + Models: []provider.Model{{ + ID: "gemini-pro", + }}, + }, + } + + cfg := &Config{} + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{ + "GOOGLE_GENAI_USE_VERTEXAI": "true", + "GOOGLE_CLOUD_PROJECT": "test-project", + "GOOGLE_CLOUD_LOCATION": "us-central1", + }) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + vertexProvider, ok := cfg.Providers["vertexai"] + assert.True(t, ok, "VertexAI provider should be present") + assert.Len(t, vertexProvider.Models, 1) + assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID) + assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"]) + assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"]) +} + +func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: provider.InferenceProviderVertexAI, + APIKey: "", + APIEndpoint: "", + Models: []provider.Model{{ + ID: "gemini-pro", + }}, + }, + } + + cfg := &Config{} + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{ + "GOOGLE_GENAI_USE_VERTEXAI": "false", + "GOOGLE_CLOUD_PROJECT": "test-project", + "GOOGLE_CLOUD_LOCATION": "us-central1", + }) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.NoError(t, err) + // Provider should not be configured without proper credentials + assert.Len(t, cfg.Providers, 0) +} + +func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: provider.InferenceProviderVertexAI, + APIKey: "", + APIEndpoint: "", + Models: []provider.Model{{ + ID: "gemini-pro", + }}, + }, + } + + cfg := &Config{} + setDefaults("/tmp", cfg) + env := env.NewFromMap(map[string]string{ + "GOOGLE_GENAI_USE_VERTEXAI": "true", + "GOOGLE_CLOUD_LOCATION": "us-central1", + }) + resolver := NewEnvironmentVariableResolver(env) + err := configureProviders(cfg, env, resolver, knownProviders) + assert.NoError(t, err) + // Provider should not be configured without project + assert.Len(t, cfg.Providers, 0) +} diff --git a/pkg/config/merge.go b/pkg/config/merge.go new file mode 100644 index 0000000000000000000000000000000000000000..3c9b7d6283a193166ad50730b28853a909f5158a --- /dev/null +++ b/pkg/config/merge.go @@ -0,0 +1,16 @@ +package config + +import ( + "bytes" + "io" + + "github.com/qjebbs/go-jsons" +) + +func Merge(data []io.Reader) (io.Reader, error) { + got, err := jsons.Merge(data) + if err != nil { + return nil, err + } + return bytes.NewReader(got), nil +} diff --git a/pkg/config/merge_test.go b/pkg/config/merge_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a00eb992a3edf97beb534353b4f0768c2b53a6d8 --- /dev/null +++ b/pkg/config/merge_test.go @@ -0,0 +1,27 @@ +package config + +import ( + "io" + "strings" + "testing" +) + +func TestMerge(t *testing.T) { + data1 := strings.NewReader(`{"foo": "bar"}`) + data2 := strings.NewReader(`{"baz": "qux"}`) + + merged, err := Merge([]io.Reader{data1, data2}) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + expected := `{"baz":"qux","foo":"bar"}` + got, err := io.ReadAll(merged) + if err != nil { + t.Fatalf("expected no error reading merged data, got %v", err) + } + + if string(got) != expected { + t.Errorf("expected %s, got %s", expected, string(got)) + } +} diff --git a/pkg/config/provider.go b/pkg/config/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..953959dece9e0714c108fc9cff43267fdc2487bc --- /dev/null +++ b/pkg/config/provider.go @@ -0,0 +1,93 @@ +package config + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "sync" + + "github.com/charmbracelet/crush/internal/fur/provider" +) + +type ProviderClient interface { + GetProviders() ([]provider.Provider, error) +} + +var ( + providerOnce sync.Once + providerList []provider.Provider +) + +// file to cache provider data +func providerCacheFileData() string { + xdgDataHome := os.Getenv("XDG_DATA_HOME") + if xdgDataHome != "" { + return filepath.Join(xdgDataHome, appName) + } + + // return the path to the main data directory + // for windows, it should be in `%LOCALAPPDATA%/crush/` + // for linux and macOS, it should be in `$HOME/.local/share/crush/` + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") + } + return filepath.Join(localAppData, appName) + } + + return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json") +} + +func saveProvidersInCache(path string, providers []provider.Provider) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + data, err := json.MarshalIndent(providers, "", " ") + if err != nil { + return err + } + + return os.WriteFile(path, data, 0o644) +} + +func loadProvidersFromCache(path string) ([]provider.Provider, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var providers []provider.Provider + err = json.Unmarshal(data, &providers) + return providers, err +} + +func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) { + providers, err := client.GetProviders() + if err != nil { + fallbackToCache, err := loadProvidersFromCache(path) + if err != nil { + return nil, err + } + providers = fallbackToCache + } else { + if err := saveProvidersInCache(path, providerList); err != nil { + return nil, err + } + } + return providers, nil +} + +func LoadProviders(client ProviderClient) ([]provider.Provider, error) { + var err error + providerOnce.Do(func() { + providerList, err = loadProviders(providerCacheFileData(), client) + }) + if err != nil { + return nil, err + } + return providerList, nil +} diff --git a/pkg/config/provider_test.go b/pkg/config/provider_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d975140bafe4300c75bea2ad8652655cddcc62a0 --- /dev/null +++ b/pkg/config/provider_test.go @@ -0,0 +1,74 @@ +package config + +import ( + "encoding/json" + "errors" + "os" + "testing" + + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/stretchr/testify/assert" +) + +type mockProviderClient struct { + shouldFail bool +} + +func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) { + if m.shouldFail { + return nil, errors.New("failed to load providers") + } + return []provider.Provider{ + { + Name: "Mock", + }, + }, nil +} + +func TestProvider_loadProvidersNoIssues(t *testing.T) { + client := &mockProviderClient{shouldFail: false} + tmpPath := t.TempDir() + "/providers.json" + providers, err := loadProviders(tmpPath, client) + assert.NoError(t, err) + assert.NotNil(t, providers) + assert.Len(t, providers, 1) + + // check if file got saved + fileInfo, err := os.Stat(tmpPath) + assert.NoError(t, err) + assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory") +} + +func TestProvider_loadProvidersWithIssues(t *testing.T) { + client := &mockProviderClient{shouldFail: true} + tmpPath := t.TempDir() + "/providers.json" + // store providers to a temporary file + oldProviders := []provider.Provider{ + { + Name: "OldProvider", + }, + } + data, err := json.Marshal(oldProviders) + if err != nil { + t.Fatalf("Failed to marshal old providers: %v", err) + } + + err = os.WriteFile(tmpPath, data, 0o644) + if err != nil { + t.Fatalf("Failed to write old providers to file: %v", err) + } + providers, err := loadProviders(tmpPath, client) + assert.NoError(t, err) + assert.NotNil(t, providers) + assert.Len(t, providers, 1) + assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails") +} + +func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { + client := &mockProviderClient{shouldFail: true} + tmpPath := t.TempDir() + "/providers.json" + providers, err := loadProviders(tmpPath, client) + assert.Error(t, err) + assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") +} + diff --git a/pkg/config/resolve.go b/pkg/config/resolve.go new file mode 100644 index 0000000000000000000000000000000000000000..483d383b22fa15ee6d5fc6713388919c82f350cc --- /dev/null +++ b/pkg/config/resolve.go @@ -0,0 +1,89 @@ +package config + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/shell" + "github.com/charmbracelet/crush/pkg/env" +) + +type VariableResolver interface { + ResolveValue(value string) (string, error) +} + +type Shell interface { + Exec(ctx context.Context, command string) (stdout, stderr string, err error) +} + +type shellVariableResolver struct { + shell Shell + env env.Env +} + +func NewShellVariableResolver(env env.Env) VariableResolver { + return &shellVariableResolver{ + shell: shell.NewShell( + &shell.Options{ + Env: env.Env(), + }, + ), + } +} + +// ResolveValue is a method for resolving values, such as environment variables. +// it will expect strings that start with `$` to be resolved as environment variables or shell commands. +// if the string does not start with `$`, it will return the string as is. +func (r *shellVariableResolver) ResolveValue(value string) (string, error) { + if !strings.HasPrefix(value, "$") { + return value, nil + } + + if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") { + command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + stdout, _, err := r.shell.Exec(ctx, command) + if err != nil { + return "", fmt.Errorf("command execution failed: %w", err) + } + return strings.TrimSpace(stdout), nil + } + + if strings.HasPrefix(value, "$") { + varName := strings.TrimPrefix(value, "$") + value = r.env.Get(varName) + if value == "" { + return "", fmt.Errorf("environment variable %q not set", varName) + } + return value, nil + } + return "", fmt.Errorf("invalid value format: %s", value) +} + +type environmentVariableResolver struct { + env env.Env +} + +func NewEnvironmentVariableResolver(env env.Env) VariableResolver { + return &environmentVariableResolver{ + env: env, + } +} + +// ResolveValue resolves environment variables from the provided env.Env. +func (r *environmentVariableResolver) ResolveValue(value string) (string, error) { + if !strings.HasPrefix(value, "$") { + return value, nil + } + + varName := strings.TrimPrefix(value, "$") + resolvedValue := r.env.Get(varName) + if resolvedValue == "" { + return "", fmt.Errorf("environment variable %q not set", varName) + } + return resolvedValue, nil +} diff --git a/pkg/config/resolve_test.go b/pkg/config/resolve_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8abfc68164610ff4059450dd8bd60ac9ed85c2cf --- /dev/null +++ b/pkg/config/resolve_test.go @@ -0,0 +1,178 @@ +package config + +import ( + "context" + "errors" + "testing" + + "github.com/charmbracelet/crush/pkg/env" + "github.com/stretchr/testify/assert" +) + +// mockShell implements the Shell interface for testing +type mockShell struct { + execFunc func(ctx context.Context, command string) (stdout, stderr string, err error) +} + +func (m *mockShell) Exec(ctx context.Context, command string) (stdout, stderr string, err error) { + if m.execFunc != nil { + return m.execFunc(ctx, command) + } + return "", "", nil +} + +func TestShellVariableResolver_ResolveValue(t *testing.T) { + tests := []struct { + name string + value string + envVars map[string]string + shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) + expected string + expectError bool + }{ + { + name: "non-variable string returns as-is", + value: "plain-string", + expected: "plain-string", + }, + { + name: "environment variable resolution", + value: "$HOME", + envVars: map[string]string{"HOME": "/home/user"}, + expected: "/home/user", + }, + { + name: "missing environment variable returns error", + value: "$MISSING_VAR", + envVars: map[string]string{}, + expectError: true, + }, + { + name: "shell command execution", + value: "$(echo hello)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo hello" { + return "hello\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "hello", + }, + { + name: "shell command with whitespace trimming", + value: "$(echo ' spaced ')", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo ' spaced '" { + return " spaced \n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "spaced", + }, + { + name: "shell command execution error", + value: "$(false)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + return "", "", errors.New("command failed") + }, + expectError: true, + }, + { + name: "invalid format returns error", + value: "$", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testEnv := env.NewFromMap(tt.envVars) + resolver := &shellVariableResolver{ + shell: &mockShell{execFunc: tt.shellFunc}, + env: testEnv, + } + + result, err := resolver.ResolveValue(tt.value) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) { + tests := []struct { + name string + value string + envVars map[string]string + expected string + expectError bool + }{ + { + name: "non-variable string returns as-is", + value: "plain-string", + expected: "plain-string", + }, + { + name: "environment variable resolution", + value: "$HOME", + envVars: map[string]string{"HOME": "/home/user"}, + expected: "/home/user", + }, + { + name: "environment variable with complex value", + value: "$PATH", + envVars: map[string]string{"PATH": "/usr/bin:/bin:/usr/local/bin"}, + expected: "/usr/bin:/bin:/usr/local/bin", + }, + { + name: "missing environment variable returns error", + value: "$MISSING_VAR", + envVars: map[string]string{}, + expectError: true, + }, + { + name: "empty environment variable returns error", + value: "$EMPTY_VAR", + envVars: map[string]string{"EMPTY_VAR": ""}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testEnv := env.NewFromMap(tt.envVars) + resolver := NewEnvironmentVariableResolver(testEnv) + + result, err := resolver.ResolveValue(tt.value) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestNewShellVariableResolver(t *testing.T) { + testEnv := env.NewFromMap(map[string]string{"TEST": "value"}) + resolver := NewShellVariableResolver(testEnv) + + assert.NotNil(t, resolver) + assert.Implements(t, (*VariableResolver)(nil), resolver) +} + +func TestNewEnvironmentVariableResolver(t *testing.T) { + testEnv := env.NewFromMap(map[string]string{"TEST": "value"}) + resolver := NewEnvironmentVariableResolver(testEnv) + + assert.NotNil(t, resolver) + assert.Implements(t, (*VariableResolver)(nil), resolver) +} + diff --git a/pkg/env/env.go b/pkg/env/env.go new file mode 100644 index 0000000000000000000000000000000000000000..f223bea50e465c28d924072b35dd042be12b0054 --- /dev/null +++ b/pkg/env/env.go @@ -0,0 +1,55 @@ +package env + +import "os" + +type Env interface { + Get(key string) string + Env() []string +} + +type osEnv struct{} + +// Get implements Env. +func (o *osEnv) Get(key string) string { + return os.Getenv(key) +} + +func (o *osEnv) Env() []string { + env := os.Environ() + if len(env) == 0 { + return nil + } + return env +} + +func New() Env { + return &osEnv{} +} + +type mapEnv struct { + m map[string]string +} + +// Get implements Env. +func (m *mapEnv) Get(key string) string { + if value, ok := m.m[key]; ok { + return value + } + return "" +} + +// Env implements Env. +func (m *mapEnv) Env() []string { + if len(m.m) == 0 { + return nil + } + env := make([]string, 0, len(m.m)) + for k, v := range m.m { + env = append(env, k+"="+v) + } + return env +} + +func NewFromMap(m map[string]string) Env { + return &mapEnv{m: m} +} diff --git a/todos.md b/todos.md deleted file mode 100644 index 080bf64df8dd6e4d5a496531ba5f8f2be5fcf8a4..0000000000000000000000000000000000000000 --- a/todos.md +++ /dev/null @@ -1,51 +0,0 @@ -## TODOs before release - -- [x] Implement help - - [x] Show full help - - [x] Make help dependent on the focused pane and page -- [x] Implement current model in the sidebar -- [x] Implement LSP errors -- [x] Implement changed files - - [x] Implement initial load - - [x] Implement realtime file changes -- [ ] Events when tool error -- [ ] Support bash commands -- [ ] Editor attachments fixes - - [ ] Reimplement removing attachments -- [ ] Fix the logs view - - [ ] Review the implementation - - [ ] The page lags - - [ ] Make the logs long lived ? -- [ ] Add all possible actions to the commands -- [ ] Parallel tool calls and permissions - - [ ] Run the tools in parallel and add results in parallel - - [ ] Show multiple permissions dialogs -- [ ] Add another space around buttons -- [ ] Completions - - [ ] Should change the help to show the completions stuff - - [ ] Should make it wider - - [ ] Tab and ctrl+y should accept - - [ ] Words should line up - - [ ] If there are no completions and cick tab/ctrl+y/enter it should close it -- [ ] Investigate messages issues - - [ ] Make the agent separator look like the - - [ ] Cleanup tool calls (watch all states) - - [ ] Weird behavior sometimes the message does not update - - [ ] Message length (I saw the message go beyond the correct length when there are errors) - - [ ] Address UX issues - - [ ] Fix issue with numbers (padding) view tool -- [x] Implement responsive mode -- [ ] Update interactive mode to use the spinner -- [ ] Revisit the core list component - - [ ] This component has become super complex we might need to fix this. -- [ ] Handle correct LSP and MCP status icon -- [x] Investigate ways to make the spinner less CPU intensive -- [ ] General cleanup and documentation -- [ ] Update the readme - -## Maybe - -- [ ] Revisit the provider/model/configs -- [ ] Implement correct persistent shell -- [ ] Store file read/write time somewhere so that the we can make sure that even if we restart we do not need to re-read the same file -- [ ] Send updates to the UI when new LSP diagnostics are available