feat: allow for using CRUSH_ prefixed env-vars without clobbering default env vars (#391)

Tai Groot created

* feat: allow for using CRUSH_ prefixed env-vars for bedrock without clobbering AWS env vars

* feat: make the CRUSH_ prefix generic

Change summary

internal/config/load.go           | 31 ++++++++++++++++++++++++++++++-
internal/llm/provider/provider.go |  3 +++
2 files changed, 33 insertions(+), 1 deletion(-)

Detailed changes

internal/config/load.go 🔗

@@ -93,8 +93,38 @@ func Load(workingDir string, debug bool) (*Config, error) {
 	return cfg, nil
 }
 
+func PushPopCrushEnv() func() {
+	found := []string{}
+	for _, ev := range os.Environ() {
+		if strings.HasPrefix(ev, "CRUSH_") {
+			pair := strings.SplitN(ev, "=", 2)
+			if len(pair) != 2 {
+				continue
+			}
+			found = append(found, strings.TrimPrefix(pair[0], "CRUSH_"))
+		}
+	}
+	backups := make(map[string]string)
+	for _, ev := range found {
+		backups[ev] = os.Getenv(ev)
+	}
+
+	for _, ev := range found {
+		os.Setenv(ev, os.Getenv("CRUSH_"+ev))
+	}
+
+	restore := func() {
+		for k, v := range backups {
+			os.Setenv(k, v)
+		}
+	}
+	return restore
+}
+
 func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
 	knownProviderNames := make(map[string]bool)
+	restore := PushPopCrushEnv()
+	defer restore()
 	for _, p := range knownProviders {
 		knownProviderNames[string(p.ID)] = true
 		config, configExists := c.Providers.Get(string(p.ID))
@@ -495,7 +525,6 @@ func hasAWSCredentials(env env.Env) bool {
 		env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
 		return true
 	}
-
 	return false
 }
 

internal/llm/provider/provider.go 🔗

@@ -5,6 +5,7 @@ import (
 	"fmt"
 
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
+
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/llm/tools"
 	"github.com/charmbracelet/crush/internal/message"
@@ -139,6 +140,8 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
 }
 
 func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
+	restore := config.PushPopCrushEnv()
+	defer restore()
 	resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
 	if err != nil {
 		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)