From efb092fc8c53197702f5ec1ed3b1a830fc19ce2a Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 25 Sep 2025 15:51:44 -0400 Subject: [PATCH] feat: pass client env vars to server and use it for shell variable resolution --- internal/app/lsp.go | 2 +- internal/cmd/root.go | 1 + internal/config/config.go | 18 +++---- internal/config/load.go | 1 - internal/config/resolve.go | 8 +-- internal/config/resolve_test.go | 4 +- internal/llm/agent/mcp-tools.go | 4 +- internal/llm/provider/anthropic.go | 4 +- internal/llm/provider/gemini.go | 2 +- internal/llm/provider/openai.go | 4 +- internal/llm/provider/provider.go | 49 ++++++++++++------- internal/proto/proto.go | 7 +++ internal/server/proto.go | 15 ++++++ internal/server/server.go | 2 + internal/tui/components/chat/splash/splash.go | 2 +- .../tui/components/dialogs/models/models.go | 17 ++++--- internal/tui/tui.go | 2 +- 17 files changed, 87 insertions(+), 55 deletions(-) diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 2994f1f688c22d3b85d224c1389b38822bddb24e..7e0341fb85865b85dc50654d616175d9a833a4be 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -36,7 +36,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, lspCfg app.updateLSPState(name, lsp.StateStarting, nil, 0) // Create LSP client. - lspClient, err := lsp.New(ctx, app.config, name, lspCfg, app.config.Resolver()) + lspClient, err := lsp.New(ctx, app.config, name, lspCfg, config.OsShellResolver) if err != nil { slog.Error("Failed to create LSP client for", name, err) app.updateLSPState(name, lsp.StateError, err, 0) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index e3cae2f6e4c4b073806686eb94537793308fd3e7..65195a079aa84833813212f3f19457a4a975e7f5 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -220,6 +220,7 @@ func setupApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *proto.Inst DataDir: dataDir, Debug: debug, YOLO: yolo, + Env: os.Environ(), }) if err != nil { return nil, nil, fmt.Errorf("failed to create or connect to instance: %v", err) diff --git a/internal/config/config.go b/internal/config/config.go index a7adf2c23889762d0f4250b18491cb85d6b74f1c..3e79e49584044c32a41027742c1a30c284b6c05f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,6 +16,12 @@ import ( "github.com/tidwall/sjson" ) +// OsShellResolver is a shell resolver that uses the current process environment +// variables. +// +// Deprecated: pass resolver explicitly instead. +var OsShellResolver = NewShellVariableResolver(os.Environ()) + const ( defaultDataDirectory = ".crush" ) @@ -267,7 +273,6 @@ type Config struct { // TODO: most likely remove this concept when I come back to it Agents map[string]Agent `json:"-"` // TODO: find a better way to do this this should probably not be part of the config - resolver VariableResolver dataConfigDir string `json:"-"` knownProviders []catwalk.Provider `json:"-"` } @@ -345,13 +350,6 @@ func (c *Config) SetCompactMode(enabled bool) error { return c.SetConfigField("options.tui.compact_mode", enabled) } -func (c *Config) Resolve(key string) (string, error) { - if c.resolver == nil { - return "", fmt.Errorf("no variable resolver configured") - } - return c.resolver.ResolveValue(key) -} - func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { c.Models[modelType] = model if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { @@ -494,10 +492,6 @@ func (c *Config) SetupAgents() { c.Agents = agents } -func (c *Config) Resolver() VariableResolver { - return c.resolver -} - func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { testURL := "" headers := make(map[string]string) diff --git a/internal/config/load.go b/internal/config/load.go index 2fe18058ae83b0cfe4802d6c501f349a27768af7..873e9b6d506a26fba0892b5c67ace6b080f934ec 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -67,7 +67,6 @@ func Load(workingDir, dataDir string, debug bool, envs []string) (*Config, error // Configure providers valueResolver := NewShellVariableResolver(envs) - cfg.resolver = valueResolver if err := cfg.configureProviders(envs, valueResolver, cfg.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } diff --git a/internal/config/resolve.go b/internal/config/resolve.go index 9ebe425ba95f98996f0bbae0368710d755182fd2..6eb6cc5d8b14ec0c3b1d465662cc39eb97d4fdf8 100644 --- a/internal/config/resolve.go +++ b/internal/config/resolve.go @@ -17,13 +17,13 @@ type Shell interface { Exec(ctx context.Context, command string) (stdout, stderr string, err error) } -type shellVariableResolver struct { +type ShellVariableResolver struct { shell Shell env []string } -func NewShellVariableResolver(env []string) VariableResolver { - return &shellVariableResolver{ +func NewShellVariableResolver(env []string) *ShellVariableResolver { + return &ShellVariableResolver{ env: env, shell: shell.NewShell( &shell.Options{ @@ -38,7 +38,7 @@ func NewShellVariableResolver(env []string) VariableResolver { // - $(command) for command substitution // - $VAR or ${VAR} for environment variables // TODO: can we replace this with [os.Expand](https://pkg.go.dev/os#Expand) somehow? -func (r *shellVariableResolver) ResolveValue(value string) (string, error) { +func (r *ShellVariableResolver) ResolveValue(value string) (string, error) { // Special case: lone $ is an error (backward compatibility) if value == "$" { return "", fmt.Errorf("invalid value format: %s", value) diff --git a/internal/config/resolve_test.go b/internal/config/resolve_test.go index 9d4ccb92114446c392e593cd57083b8f71d2df1a..2cb81aa5a194cb256287187814e263952559d4c0 100644 --- a/internal/config/resolve_test.go +++ b/internal/config/resolve_test.go @@ -76,7 +76,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { testEnv := environ(tt.envVars) - resolver := &shellVariableResolver{ + resolver := &ShellVariableResolver{ shell: &mockShell{execFunc: tt.shellFunc}, env: testEnv, } @@ -241,7 +241,7 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { testEnv := environ(tt.envVars) - resolver := &shellVariableResolver{ + resolver := &ShellVariableResolver{ shell: &mockShell{execFunc: tt.shellFunc}, env: testEnv, } diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 41415c474d9b236f2e451b526396a6168012efa7..b36c6d85734d37b7792ea4bbf3556ab37c0eb4d1 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -136,7 +136,7 @@ func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*cl } updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount) - c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver()) + c, err = createAndInitializeClient(ctx, name, m, config.OsShellResolver) if err != nil { return nil, err } @@ -288,7 +288,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) defer cancel() - c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver()) + c, err := createAndInitializeClient(ctx, name, m, config.OsShellResolver) if err != nil { return } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index f0fb6b06368afe465522ab20f10a9d49d5f315d4..0b2afacab18777bd364d24e44555eab04aa58cb7 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -81,7 +81,7 @@ func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) a } if opts.baseURL != "" { - resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL) + resolvedBaseURL, err := opts.resolver.ResolveValue(opts.baseURL) if err == nil && resolvedBaseURL != "" { anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(resolvedBaseURL)) } @@ -496,7 +496,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err if apiErr.StatusCode == http.StatusUnauthorized { prev := a.providerOptions.apiKey // in case the key comes from a script, we try to re-evaluate it. - a.providerOptions.apiKey, err = a.providerOptions.cfg.Resolve(a.providerOptions.config.APIKey) + a.providerOptions.apiKey, err = a.providerOptions.resolver.ResolveValue(a.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index b0c53e75d3a2cad501e25184a8174391e0b2dbbd..1dee7f70577ec8c5dfbd7ec5d1c112f5332ba7f9 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -438,7 +438,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) if contains(errMsg, "unauthorized", "invalid api key", "api key expired") { prev := g.providerOptions.apiKey // in case the key comes from a script, we try to re-evaluate it. - g.providerOptions.apiKey, err = g.providerOptions.cfg.Resolve(g.providerOptions.config.APIKey) + g.providerOptions.apiKey, err = g.providerOptions.resolver.ResolveValue(g.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 5ff2535ef112397e155c02bafd866da4b0bc8b6a..6aa33a80d42b1e1b710ab400c45ae59afb499acf 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -43,7 +43,7 @@ func createOpenAIClient(opts providerClientOptions) openai.Client { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } if opts.baseURL != "" { - resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL) + resolvedBaseURL, err := opts.resolver.ResolveValue(opts.baseURL) if err == nil && resolvedBaseURL != "" { openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL)) } @@ -517,7 +517,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) if apiErr.StatusCode == http.StatusUnauthorized { prev := o.providerOptions.apiKey // in case the key comes from a script, we try to re-evaluate it. - o.providerOptions.apiKey, err = o.providerOptions.cfg.Resolve(o.providerOptions.config.APIKey) + o.providerOptions.apiKey, err = o.providerOptions.resolver.ResolveValue(o.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 37cae1e7aa2b021aaa3c06a2f1fb6054fbe91cb9..965d76d164bcbcc79e5884786d9ca9d83e3dc73c 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -64,6 +64,8 @@ type Provider interface { type providerClientOptions struct { cfg *config.Config + resolver config.VariableResolver + baseURL string config config.ProviderConfig apiKey string @@ -141,30 +143,17 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption { } } -func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { - restore := config.PushPopCrushEnv() - defer restore() - resolvedAPIKey, err := cfg.Resolve(pcfg.APIKey) - if err != nil { - return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err) - } - - // Resolve extra headers - resolvedExtraHeaders := make(map[string]string) - for key, value := range pcfg.ExtraHeaders { - resolvedValue, err := cfg.Resolve(value) - if err != nil { - return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err) - } - resolvedExtraHeaders[key] = resolvedValue +func WithResolver(resolver config.VariableResolver) ProviderClientOption { + return func(options *providerClientOptions) { + options.resolver = resolver } +} +func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { clientOptions := providerClientOptions{ cfg: cfg, baseURL: pcfg.BaseURL, config: pcfg, - apiKey: resolvedAPIKey, - extraHeaders: resolvedExtraHeaders, extraBody: pcfg.ExtraBody, extraParams: pcfg.ExtraParams, systemPromptPrefix: pcfg.SystemPromptPrefix, @@ -175,6 +164,30 @@ func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...Provide for _, o := range opts { o(&clientOptions) } + if clientOptions.resolver == nil { + clientOptions.resolver = config.OsShellResolver + } + + restore := config.PushPopCrushEnv() + defer restore() + resolvedAPIKey, err := clientOptions.resolver.ResolveValue(pcfg.APIKey) + if err != nil { + return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err) + } + + // Resolve extra headers + resolvedExtraHeaders := make(map[string]string) + for key, value := range pcfg.ExtraHeaders { + resolvedValue, err := clientOptions.resolver.ResolveValue(value) + if err != nil { + return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err) + } + resolvedExtraHeaders[key] = resolvedValue + } + + clientOptions.apiKey = resolvedAPIKey + clientOptions.extraHeaders = resolvedExtraHeaders + switch pcfg.Type { case catwalk.TypeAnthropic: return &baseProvider[AnthropicClient]{ diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 37a4ac9aa78b88bfd6dc6b2f0d8911467cf21af2..6bade2972a9fd030948fec550798fca5876e1e38 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -17,6 +17,13 @@ type Instance struct { Debug bool `json:"debug,omitempty"` DataDir string `json:"data_dir,omitempty"` Config *config.Config `json:"config,omitempty"` + Env []string `json:"env,omitempty"` +} + +// ShellResolver returns a new [config.ShellResolver] based on the instance's +// environment variables. +func (i Instance) ShellResolver() *config.ShellVariableResolver { + return config.NewShellVariableResolver(i.Env) } // Error represents an error response. diff --git a/internal/server/proto.go b/internal/server/proto.go index 7cc626b238c86890ef6034294e80eacc7643bc4c..5caa59a47a6116c2a3223b206c04879a8e1f59aa 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -460,6 +460,19 @@ func (c *controllerV1) handleGetInstancePermissionsSkip(w http.ResponseWriter, r jsonEncode(w, proto.PermissionSkipRequest{Skip: skip}) } +func (c *controllerV1) handleGetInstanceProviders(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + providers, _ := config.Providers(ins.cfg) + jsonEncode(w, providers) +} + func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) { flusher := http.NewResponseController(w) id := r.PathValue("id") @@ -587,6 +600,7 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques id: id, path: args.Path, cfg: cfg, + env: args.Env, } c.instances.Set(id, ins) @@ -597,6 +611,7 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques Debug: cfg.Options.Debug, YOLO: cfg.Permissions.SkipRequests, Config: cfg, + Env: args.Env, }) } diff --git a/internal/server/server.go b/internal/server/server.go index afce3d66d5db6164b1f1440b04e1950c3d4edd5b..ce136573e8411ea485aed885f6c9de6c6e36181f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -40,6 +40,7 @@ type Instance struct { cfg *config.Config id string path string + env []string } // ParseHostURL parses a host URL into a [url.URL]. @@ -134,6 +135,7 @@ func NewServer(cfg *config.Config, network, address string) *Server { mux.HandleFunc("GET /v1/instances/{id}", c.handleGetInstance) mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig) mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents) + mux.HandleFunc("GET /v1/instances/{id}/providers", c.handleGetInstanceProviders) mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions) mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions) mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession) diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 88585a78bfcdefb435e3a4bf02b84e45b52c8352..9bbe3b8e030768a69c139e2a9f6d03320e84f669 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -218,7 +218,7 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { }), func() tea.Msg { start := time.Now() - err := providerConfig.TestConnection(s.ins.Config.Resolver()) + err := providerConfig.TestConnection(s.ins.ShellResolver()) // intentionally wait for at least 750ms to make sure the user sees the spinner elapsed := time.Since(start) if elapsed < 750*time.Millisecond { diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 1301b7ffb0b18661ac0d710c54d92469c45e2765..b4c92f9f12c8aad864a5d8fc0aea6f938bfb0306 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -10,6 +10,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" "github.com/charmbracelet/crush/internal/tui/exp/list" @@ -68,10 +69,10 @@ type modelDialogCmp struct { isAPIKeyValid bool apiKeyValue string - cfg *config.Config + ins *proto.Instance } -func NewModelDialogCmp(cfg *config.Config) ModelDialog { +func NewModelDialogCmp(ins *proto.Instance) ModelDialog { keyMap := DefaultKeyMap() listKeyMap := list.DefaultKeyMap() @@ -81,7 +82,7 @@ func NewModelDialogCmp(cfg *config.Config) ModelDialog { listKeyMap.UpOneItem = keyMap.Previous t := styles.CurrentTheme() - modelList := NewModelListComponent(cfg, listKeyMap, largeModelInputPlaceholder, true) + modelList := NewModelListComponent(ins.Config, listKeyMap, largeModelInputPlaceholder, true) apiKeyInput := NewAPIKeyInput() apiKeyInput.SetShowTitle(false) help := help.New() @@ -93,7 +94,7 @@ func NewModelDialogCmp(cfg *config.Config) ModelDialog { width: defaultWidth, keyMap: DefaultKeyMap(), help: help, - cfg: cfg, + ins: ins, } } @@ -139,7 +140,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { }), func() tea.Msg { start := time.Now() - err := providerConfig.TestConnection(m.cfg.Resolver()) + err := providerConfig.TestConnection(m.ins.ShellResolver()) // intentionally wait for at least 750ms to make sure the user sees the spinner elapsed := time.Since(start) if elapsed < 750*time.Millisecond { @@ -347,14 +348,14 @@ func (m *modelDialogCmp) modelTypeRadio() string { } func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { - if _, ok := m.cfg.Providers.Get(providerID); ok { + if _, ok := m.ins.Config.Providers.Get(providerID); ok { return true } return false } func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { - providers, err := config.Providers(m.cfg) + providers, err := config.Providers(m.ins.Config) if err != nil { return nil, err } @@ -371,7 +372,7 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { return util.ReportError(fmt.Errorf("no model selected")) } - err := m.cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey) + err := m.ins.Config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey) if err != nil { return util.ReportError(fmt.Errorf("failed to save API key: %w", err)) } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 7bef2e6bc59f80fadba64db1a31c9f7a6455854f..5fffd75547e7b13b394790cb08a49fce8fbf4d02 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -178,7 +178,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case commands.SwitchModelMsg: return a, util.CmdHandler( dialogs.OpenDialogMsg{ - Model: models.NewModelDialogCmp(a.ins.Config), + Model: models.NewModelDialogCmp(a.ins), }, ) // Compact