From 4487507bdf3ecc6dae22efb4ab83a115c91dd451 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 25 Sep 2025 15:17:55 -0400 Subject: [PATCH] refactor(env): remove env pkg --- internal/config/config.go | 6 +- internal/config/load.go | 61 +++++++------- internal/config/load_test.go | 103 ++++++++++++----------- internal/config/provider.go | 27 ++---- internal/config/resolve.go | 16 ++-- internal/config/resolve_test.go | 43 +++++----- internal/env/env.go | 58 ------------- internal/env/env_test.go | 140 -------------------------------- internal/llm/prompt/prompt.go | 3 +- internal/lsp/client_test.go | 7 +- internal/version/version.go | 2 + 11 files changed, 128 insertions(+), 338 deletions(-) delete mode 100644 internal/env/env.go delete mode 100644 internal/env/env_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 450029399eac38cf4cb4d5868d7cff5e09e59e9f..a7adf2c23889762d0f4250b18491cb85d6b74f1c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,12 +13,10 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" "github.com/tidwall/sjson" ) const ( - appName = "crush" defaultDataDirectory = ".crush" ) @@ -207,7 +205,7 @@ func (m MCPConfig) ResolvedEnv() []string { } func (m MCPConfig) ResolvedHeaders() map[string]string { - resolver := NewShellVariableResolver(env.New()) + resolver := NewShellVariableResolver(os.Environ()) for e, v := range m.Headers { var err error m.Headers[e], err = resolver.ResolveValue(v) @@ -563,7 +561,7 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { } func resolveEnvs(envs map[string]string) []string { - resolver := NewShellVariableResolver(env.New()) + resolver := NewShellVariableResolver(os.Environ()) for e, v := range envs { var err error envs[e], err = resolver.ResolveValue(v) diff --git a/internal/config/load.go b/internal/config/load.go index ed12cddef00357509844344b37921a7abec1b6ef..ce7982b24dd7431e6f8f4a0eeb37c240a616917c 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -15,12 +15,15 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/version" + uv "github.com/charmbracelet/ultraviolet" powernapConfig "github.com/charmbracelet/x/powernap/pkg/config" ) +type environ = uv.Environ + const defaultCatwalkURL = "https://catwalk.charm.sh" // LoadReader config via io.Reader. @@ -62,11 +65,11 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { } cfg.knownProviders = providers - env := env.New() + envs := os.Environ() // Configure providers - valueResolver := NewShellVariableResolver(env) + valueResolver := NewShellVariableResolver(envs) cfg.resolver = valueResolver - if err := cfg.configureProviders(env, valueResolver, cfg.knownProviders); err != nil { + if err := cfg.configureProviders(envs, valueResolver, cfg.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } @@ -110,7 +113,7 @@ func PushPopCrushEnv() func() { return restore } -func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { +func (c *Config) configureProviders(env environ, resolver VariableResolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) restore := PushPopCrushEnv() defer restore() @@ -185,8 +188,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } continue } - prepared.ExtraParams["project"] = env.Get("VERTEXAI_PROJECT") - prepared.ExtraParams["location"] = env.Get("VERTEXAI_LOCATION") + prepared.ExtraParams["project"] = env.Getenv("VERTEXAI_PROJECT") + prepared.ExtraParams["location"] = env.Getenv("VERTEXAI_LOCATION") case catwalk.InferenceProviderAzure: endpoint, err := resolver.ResolveValue(p.APIEndpoint) if err != nil || endpoint == "" { @@ -197,7 +200,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know continue } prepared.BaseURL = endpoint - prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION") + prepared.ExtraParams["apiVersion"] = env.Getenv("AZURE_OPENAI_API_VERSION") case catwalk.InferenceProviderBedrock: if !hasAWSCredentials(env) { if configExists { @@ -206,9 +209,9 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } continue } - prepared.ExtraParams["region"] = env.Get("AWS_REGION") + prepared.ExtraParams["region"] = env.Getenv("AWS_REGION") if prepared.ExtraParams["region"] == "" { - prepared.ExtraParams["region"] = env.Get("AWS_DEFAULT_REGION") + prepared.ExtraParams["region"] = env.Getenv("AWS_DEFAULT_REGION") } for _, model := range p.Models { if !strings.HasPrefix(model.ID, "anthropic.") { @@ -521,7 +524,7 @@ func lookupConfigs(cwd string) []string { return configPaths } - configNames := []string{appName + ".json", "." + appName + ".json"} + configNames := []string{version.AppName + ".json", "." + version.AppName + ".json"} foundConfigs, err := fsext.Lookup(cwd, configNames...) if err != nil { @@ -567,27 +570,27 @@ func loadFromReaders(readers []io.Reader) (*Config, error) { return LoadReader(merged) } -func hasVertexCredentials(env env.Env) bool { - hasProject := env.Get("VERTEXAI_PROJECT") != "" - hasLocation := env.Get("VERTEXAI_LOCATION") != "" +func hasVertexCredentials(env environ) bool { + hasProject := env.Getenv("VERTEXAI_PROJECT") != "" + hasLocation := env.Getenv("VERTEXAI_LOCATION") != "" return hasProject && hasLocation } -func hasAWSCredentials(env env.Env) bool { - if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" { +func hasAWSCredentials(env environ) bool { + if env.Getenv("AWS_ACCESS_KEY_ID") != "" && env.Getenv("AWS_SECRET_ACCESS_KEY") != "" { return true } - if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" { + if env.Getenv("AWS_PROFILE") != "" || env.Getenv("AWS_DEFAULT_PROFILE") != "" { return true } - if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" { + if env.Getenv("AWS_REGION") != "" || env.Getenv("AWS_DEFAULT_REGION") != "" { return true } - if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || - env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { + if env.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + env.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { return true } return false @@ -596,7 +599,7 @@ func hasAWSCredentials(env env.Env) bool { func globalConfig() string { xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") if xdgConfigHome != "" { - return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(xdgConfigHome, version.AppName, fmt.Sprintf("%s.json", version.AppName)) } // return the path to the main config directory @@ -607,10 +610,10 @@ func globalConfig() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(localAppData, version.AppName, fmt.Sprintf("%s.json", version.AppName)) } - return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(home.Dir(), ".config", version.AppName, fmt.Sprintf("%s.json", version.AppName)) } // GlobalConfigData returns the path to the main data directory for the application. @@ -618,7 +621,7 @@ func globalConfig() string { func GlobalConfigData() string { xdgDataHome := os.Getenv("XDG_DATA_HOME") if xdgDataHome != "" { - return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(xdgDataHome, version.AppName, fmt.Sprintf("%s.json", version.AppName)) } // return the path to the main data directory @@ -629,17 +632,17 @@ func GlobalConfigData() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(localAppData, version.AppName, fmt.Sprintf("%s.json", version.AppName)) } - return filepath.Join(home.Dir(), ".local", "share", appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(home.Dir(), ".local", "share", version.AppName, fmt.Sprintf("%s.json", version.AppName)) } // GlobalCacheDir returns the path to the main cache directory for the application. func GlobalCacheDir() string { xdgCacheHome := os.Getenv("XDG_CACHE_HOME") if xdgCacheHome != "" { - return filepath.Join(xdgCacheHome, appName) + return filepath.Join(xdgCacheHome, version.AppName) } // return the path to the main cache directory @@ -650,8 +653,8 @@ func GlobalCacheDir() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, "Cache") + return filepath.Join(localAppData, version.AppName, "Cache") } - return filepath.Join(home.Dir(), ".cache", appName) + return filepath.Join(home.Dir(), ".cache", version.AppName) } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 406fe07d523c8b0d5d7f038f8d94cc74a0b58f89..e4777b8d823f5caf28246f58b2cdeab3e12c7d14 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -10,7 +10,6 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,8 +22,8 @@ func TestMain(m *testing.M) { } func TestConfig_LoadFromReaders(t *testing.T) { - data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`) - data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`) + data1 := strings.NewReader(`{"providers": {"openai": {"api_key=key1", "base_url": "https://api.openai.com/v1"}}}`) + data2 := strings.NewReader(`{"providers": {"openai": {"api_key=key2", "base_url": "https://api.openai.com/v2"}}}`) data3 := strings.NewReader(`{"providers": {"openai": {}}}`) loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3}) @@ -70,8 +69,8 @@ func TestConfig_configureProviders(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "OPENAI_API_KEY": "test-key", + env := environ([]string{ + "OPENAI_API_KEY=test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -113,8 +112,8 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { }) cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "OPENAI_API_KEY": "test-key", + env := environ([]string{ + "OPENAI_API_KEY=test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -155,8 +154,8 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { }), } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "OPENAI_API_KEY": "test-key", + env := environ([]string{ + "OPENAI_API_KEY=test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -190,9 +189,9 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "AWS_ACCESS_KEY_ID": "test-key-id", - "AWS_SECRET_ACCESS_KEY": "test-secret-key", + env := environ([]string{ + "AWS_ACCESS_KEY_ID=test-key-id", + "AWS_SECRET_ACCESS_KEY=test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -219,7 +218,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -241,9 +240,9 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "AWS_ACCESS_KEY_ID": "test-key-id", - "AWS_SECRET_ACCESS_KEY": "test-secret-key", + env := environ([]string{ + "AWS_ACCESS_KEY_ID=test-key-id", + "AWS_SECRET_ACCESS_KEY=test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -264,9 +263,9 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "VERTEXAI_PROJECT": "test-project", - "VERTEXAI_LOCATION": "us-central1", + env := environ([]string{ + "VERTEXAI_PROJECT=test-project", + "VERTEXAI_LOCATION=us-central1", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -295,10 +294,10 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "GOOGLE_GENAI_USE_VERTEXAI": "false", - "GOOGLE_CLOUD_PROJECT": "test-project", - "GOOGLE_CLOUD_LOCATION": "us-central1", + env := environ([]string{ + "GOOGLE_GENAI_USE_VERTEXAI=false", + "GOOGLE_CLOUD_PROJECT=test-project", + "GOOGLE_CLOUD_LOCATION=us-central1", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -321,9 +320,9 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "GOOGLE_GENAI_USE_VERTEXAI": "true", - "GOOGLE_CLOUD_LOCATION": "us-central1", + env := environ([]string{ + "GOOGLE_GENAI_USE_VERTEXAI=true", + "GOOGLE_CLOUD_LOCATION=us-central1", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -346,8 +345,8 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "OPENAI_API_KEY": "test-key", + env := environ([]string{ + "OPENAI_API_KEY=test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -536,8 +535,8 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "OPENAI_API_KEY": "test-key", + env := environ([]string{ + "OPENAI_API_KEY=test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -566,7 +565,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) @@ -589,7 +588,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) @@ -611,7 +610,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) @@ -636,7 +635,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) @@ -661,7 +660,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) @@ -689,7 +688,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) @@ -719,7 +718,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) @@ -752,8 +751,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "GOOGLE_GENAI_USE_VERTEXAI": "false", + env := environ([]string{ + "GOOGLE_GENAI_USE_VERTEXAI=false", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -785,7 +784,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -816,7 +815,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -847,8 +846,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{ - "OPENAI_API_KEY": "test-key", + env := environ([]string{ + "OPENAI_API_KEY=test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -883,7 +882,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -919,7 +918,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -949,7 +948,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -992,7 +991,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }), } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -1036,7 +1035,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }), } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -1078,7 +1077,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }), } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -1126,7 +1125,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { }, } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -1188,7 +1187,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { }, } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) @@ -1233,7 +1232,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { }, } cfg.setDefaults("/tmp", "") - env := env.NewFromMap(map[string]string{}) + env := environ([]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) diff --git a/internal/config/provider.go b/internal/config/provider.go index 671c348f71da3a79f65c14c624bdaf2adc011411..df75c25a3ab9d4b8e5f7f79a34d84d05e8362c83 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -9,29 +9,23 @@ import ( "path/filepath" "runtime" "strings" - "sync" "time" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/catwalk/pkg/embedded" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/version" ) type ProviderClient interface { GetProviders() ([]catwalk.Provider, error) } -var ( - providerOnce sync.Once - providerList []catwalk.Provider - providerErr error -) - // file to cache provider data func providerCacheFileData() string { xdgDataHome := os.Getenv("XDG_DATA_HOME") if xdgDataHome != "" { - return filepath.Join(xdgDataHome, appName, "providers.json") + return filepath.Join(xdgDataHome, version.AppName, "providers.json") } // return the path to the main data directory @@ -42,10 +36,10 @@ func providerCacheFileData() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, "providers.json") + return filepath.Join(localAppData, version.AppName, "providers.json") } - return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json") + return filepath.Join(home.Dir(), ".local", "share", version.AppName, "providers.json") } func saveProvidersInCache(path string, providers []catwalk.Provider) error { @@ -114,15 +108,10 @@ func UpdateProviders(pathOrUrl string) error { } func Providers(cfg *Config) ([]catwalk.Provider, error) { - providerOnce.Do(func() { - catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) - client := catwalk.NewWithURL(catwalkURL) - path := providerCacheFileData() - - autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate - providerList, providerErr = loadProviders(autoUpdateDisabled, client, path) - }) - return providerList, providerErr + catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) + client := catwalk.NewWithURL(catwalkURL) + path := providerCacheFileData() + return loadProviders(cfg.Options.DisableProviderAutoUpdate, client, path) } func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) { diff --git a/internal/config/resolve.go b/internal/config/resolve.go index 3ef3522b09e504d3c57105e8bbe393b0f7c38b2b..9ebe425ba95f98996f0bbae0368710d755182fd2 100644 --- a/internal/config/resolve.go +++ b/internal/config/resolve.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/shell" ) @@ -20,15 +19,15 @@ type Shell interface { type shellVariableResolver struct { shell Shell - env env.Env + env []string } -func NewShellVariableResolver(env env.Env) VariableResolver { +func NewShellVariableResolver(env []string) VariableResolver { return &shellVariableResolver{ env: env, shell: shell.NewShell( &shell.Options{ - Env: env.Env(), + Env: env, }, ), } @@ -38,6 +37,7 @@ func NewShellVariableResolver(env env.Env) VariableResolver { // it will resolve shell-like variable substitution anywhere in the string, including: // - $(command) for command substitution // - $VAR or ${VAR} for environment variables +// TODO: can we replace this with [os.Expand](https://pkg.go.dev/os#Expand) somehow? func (r *shellVariableResolver) ResolveValue(value string) (string, error) { // Special case: lone $ is an error (backward compatibility) if value == "$" { @@ -139,7 +139,7 @@ func (r *shellVariableResolver) ResolveValue(value string) (string, error) { varName = result[start+1 : end] } - envValue := r.env.Get(varName) + envValue := environ(r.env).Getenv(varName) if envValue == "" { return "", fmt.Errorf("environment variable %q not set", varName) } @@ -152,10 +152,10 @@ func (r *shellVariableResolver) ResolveValue(value string) (string, error) { } type environmentVariableResolver struct { - env env.Env + env []string } -func NewEnvironmentVariableResolver(env env.Env) VariableResolver { +func NewEnvironmentVariableResolver(env []string) VariableResolver { return &environmentVariableResolver{ env: env, } @@ -168,7 +168,7 @@ func (r *environmentVariableResolver) ResolveValue(value string) (string, error) } varName := strings.TrimPrefix(value, "$") - resolvedValue := r.env.Get(varName) + resolvedValue := environ(r.env).Getenv(varName) if resolvedValue == "" { return "", fmt.Errorf("environment variable %q not set", varName) } diff --git a/internal/config/resolve_test.go b/internal/config/resolve_test.go index ec9b06c25bdc023acebffc71f043b54a8da21597..9d4ccb92114446c392e593cd57083b8f71d2df1a 100644 --- a/internal/config/resolve_test.go +++ b/internal/config/resolve_test.go @@ -5,7 +5,6 @@ import ( "errors" "testing" - "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/require" ) @@ -25,7 +24,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) { tests := []struct { name string value string - envVars map[string]string + envVars []string shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) expected string expectError bool @@ -38,13 +37,13 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) { { name: "environment variable resolution", value: "$HOME", - envVars: map[string]string{"HOME": "/home/user"}, + envVars: []string{"HOME=/home/user"}, expected: "/home/user", }, { name: "missing environment variable returns error", value: "$MISSING_VAR", - envVars: map[string]string{}, + envVars: []string{}, expectError: true, }, @@ -76,7 +75,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - testEnv := env.NewFromMap(tt.envVars) + testEnv := environ(tt.envVars) resolver := &shellVariableResolver{ shell: &mockShell{execFunc: tt.shellFunc}, env: testEnv, @@ -98,7 +97,7 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { tests := []struct { name string value string - envVars map[string]string + envVars []string shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) expected string expectError bool @@ -117,21 +116,21 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { { name: "environment variable within string", value: "Bearer $TOKEN", - envVars: map[string]string{"TOKEN": "sk-ant-123"}, + envVars: []string{"TOKEN=sk-ant-123"}, expected: "Bearer sk-ant-123", }, { name: "environment variable with braces within string", value: "Bearer ${TOKEN}", - envVars: map[string]string{"TOKEN": "sk-ant-456"}, + envVars: []string{"TOKEN=sk-ant-456"}, expected: "Bearer sk-ant-456", }, { name: "mixed command and environment substitution", value: "$USER-$(date +%Y)-$HOST", - envVars: map[string]string{ - "USER": "testuser", - "HOST": "localhost", + envVars: []string{ + "USER=testuser", + "HOST=localhost", }, shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { if command == "date +%Y" { @@ -179,7 +178,7 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { { name: "empty environment variable substitution", value: "Bearer $EMPTY_VAR", - envVars: map[string]string{}, + envVars: []string{}, expectError: true, }, { @@ -214,7 +213,7 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { { name: "environment variable with underscores and numbers", value: "Bearer $API_KEY_V2", - envVars: map[string]string{"API_KEY_V2": "sk-test-123"}, + envVars: []string{"API_KEY_V2=sk-test-123"}, expected: "Bearer sk-test-123", }, { @@ -241,7 +240,7 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - testEnv := env.NewFromMap(tt.envVars) + testEnv := environ(tt.envVars) resolver := &shellVariableResolver{ shell: &mockShell{execFunc: tt.shellFunc}, env: testEnv, @@ -263,7 +262,7 @@ func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) { tests := []struct { name string value string - envVars map[string]string + envVars []string expected string expectError bool }{ @@ -275,32 +274,32 @@ func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) { { name: "environment variable resolution", value: "$HOME", - envVars: map[string]string{"HOME": "/home/user"}, + envVars: []string{"HOME=/home/user"}, expected: "/home/user", }, { name: "environment variable with complex value", value: "$PATH", - envVars: map[string]string{"PATH": "/usr/bin:/bin:/usr/local/bin"}, + envVars: []string{"PATH=/usr/bin:/bin:/usr/local/bin"}, expected: "/usr/bin:/bin:/usr/local/bin", }, { name: "missing environment variable returns error", value: "$MISSING_VAR", - envVars: map[string]string{}, + envVars: []string{}, expectError: true, }, { name: "empty environment variable returns error", value: "$EMPTY_VAR", - envVars: map[string]string{"EMPTY_VAR": ""}, + envVars: []string{"EMPTY_VAR="}, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - testEnv := env.NewFromMap(tt.envVars) + testEnv := environ(tt.envVars) resolver := NewEnvironmentVariableResolver(testEnv) result, err := resolver.ResolveValue(tt.value) @@ -316,7 +315,7 @@ func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) { } func TestNewShellVariableResolver(t *testing.T) { - testEnv := env.NewFromMap(map[string]string{"TEST": "value"}) + testEnv := environ([]string{"TEST=value"}) resolver := NewShellVariableResolver(testEnv) require.NotNil(t, resolver) @@ -324,7 +323,7 @@ func TestNewShellVariableResolver(t *testing.T) { } func TestNewEnvironmentVariableResolver(t *testing.T) { - testEnv := env.NewFromMap(map[string]string{"TEST": "value"}) + testEnv := environ([]string{"TEST=value"}) resolver := NewEnvironmentVariableResolver(testEnv) require.NotNil(t, resolver) diff --git a/internal/env/env.go b/internal/env/env.go deleted file mode 100644 index 24d44d10fca5a374732283d0aca4ddc8166b879b..0000000000000000000000000000000000000000 --- a/internal/env/env.go +++ /dev/null @@ -1,58 +0,0 @@ -package env - -import "os" - -type Env interface { - Get(key string) string - Env() []string -} - -type osEnv struct{} - -// Get implements Env. -func (o *osEnv) Get(key string) string { - return os.Getenv(key) -} - -func (o *osEnv) Env() []string { - env := os.Environ() - if len(env) == 0 { - return nil - } - return env -} - -func New() Env { - return &osEnv{} -} - -type mapEnv struct { - m map[string]string -} - -// Get implements Env. -func (m *mapEnv) Get(key string) string { - if value, ok := m.m[key]; ok { - return value - } - return "" -} - -// Env implements Env. -func (m *mapEnv) Env() []string { - if len(m.m) == 0 { - return nil - } - env := make([]string, 0, len(m.m)) - for k, v := range m.m { - env = append(env, k+"="+v) - } - return env -} - -func NewFromMap(m map[string]string) Env { - if m == nil { - m = make(map[string]string) - } - return &mapEnv{m: m} -} diff --git a/internal/env/env_test.go b/internal/env/env_test.go deleted file mode 100644 index 6bd323e0cb169c2fd06397ed7b015de98145b105..0000000000000000000000000000000000000000 --- a/internal/env/env_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package env - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestOsEnv_Get(t *testing.T) { - env := New() - - // Test getting an existing environment variable - t.Setenv("TEST_VAR", "test_value") - - value := env.Get("TEST_VAR") - require.Equal(t, "test_value", value) - - // Test getting a non-existent environment variable - value = env.Get("NON_EXISTENT_VAR") - require.Equal(t, "", value) -} - -func TestOsEnv_Env(t *testing.T) { - env := New() - - envVars := env.Env() - - // Environment should not be empty in normal circumstances - require.NotNil(t, envVars) - require.Greater(t, len(envVars), 0) - - // Each environment variable should be in key=value format - for _, envVar := range envVars { - require.Contains(t, envVar, "=") - } -} - -func TestNewFromMap(t *testing.T) { - testMap := map[string]string{ - "KEY1": "value1", - "KEY2": "value2", - } - - env := NewFromMap(testMap) - require.NotNil(t, env) - require.IsType(t, &mapEnv{}, env) -} - -func TestMapEnv_Get(t *testing.T) { - testMap := map[string]string{ - "KEY1": "value1", - "KEY2": "value2", - } - - env := NewFromMap(testMap) - - // Test getting existing keys - require.Equal(t, "value1", env.Get("KEY1")) - require.Equal(t, "value2", env.Get("KEY2")) - - // Test getting non-existent key - require.Equal(t, "", env.Get("NON_EXISTENT")) -} - -func TestMapEnv_Env(t *testing.T) { - t.Run("with values", func(t *testing.T) { - testMap := map[string]string{ - "KEY1": "value1", - "KEY2": "value2", - } - - env := NewFromMap(testMap) - envVars := env.Env() - - require.Len(t, envVars, 2) - - // Convert to map for easier testing (order is not guaranteed) - envMap := make(map[string]string) - for _, envVar := range envVars { - parts := strings.SplitN(envVar, "=", 2) - require.Len(t, parts, 2) - envMap[parts[0]] = parts[1] - } - - require.Equal(t, "value1", envMap["KEY1"]) - require.Equal(t, "value2", envMap["KEY2"]) - }) - - t.Run("empty map", func(t *testing.T) { - env := NewFromMap(map[string]string{}) - envVars := env.Env() - require.Nil(t, envVars) - }) - - t.Run("nil map", func(t *testing.T) { - env := NewFromMap(nil) - envVars := env.Env() - require.Nil(t, envVars) - }) -} - -func TestMapEnv_GetEmptyValue(t *testing.T) { - testMap := map[string]string{ - "EMPTY_KEY": "", - "NORMAL_KEY": "value", - } - - env := NewFromMap(testMap) - - // Test that empty values are returned correctly - require.Equal(t, "", env.Get("EMPTY_KEY")) - require.Equal(t, "value", env.Get("NORMAL_KEY")) -} - -func TestMapEnv_EnvFormat(t *testing.T) { - testMap := map[string]string{ - "KEY_WITH_EQUALS": "value=with=equals", - "KEY_WITH_SPACES": "value with spaces", - } - - env := NewFromMap(testMap) - envVars := env.Env() - - require.Len(t, envVars, 2) - - // Check that the format is correct even with special characters - found := make(map[string]bool) - for _, envVar := range envVars { - if envVar == "KEY_WITH_EQUALS=value=with=equals" { - found["equals"] = true - } - if envVar == "KEY_WITH_SPACES=value with spaces" { - found["spaces"] = true - } - } - - require.True(t, found["equals"], "Should handle values with equals signs") - require.True(t, found["spaces"], "Should handle values with spaces") -} diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 54a8b446fb0ff8411228722bc55c4bbb627723b8..28ab4e4f98a1ae197c12b6b174f71b58dfe1de3b 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -8,7 +8,6 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/home" ) @@ -49,7 +48,7 @@ func expandPath(path string) string { // Handle environment variable expansion using the same pattern as config if strings.HasPrefix(path, "$") { - resolver := config.NewEnvironmentVariableResolver(env.New()) + resolver := config.NewEnvironmentVariableResolver(os.Environ()) if expanded, err := resolver.ResolveValue(path); err == nil { path = expanded } diff --git a/internal/lsp/client_test.go b/internal/lsp/client_test.go index be87b679cf9b1b1839e4120af023f42d9c97afa1..33cb0eb7310b7fae3e9345999b8125e95f771d88 100644 --- a/internal/lsp/client_test.go +++ b/internal/lsp/client_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/env" ) func TestClient(t *testing.T) { @@ -22,9 +21,9 @@ func TestClient(t *testing.T) { // Test creating a powernap client - this will likely fail with echo // but we can still test the basic structure - client, err := New(ctx, &cfg, "test", lspCfg, config.NewEnvironmentVariableResolver(env.NewFromMap(map[string]string{ - "THE_CMD": "echo", - }))) + client, err := New(ctx, &cfg, "test", lspCfg, config.NewEnvironmentVariableResolver([]string{ + "THE_CMD=echo", + })) if err != nil { // Expected to fail with echo command, skip the rest t.Skipf("Powernap client creation failed as expected with dummy command: %v", err) diff --git a/internal/version/version.go b/internal/version/version.go index 430412e050668fd598f206bf7073c12f27c8d004..6ea7f203d22c789e09d11c921895c7f9e704cec8 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -2,6 +2,8 @@ package version import "runtime/debug" +const AppName = "crush" + // Build-time parameters set via -ldflags var (