@@ -93,8 +93,38 @@ func Load(workingDir string, debug bool) (*Config, error) {
return cfg, nil
}
+func PushPopCrushEnv() func() {
+ found := []string{}
+ for _, ev := range os.Environ() {
+ if strings.HasPrefix(ev, "CRUSH_") {
+ pair := strings.SplitN(ev, "=", 2)
+ if len(pair) != 2 {
+ continue
+ }
+ found = append(found, strings.TrimPrefix(pair[0], "CRUSH_"))
+ }
+ }
+ backups := make(map[string]string)
+ for _, ev := range found {
+ backups[ev] = os.Getenv(ev)
+ }
+
+ for _, ev := range found {
+ os.Setenv(ev, os.Getenv("CRUSH_"+ev))
+ }
+
+ restore := func() {
+ for k, v := range backups {
+ os.Setenv(k, v)
+ }
+ }
+ return restore
+}
+
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
knownProviderNames := make(map[string]bool)
+ restore := PushPopCrushEnv()
+ defer restore()
for _, p := range knownProviders {
knownProviderNames[string(p.ID)] = true
config, configExists := c.Providers.Get(string(p.ID))
@@ -495,7 +525,6 @@ func hasAWSCredentials(env env.Env) bool {
env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
return true
}
-
return false
}
@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
@@ -139,6 +140,8 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
}
func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
+ restore := config.PushPopCrushEnv()
+ defer restore()
resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
if err != nil {
return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)