diff --git a/internal/config/load.go b/internal/config/load.go index 5b81fa3085b94cbfb051ddfdf9887e44c6fe5540..e2dfcdbbf9dbd60dd5afc007338bad0e3e410050 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -93,8 +93,38 @@ func Load(workingDir string, debug bool) (*Config, error) { return cfg, nil } +func PushPopCrushEnv() func() { + found := []string{} + for _, ev := range os.Environ() { + if strings.HasPrefix(ev, "CRUSH_") { + pair := strings.SplitN(ev, "=", 2) + if len(pair) != 2 { + continue + } + found = append(found, strings.TrimPrefix(pair[0], "CRUSH_")) + } + } + backups := make(map[string]string) + for _, ev := range found { + backups[ev] = os.Getenv(ev) + } + + for _, ev := range found { + os.Setenv(ev, os.Getenv("CRUSH_"+ev)) + } + + restore := func() { + for k, v := range backups { + os.Setenv(k, v) + } + } + return restore +} + func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) + restore := PushPopCrushEnv() + defer restore() for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true config, configExists := c.Providers.Get(string(p.ID)) @@ -495,7 +525,6 @@ func hasAWSCredentials(env env.Env) bool { env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { return true } - return false } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 6376561aa437c0dfcd4abeb8f7ed2fd2b182e936..28562f2f484a75c445d9eaa21ce90af4ef5ca613 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" @@ -139,6 +140,8 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption { } func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { + restore := config.PushPopCrushEnv() + defer restore() resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey) if err != nil { return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)