feat: pass client env vars to server and use it for shell variable resolution

Ayman Bagabas created

Change summary

internal/app/lsp.go                              |  2 
internal/cmd/root.go                             |  1 
internal/config/config.go                        | 18 ++----
internal/config/load.go                          |  1 
internal/config/resolve.go                       |  8 +-
internal/config/resolve_test.go                  |  4 
internal/llm/agent/mcp-tools.go                  |  4 
internal/llm/provider/anthropic.go               |  4 
internal/llm/provider/gemini.go                  |  2 
internal/llm/provider/openai.go                  |  4 
internal/llm/provider/provider.go                | 49 +++++++++++------
internal/proto/proto.go                          |  7 ++
internal/server/proto.go                         | 15 +++++
internal/server/server.go                        |  2 
internal/tui/components/chat/splash/splash.go    |  2 
internal/tui/components/dialogs/models/models.go | 17 +++--
internal/tui/tui.go                              |  2 
17 files changed, 87 insertions(+), 55 deletions(-)

Detailed changes

internal/app/lsp.go 🔗

@@ -36,7 +36,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, lspCfg
 	app.updateLSPState(name, lsp.StateStarting, nil, 0)
 
 	// Create LSP client.
-	lspClient, err := lsp.New(ctx, app.config, name, lspCfg, app.config.Resolver())
+	lspClient, err := lsp.New(ctx, app.config, name, lspCfg, config.OsShellResolver)
 	if err != nil {
 		slog.Error("Failed to create LSP client for", name, err)
 		app.updateLSPState(name, lsp.StateError, err, 0)

internal/cmd/root.go 🔗

@@ -220,6 +220,7 @@ func setupApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *proto.Inst
 		DataDir: dataDir,
 		Debug:   debug,
 		YOLO:    yolo,
+		Env:     os.Environ(),
 	})
 	if err != nil {
 		return nil, nil, fmt.Errorf("failed to create or connect to instance: %v", err)

internal/config/config.go 🔗

@@ -16,6 +16,12 @@ import (
 	"github.com/tidwall/sjson"
 )
 
+// OsShellResolver is a shell resolver that uses the current process environment
+// variables.
+//
+// Deprecated: pass resolver explicitly instead.
+var OsShellResolver = NewShellVariableResolver(os.Environ())
+
 const (
 	defaultDataDirectory = ".crush"
 )
@@ -267,7 +273,6 @@ type Config struct {
 	// TODO: most likely remove this concept when I come back to it
 	Agents map[string]Agent `json:"-"`
 	// TODO: find a better way to do this this should probably not be part of the config
-	resolver       VariableResolver
 	dataConfigDir  string             `json:"-"`
 	knownProviders []catwalk.Provider `json:"-"`
 }
@@ -345,13 +350,6 @@ func (c *Config) SetCompactMode(enabled bool) error {
 	return c.SetConfigField("options.tui.compact_mode", enabled)
 }
 
-func (c *Config) Resolve(key string) (string, error) {
-	if c.resolver == nil {
-		return "", fmt.Errorf("no variable resolver configured")
-	}
-	return c.resolver.ResolveValue(key)
-}
-
 func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
 	c.Models[modelType] = model
 	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
@@ -494,10 +492,6 @@ func (c *Config) SetupAgents() {
 	c.Agents = agents
 }
 
-func (c *Config) Resolver() VariableResolver {
-	return c.resolver
-}
-
 func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
 	testURL := ""
 	headers := make(map[string]string)

internal/config/load.go 🔗

@@ -67,7 +67,6 @@ func Load(workingDir, dataDir string, debug bool, envs []string) (*Config, error
 
 	// Configure providers
 	valueResolver := NewShellVariableResolver(envs)
-	cfg.resolver = valueResolver
 	if err := cfg.configureProviders(envs, valueResolver, cfg.knownProviders); err != nil {
 		return nil, fmt.Errorf("failed to configure providers: %w", err)
 	}

internal/config/resolve.go 🔗

@@ -17,13 +17,13 @@ type Shell interface {
 	Exec(ctx context.Context, command string) (stdout, stderr string, err error)
 }
 
-type shellVariableResolver struct {
+type ShellVariableResolver struct {
 	shell Shell
 	env   []string
 }
 
-func NewShellVariableResolver(env []string) VariableResolver {
-	return &shellVariableResolver{
+func NewShellVariableResolver(env []string) *ShellVariableResolver {
+	return &ShellVariableResolver{
 		env: env,
 		shell: shell.NewShell(
 			&shell.Options{
@@ -38,7 +38,7 @@ func NewShellVariableResolver(env []string) VariableResolver {
 // - $(command) for command substitution
 // - $VAR or ${VAR} for environment variables
 // TODO: can we replace this with [os.Expand](https://pkg.go.dev/os#Expand) somehow?
-func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
+func (r *ShellVariableResolver) ResolveValue(value string) (string, error) {
 	// Special case: lone $ is an error (backward compatibility)
 	if value == "$" {
 		return "", fmt.Errorf("invalid value format: %s", value)

internal/config/resolve_test.go 🔗

@@ -76,7 +76,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
 			testEnv := environ(tt.envVars)
-			resolver := &shellVariableResolver{
+			resolver := &ShellVariableResolver{
 				shell: &mockShell{execFunc: tt.shellFunc},
 				env:   testEnv,
 			}
@@ -241,7 +241,7 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
 			testEnv := environ(tt.envVars)
-			resolver := &shellVariableResolver{
+			resolver := &ShellVariableResolver{
 				shell: &mockShell{execFunc: tt.shellFunc},
 				env:   testEnv,
 			}

internal/llm/agent/mcp-tools.go 🔗

@@ -136,7 +136,7 @@ func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*cl
 	}
 	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
 
-	c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
+	c, err = createAndInitializeClient(ctx, name, m, config.OsShellResolver)
 	if err != nil {
 		return nil, err
 	}
@@ -288,7 +288,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 
 			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
 			defer cancel()
-			c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
+			c, err := createAndInitializeClient(ctx, name, m, config.OsShellResolver)
 			if err != nil {
 				return
 			}

internal/llm/provider/anthropic.go 🔗

@@ -81,7 +81,7 @@ func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) a
 	}
 
 	if opts.baseURL != "" {
-		resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL)
+		resolvedBaseURL, err := opts.resolver.ResolveValue(opts.baseURL)
 		if err == nil && resolvedBaseURL != "" {
 			anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(resolvedBaseURL))
 		}
@@ -496,7 +496,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
 	if apiErr.StatusCode == http.StatusUnauthorized {
 		prev := a.providerOptions.apiKey
 		// in case the key comes from a script, we try to re-evaluate it.
-		a.providerOptions.apiKey, err = a.providerOptions.cfg.Resolve(a.providerOptions.config.APIKey)
+		a.providerOptions.apiKey, err = a.providerOptions.resolver.ResolveValue(a.providerOptions.config.APIKey)
 		if err != nil {
 			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
 		}

internal/llm/provider/gemini.go 🔗

@@ -438,7 +438,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
 	if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
 		prev := g.providerOptions.apiKey
 		// in case the key comes from a script, we try to re-evaluate it.
-		g.providerOptions.apiKey, err = g.providerOptions.cfg.Resolve(g.providerOptions.config.APIKey)
+		g.providerOptions.apiKey, err = g.providerOptions.resolver.ResolveValue(g.providerOptions.config.APIKey)
 		if err != nil {
 			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
 		}

internal/llm/provider/openai.go 🔗

@@ -43,7 +43,7 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
 		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 	}
 	if opts.baseURL != "" {
-		resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL)
+		resolvedBaseURL, err := opts.resolver.ResolveValue(opts.baseURL)
 		if err == nil && resolvedBaseURL != "" {
 			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 		}
@@ -517,7 +517,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
 		if apiErr.StatusCode == http.StatusUnauthorized {
 			prev := o.providerOptions.apiKey
 			// in case the key comes from a script, we try to re-evaluate it.
-			o.providerOptions.apiKey, err = o.providerOptions.cfg.Resolve(o.providerOptions.config.APIKey)
+			o.providerOptions.apiKey, err = o.providerOptions.resolver.ResolveValue(o.providerOptions.config.APIKey)
 			if err != nil {
 				return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
 			}

internal/llm/provider/provider.go 🔗

@@ -64,6 +64,8 @@ type Provider interface {
 type providerClientOptions struct {
 	cfg *config.Config
 
+	resolver config.VariableResolver
+
 	baseURL            string
 	config             config.ProviderConfig
 	apiKey             string
@@ -141,30 +143,17 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
 	}
 }
 
-func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
-	restore := config.PushPopCrushEnv()
-	defer restore()
-	resolvedAPIKey, err := cfg.Resolve(pcfg.APIKey)
-	if err != nil {
-		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
-	}
-
-	// Resolve extra headers
-	resolvedExtraHeaders := make(map[string]string)
-	for key, value := range pcfg.ExtraHeaders {
-		resolvedValue, err := cfg.Resolve(value)
-		if err != nil {
-			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
-		}
-		resolvedExtraHeaders[key] = resolvedValue
+func WithResolver(resolver config.VariableResolver) ProviderClientOption {
+	return func(options *providerClientOptions) {
+		options.resolver = resolver
 	}
+}
 
+func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
 	clientOptions := providerClientOptions{
 		cfg:                cfg,
 		baseURL:            pcfg.BaseURL,
 		config:             pcfg,
-		apiKey:             resolvedAPIKey,
-		extraHeaders:       resolvedExtraHeaders,
 		extraBody:          pcfg.ExtraBody,
 		extraParams:        pcfg.ExtraParams,
 		systemPromptPrefix: pcfg.SystemPromptPrefix,
@@ -175,6 +164,30 @@ func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...Provide
 	for _, o := range opts {
 		o(&clientOptions)
 	}
+	if clientOptions.resolver == nil {
+		clientOptions.resolver = config.OsShellResolver
+	}
+
+	restore := config.PushPopCrushEnv()
+	defer restore()
+	resolvedAPIKey, err := clientOptions.resolver.ResolveValue(pcfg.APIKey)
+	if err != nil {
+		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
+	}
+
+	// Resolve extra headers
+	resolvedExtraHeaders := make(map[string]string)
+	for key, value := range pcfg.ExtraHeaders {
+		resolvedValue, err := clientOptions.resolver.ResolveValue(value)
+		if err != nil {
+			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
+		}
+		resolvedExtraHeaders[key] = resolvedValue
+	}
+
+	clientOptions.apiKey = resolvedAPIKey
+	clientOptions.extraHeaders = resolvedExtraHeaders
+
 	switch pcfg.Type {
 	case catwalk.TypeAnthropic:
 		return &baseProvider[AnthropicClient]{

internal/proto/proto.go 🔗

@@ -17,6 +17,13 @@ type Instance struct {
 	Debug   bool           `json:"debug,omitempty"`
 	DataDir string         `json:"data_dir,omitempty"`
 	Config  *config.Config `json:"config,omitempty"`
+	Env     []string       `json:"env,omitempty"`
+}
+
+// ShellResolver returns a new [config.ShellResolver] based on the instance's
+// environment variables.
+func (i Instance) ShellResolver() *config.ShellVariableResolver {
+	return config.NewShellVariableResolver(i.Env)
 }
 
 // Error represents an error response.

internal/server/proto.go 🔗

@@ -460,6 +460,19 @@ func (c *controllerV1) handleGetInstancePermissionsSkip(w http.ResponseWriter, r
 	jsonEncode(w, proto.PermissionSkipRequest{Skip: skip})
 }
 
+func (c *controllerV1) handleGetInstanceProviders(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	providers, _ := config.Providers(ins.cfg)
+	jsonEncode(w, providers)
+}
+
 func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) {
 	flusher := http.NewResponseController(w)
 	id := r.PathValue("id")
@@ -587,6 +600,7 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques
 		id:    id,
 		path:  args.Path,
 		cfg:   cfg,
+		env:   args.Env,
 	}
 
 	c.instances.Set(id, ins)
@@ -597,6 +611,7 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques
 		Debug:   cfg.Options.Debug,
 		YOLO:    cfg.Permissions.SkipRequests,
 		Config:  cfg,
+		Env:     args.Env,
 	})
 }
 

internal/server/server.go 🔗

@@ -40,6 +40,7 @@ type Instance struct {
 	cfg   *config.Config
 	id    string
 	path  string
+	env   []string
 }
 
 // ParseHostURL parses a host URL into a [url.URL].
@@ -134,6 +135,7 @@ func NewServer(cfg *config.Config, network, address string) *Server {
 	mux.HandleFunc("GET /v1/instances/{id}", c.handleGetInstance)
 	mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
 	mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
+	mux.HandleFunc("GET /v1/instances/{id}/providers", c.handleGetInstanceProviders)
 	mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
 	mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
 	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)

internal/tui/components/chat/splash/splash.go 🔗

@@ -218,7 +218,7 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 					}),
 					func() tea.Msg {
 						start := time.Now()
-						err := providerConfig.TestConnection(s.ins.Config.Resolver())
+						err := providerConfig.TestConnection(s.ins.ShellResolver())
 						// intentionally wait for at least 750ms to make sure the user sees the spinner
 						elapsed := time.Since(start)
 						if elapsed < 750*time.Millisecond {

internal/tui/components/dialogs/models/models.go 🔗

@@ -10,6 +10,7 @@ import (
 	tea "github.com/charmbracelet/bubbletea/v2"
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/proto"
 	"github.com/charmbracelet/crush/internal/tui/components/core"
 	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 	"github.com/charmbracelet/crush/internal/tui/exp/list"
@@ -68,10 +69,10 @@ type modelDialogCmp struct {
 	isAPIKeyValid     bool
 	apiKeyValue       string
 
-	cfg *config.Config
+	ins *proto.Instance
 }
 
-func NewModelDialogCmp(cfg *config.Config) ModelDialog {
+func NewModelDialogCmp(ins *proto.Instance) ModelDialog {
 	keyMap := DefaultKeyMap()
 
 	listKeyMap := list.DefaultKeyMap()
@@ -81,7 +82,7 @@ func NewModelDialogCmp(cfg *config.Config) ModelDialog {
 	listKeyMap.UpOneItem = keyMap.Previous
 
 	t := styles.CurrentTheme()
-	modelList := NewModelListComponent(cfg, listKeyMap, largeModelInputPlaceholder, true)
+	modelList := NewModelListComponent(ins.Config, listKeyMap, largeModelInputPlaceholder, true)
 	apiKeyInput := NewAPIKeyInput()
 	apiKeyInput.SetShowTitle(false)
 	help := help.New()
@@ -93,7 +94,7 @@ func NewModelDialogCmp(cfg *config.Config) ModelDialog {
 		width:       defaultWidth,
 		keyMap:      DefaultKeyMap(),
 		help:        help,
-		cfg:         cfg,
+		ins:         ins,
 	}
 }
 
@@ -139,7 +140,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 					}),
 					func() tea.Msg {
 						start := time.Now()
-						err := providerConfig.TestConnection(m.cfg.Resolver())
+						err := providerConfig.TestConnection(m.ins.ShellResolver())
 						// intentionally wait for at least 750ms to make sure the user sees the spinner
 						elapsed := time.Since(start)
 						if elapsed < 750*time.Millisecond {
@@ -347,14 +348,14 @@ func (m *modelDialogCmp) modelTypeRadio() string {
 }
 
 func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
-	if _, ok := m.cfg.Providers.Get(providerID); ok {
+	if _, ok := m.ins.Config.Providers.Get(providerID); ok {
 		return true
 	}
 	return false
 }
 
 func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
-	providers, err := config.Providers(m.cfg)
+	providers, err := config.Providers(m.ins.Config)
 	if err != nil {
 		return nil, err
 	}
@@ -371,7 +372,7 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
 		return util.ReportError(fmt.Errorf("no model selected"))
 	}
 
-	err := m.cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
+	err := m.ins.Config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
 	if err != nil {
 		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
 	}

internal/tui/tui.go 🔗

@@ -178,7 +178,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	case commands.SwitchModelMsg:
 		return a, util.CmdHandler(
 			dialogs.OpenDialogMsg{
-				Model: models.NewModelDialogCmp(a.ins.Config),
+				Model: models.NewModelDialogCmp(a.ins),
 			},
 		)
 	// Compact