diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index 4b2c4f8691053676da3395524b82a55a62aad8f8..7dd90c37206a018ae173e5e0d977cb7c4e46b4c7 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -447,11 +447,15 @@ func createTransport(ctx context.Context, m config.MCPConfig, resolver config.Va if strings.TrimSpace(command) == "" { return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field") } - envs, err := m.ResolvedEnv() + args, err := m.ResolvedArgs(resolver) if err != nil { return nil, err } - cmd := exec.CommandContext(ctx, home.Long(command), m.Args...) + envs, err := m.ResolvedEnv(resolver) + if err != nil { + return nil, err + } + cmd := exec.CommandContext(ctx, home.Long(command), args...) cmd.Env = append(os.Environ(), envs...) return &mcp.CommandTransport{ Command: cmd, @@ -460,7 +464,7 @@ func createTransport(ctx context.Context, m config.MCPConfig, resolver config.Va if strings.TrimSpace(m.URL) == "" { return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field") } - headers, err := m.ResolvedHeaders() + headers, err := m.ResolvedHeaders(resolver) if err != nil { return nil, err } @@ -477,7 +481,7 @@ func createTransport(ctx context.Context, m config.MCPConfig, resolver config.Va if strings.TrimSpace(m.URL) == "" { return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field") } - headers, err := m.ResolvedHeaders() + headers, err := m.ResolvedHeaders(resolver) if err != nil { return nil, err } diff --git a/internal/config/config.go b/internal/config/config.go index 3f6daaef71742d1ad30e6d66d8de68a3c79977ee..40db4aa68890b68f8199ce89b03280ea61257f99 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,7 +14,6 @@ import ( "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" "github.com/invopop/jsonschema" @@ -301,7 +300,7 @@ func (l LSPs) Sorted() []LSP { } // ResolvedEnv returns m.Env with every value expanded through the -// shell resolver. The returned slice is of the form "KEY=value" sorted +// given resolver. The returned slice is of the form "KEY=value" sorted // by key so callers get deterministic output; the receiver's Env map is // not mutated. On the first resolution failure it returns nil and an // error that identifies the offending key; the inner resolver error is @@ -309,21 +308,49 @@ func (l LSPs) Sorted() []LSP { // errors.Is/As continues to work. Callers are expected to surface it // (for MCP, via StateError on the status card) rather than silently // spawn the server with an empty credential. -func (m MCPConfig) ResolvedEnv() ([]string, error) { - return resolveEnvs(m.Env) +// +// The resolver choice matters: in server mode pass the shell resolver +// so $VAR / $(cmd) expand; in client mode pass IdentityResolver so the +// template is forwarded verbatim and expansion happens on the server. +func (m MCPConfig) ResolvedEnv(r VariableResolver) ([]string, error) { + return resolveEnvs(m.Env, r) +} + +// ResolvedArgs returns m.Args with every element expanded through the +// given resolver. A fresh slice is allocated; m.Args is never mutated. +// On the first resolution failure it returns nil and an error +// identifying the offending positional index; the inner resolver error +// is already sanitized by ResolveValue and is wrapped with %w so +// errors.Is/As continues to work. +// +// See ResolvedEnv for guidance on picking a resolver. +func (m MCPConfig) ResolvedArgs(r VariableResolver) ([]string, error) { + if len(m.Args) == 0 { + return nil, nil + } + out := make([]string, len(m.Args)) + for i, a := range m.Args { + v, err := r.ResolveValue(a) + if err != nil { + return nil, fmt.Errorf("arg %d: %w", i, err) + } + out[i] = v + } + return out, nil } // ResolvedHeaders returns m.Headers with every value expanded through -// the shell resolver. A fresh map is allocated; m.Headers is never +// the given resolver. A fresh map is allocated; m.Headers is never // mutated. On the first resolution failure it returns nil and an error // identifying the offending header name; the inner resolver error is // already sanitized by ResolveValue and is wrapped with %w so // errors.Is/As continues to work. -func (m MCPConfig) ResolvedHeaders() (map[string]string, error) { +// +// See ResolvedEnv for guidance on picking a resolver. +func (m MCPConfig) ResolvedHeaders(r VariableResolver) (map[string]string, error) { if len(m.Headers) == 0 { return map[string]string{}, nil } - resolver := NewShellVariableResolver(env.New()) out := make(map[string]string, len(m.Headers)) // Sort keys so failures are reported deterministically when more // than one header would fail. @@ -333,7 +360,7 @@ func (m MCPConfig) ResolvedHeaders() (map[string]string, error) { } slices.Sort(keys) for _, k := range keys { - v, err := resolver.ResolveValue(m.Headers[k]) + v, err := r.ResolveValue(m.Headers[k]) if err != nil { return nil, fmt.Errorf("header %s: %w", k, err) } @@ -679,16 +706,15 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { return nil } -// resolveEnvs expands every value in envs through the shell resolver +// resolveEnvs expands every value in envs through the given resolver // and returns a fresh "KEY=value" slice sorted by key. The input map is // not mutated. On the first resolution failure it returns nil and an // error identifying the offending variable; the inner resolver error is // already sanitized by ResolveValue and is wrapped with %w. -func resolveEnvs(envs map[string]string) ([]string, error) { +func resolveEnvs(envs map[string]string, r VariableResolver) ([]string, error) { if len(envs) == 0 { return nil, nil } - resolver := NewShellVariableResolver(env.New()) keys := make([]string, 0, len(envs)) for k := range envs { keys = append(keys, k) @@ -696,7 +722,7 @@ func resolveEnvs(envs map[string]string) ([]string, error) { slices.Sort(keys) res := make([]string, 0, len(envs)) for _, k := range keys { - v, err := resolver.ResolveValue(envs[k]) + v, err := r.ResolveValue(envs[k]) if err != nil { return nil, fmt.Errorf("env %s: %w", k, err) }