diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index 36ac4ec1a0e3fac68d7995f230899ba534141b03..4b2c4f8691053676da3395524b82a55a62aad8f8 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -447,8 +447,12 @@ func createTransport(ctx context.Context, m config.MCPConfig, resolver config.Va if strings.TrimSpace(command) == "" { return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field") } + envs, err := m.ResolvedEnv() + if err != nil { + return nil, err + } cmd := exec.CommandContext(ctx, home.Long(command), m.Args...) - cmd.Env = append(os.Environ(), m.ResolvedEnv()...) + cmd.Env = append(os.Environ(), envs...) return &mcp.CommandTransport{ Command: cmd, }, nil @@ -456,9 +460,13 @@ func createTransport(ctx context.Context, m config.MCPConfig, resolver config.Va if strings.TrimSpace(m.URL) == "" { return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field") } + headers, err := m.ResolvedHeaders() + if err != nil { + return nil, err + } client := &http.Client{ Transport: &headerRoundTripper{ - headers: m.ResolvedHeaders(), + headers: headers, }, } return &mcp.StreamableClientTransport{ @@ -469,9 +477,13 @@ func createTransport(ctx context.Context, m config.MCPConfig, resolver config.Va if strings.TrimSpace(m.URL) == "" { return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field") } + headers, err := m.ResolvedHeaders() + if err != nil { + return nil, err + } client := &http.Client{ Transport: &headerRoundTripper{ - headers: m.ResolvedHeaders(), + headers: headers, }, } return &mcp.SSEClientTransport{ diff --git a/internal/config/config.go b/internal/config/config.go index 33251615d44252aaeb2b5db58577759b5dfdff51..3f6daaef71742d1ad30e6d66d8de68a3c79977ee 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "log/slog" "maps" "net/http" "net/url" @@ -301,25 +300,46 @@ func (l LSPs) Sorted() []LSP { return sorted } -func (l LSPConfig) ResolvedEnv() []string { - return resolveEnvs(l.Env) -} - -func (m MCPConfig) ResolvedEnv() []string { +// ResolvedEnv returns m.Env with every value expanded through the +// shell resolver. The returned slice is of the form "KEY=value" sorted +// by key so callers get deterministic output; the receiver's Env map is +// not mutated. On the first resolution failure it returns nil and an +// error that identifies the offending key; the inner resolver error is +// already sanitized by ResolveValue and is wrapped with %w so +// errors.Is/As continues to work. Callers are expected to surface it +// (for MCP, via StateError on the status card) rather than silently +// spawn the server with an empty credential. +func (m MCPConfig) ResolvedEnv() ([]string, error) { return resolveEnvs(m.Env) } -func (m MCPConfig) ResolvedHeaders() map[string]string { +// ResolvedHeaders returns m.Headers with every value expanded through +// the shell resolver. A fresh map is allocated; m.Headers is never +// mutated. On the first resolution failure it returns nil and an error +// identifying the offending header name; the inner resolver error is +// already sanitized by ResolveValue and is wrapped with %w so +// errors.Is/As continues to work. +func (m MCPConfig) ResolvedHeaders() (map[string]string, error) { + if len(m.Headers) == 0 { + return map[string]string{}, nil + } resolver := NewShellVariableResolver(env.New()) - for e, v := range m.Headers { - var err error - m.Headers[e], err = resolver.ResolveValue(v) + out := make(map[string]string, len(m.Headers)) + // Sort keys so failures are reported deterministically when more + // than one header would fail. + keys := make([]string, 0, len(m.Headers)) + for k := range m.Headers { + keys = append(keys, k) + } + slices.Sort(keys) + for _, k := range keys { + v, err := resolver.ResolveValue(m.Headers[k]) if err != nil { - slog.Error("Error resolving header variable", "error", err, "variable", e, "value", v) - continue + return nil, fmt.Errorf("header %s: %w", k, err) } + out[k] = v } - return m.Headers + return out, nil } type Agent struct { @@ -659,22 +679,30 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { return nil } -func resolveEnvs(envs map[string]string) []string { +// resolveEnvs expands every value in envs through the shell resolver +// and returns a fresh "KEY=value" slice sorted by key. The input map is +// not mutated. On the first resolution failure it returns nil and an +// error identifying the offending variable; the inner resolver error is +// already sanitized by ResolveValue and is wrapped with %w. +func resolveEnvs(envs map[string]string) ([]string, error) { + if len(envs) == 0 { + return nil, nil + } resolver := NewShellVariableResolver(env.New()) - for e, v := range envs { - var err error - envs[e], err = resolver.ResolveValue(v) - if err != nil { - slog.Error("Error resolving environment variable", "error", err, "variable", e, "value", v) - continue - } + keys := make([]string, 0, len(envs)) + for k := range envs { + keys = append(keys, k) } - + slices.Sort(keys) res := make([]string, 0, len(envs)) - for k, v := range envs { + for _, k := range keys { + v, err := resolver.ResolveValue(envs[k]) + if err != nil { + return nil, fmt.Errorf("env %s: %w", k, err) + } res = append(res, fmt.Sprintf("%s=%s", k, v)) } - return res + return res, nil } func ptrValOr[T any](t *T, el T) T {