diff --git a/README.md b/README.md index 722697b73ed0c554d19643b8b799520025f8c7d5..7aed61d09a6e07cc2a9402cf7efd845a88919f6a 100644 --- a/README.md +++ b/README.md @@ -293,10 +293,17 @@ like you would. LSPs can be added manually like so: ### MCPs -Crush also supports Model Context Protocol (MCP) servers through three -transport types: `stdio` for command-line servers, `http` for HTTP endpoints, -and `sse` for Server-Sent Events. Environment variable expansion is supported -using `$(echo $VAR)` syntax. +Crush also supports Model Context Protocol (MCP) servers through three transport +types: `stdio` for command-line servers, `http` for HTTP endpoints, and `sse` +for Server-Sent Events. + +Shell-style value expansion (`$VAR`, `${VAR:-default}`, `$(command)`, quoting, +and nesting (works in `command`, `args`, `env`, `headers`, and `url`, so +file-based secrets like work out of the box, so you can use values like +"$TOKEN"` and `"$(cat /path/to/secret/token)"``. Expansion runs through Crush's +embedded shell, so the same syntax works on all supported systems, including +Windows. Unset variables are an error; use `${VAR:-fallback}` to opt in to +a default. ```json { diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index 36ac4ec1a0e3fac68d7995f230899ba534141b03..7284bb063789baa6ec18f488583f781d9ddd5cc5 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -447,35 +447,59 @@ func createTransport(ctx context.Context, m config.MCPConfig, resolver config.Va if strings.TrimSpace(command) == "" { return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field") } - cmd := exec.CommandContext(ctx, home.Long(command), m.Args...) - cmd.Env = append(os.Environ(), m.ResolvedEnv()...) + args, err := m.ResolvedArgs(resolver) + if err != nil { + return nil, err + } + envs, err := m.ResolvedEnv(resolver) + if err != nil { + return nil, err + } + cmd := exec.CommandContext(ctx, home.Long(command), args...) + cmd.Env = append(os.Environ(), envs...) return &mcp.CommandTransport{ Command: cmd, }, nil case config.MCPHttp: - if strings.TrimSpace(m.URL) == "" { + url, err := m.ResolvedURL(resolver) + if err != nil { + return nil, err + } + if strings.TrimSpace(url) == "" { return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field") } + headers, err := m.ResolvedHeaders(resolver) + if err != nil { + return nil, err + } client := &http.Client{ Transport: &headerRoundTripper{ - headers: m.ResolvedHeaders(), + headers: headers, }, } return &mcp.StreamableClientTransport{ - Endpoint: m.URL, + Endpoint: url, HTTPClient: client, }, nil case config.MCPSSE: - if strings.TrimSpace(m.URL) == "" { + url, err := m.ResolvedURL(resolver) + if err != nil { + return nil, err + } + if strings.TrimSpace(url) == "" { return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field") } + headers, err := m.ResolvedHeaders(resolver) + if err != nil { + return nil, err + } client := &http.Client{ Transport: &headerRoundTripper{ - headers: m.ResolvedHeaders(), + headers: headers, }, } return &mcp.SSEClientTransport{ - Endpoint: m.URL, + Endpoint: url, HTTPClient: client, }, nil default: diff --git a/internal/agent/tools/mcp/init_test.go b/internal/agent/tools/mcp/init_test.go index 94958593750852d30ff96734ada23671252e508e..49ceb0410931fb6d20531f440960fb226ab2d768 100644 --- a/internal/agent/tools/mcp/init_test.go +++ b/internal/agent/tools/mcp/init_test.go @@ -2,13 +2,28 @@ package mcp import ( "context" + "maps" + "os" "testing" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/env" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/require" "go.uber.org/goleak" ) +// shellResolverWithPath builds a shell resolver whose env carries PATH +// plus any caller-supplied overrides. Without PATH, $(cat), $(echo), +// etc. can't find their binaries in a test process where the shell env +// is otherwise empty. +func shellResolverWithPath(t *testing.T, overrides map[string]string) config.VariableResolver { + t.Helper() + m := map[string]string{"PATH": os.Getenv("PATH")} + maps.Copy(m, overrides) + return config.NewShellVariableResolver(env.NewFromMap(m)) +} + func TestMCPSession_CancelOnClose(t *testing.T) { defer goleak.VerifyNone(t) @@ -36,3 +51,379 @@ func TestMCPSession_CancelOnClose(t *testing.T) { // After Close, the context must be cancelled. require.ErrorIs(t, ctx.Err(), context.Canceled) } + +// TestCreateTransport_URLResolution pins that m.URL goes through the +// same resolver seam as command, args, env, and headers. Covers both +// the HTTP and SSE branches, success and failure, so a regression in +// ResolvedURL wiring is caught at the transport layer rather than only +// at the config layer. +func TestCreateTransport_URLResolution(t *testing.T) { + t.Parallel() + + shell := config.NewShellVariableResolver(env.NewFromMap(map[string]string{ + "MCP_HOST": "mcp.example.com", + })) + + t.Run("http success expands $VAR", func(t *testing.T) { + t.Parallel() + m := config.MCPConfig{ + Type: config.MCPHttp, + URL: "https://$MCP_HOST/api", + } + tr, err := createTransport(t.Context(), m, shell) + require.NoError(t, err) + require.NotNil(t, tr) + sct, ok := tr.(*mcp.StreamableClientTransport) + require.True(t, ok, "expected StreamableClientTransport, got %T", tr) + require.Equal(t, "https://mcp.example.com/api", sct.Endpoint) + }) + + t.Run("sse success expands $(cmd)", func(t *testing.T) { + t.Parallel() + m := config.MCPConfig{ + Type: config.MCPSSE, + URL: "https://$(echo mcp.example.com)/events", + } + tr, err := createTransport(t.Context(), m, shell) + require.NoError(t, err) + sse, ok := tr.(*mcp.SSEClientTransport) + require.True(t, ok, "expected SSEClientTransport, got %T", tr) + require.Equal(t, "https://mcp.example.com/events", sse.Endpoint) + }) + + t.Run("http unset var surfaces error, no transport created", func(t *testing.T) { + t.Parallel() + m := config.MCPConfig{ + Type: config.MCPHttp, + URL: "https://$MCP_MISSING_HOST/api", + } + tr, err := createTransport(t.Context(), m, shell) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "url:") + require.Contains(t, err.Error(), "$MCP_MISSING_HOST") + }) + + t.Run("sse failing $(cmd) surfaces error, no transport created", func(t *testing.T) { + t.Parallel() + m := config.MCPConfig{ + Type: config.MCPSSE, + URL: "https://$(false)/events", + } + tr, err := createTransport(t.Context(), m, shell) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "url:") + require.Contains(t, err.Error(), "$(false)") + }) + + t.Run("http empty-after-resolve still fails the non-empty guard", func(t *testing.T) { + t.Parallel() + // ${MCP_EMPTY:-} resolves to the empty string (no error), + // then the existing TrimSpace guard in createTransport must + // reject it so we never spawn a transport against "". + m := config.MCPConfig{ + Type: config.MCPHttp, + URL: "${MCP_EMPTY:-}", + } + tr, err := createTransport(t.Context(), m, shell) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "non-empty 'url'") + }) + + t.Run("identity resolver round-trips template verbatim", func(t *testing.T) { + t.Parallel() + // Client mode forwards the template to the server; no local + // expansion, no error on unset vars. + tmpl := "https://$MCP_MISSING_HOST/api" + m := config.MCPConfig{Type: config.MCPHttp, URL: tmpl} + tr, err := createTransport(t.Context(), m, config.IdentityResolver()) + require.NoError(t, err) + sct, ok := tr.(*mcp.StreamableClientTransport) + require.True(t, ok) + require.Equal(t, tmpl, sct.Endpoint) + }) +} + +// TestCreateTransport_StdioResolution pins that command, args, and env +// for stdio MCPs go through the same resolver seam as the other +// transports. Covers both success (expansion produced the expected +// exec.Cmd) and failure (any one field erroring prevents transport +// creation). +func TestCreateTransport_StdioResolution(t *testing.T) { + t.Parallel() + + t.Run("success expands command, args, and env", func(t *testing.T) { + t.Parallel() + r := shellResolverWithPath(t, map[string]string{ + "MY_TOKEN": "hunter2", + }) + m := config.MCPConfig{ + Type: config.MCPStdio, + Command: "forgejo-mcp", + Args: []string{"--token", "$MY_TOKEN", "--host", "$(echo example.com)"}, + Env: map[string]string{ + "SECRET": "$(echo shh)", + "PLAIN": "literal", + "REFERENCE": "$MY_TOKEN", + }, + } + tr, err := createTransport(t.Context(), m, r) + require.NoError(t, err) + require.NotNil(t, tr) + + ct, ok := tr.(*mcp.CommandTransport) + require.True(t, ok, "expected CommandTransport, got %T", tr) + + // exec.Cmd.Args[0] is the command name; the rest are positional + // args as passed. + require.Equal(t, []string{"forgejo-mcp", "--token", "hunter2", "--host", "example.com"}, ct.Command.Args) + + // Env is os.Environ() + resolved entries (sorted). Check the + // resolved entries are present with their expanded values. + require.Contains(t, ct.Command.Env, "SECRET=shh") + require.Contains(t, ct.Command.Env, "PLAIN=literal") + require.Contains(t, ct.Command.Env, "REFERENCE=hunter2") + }) + + t.Run("env resolution failure surfaces error, no transport created", func(t *testing.T) { + t.Parallel() + r := shellResolverWithPath(t, nil) + m := config.MCPConfig{ + Type: config.MCPStdio, + Command: "forgejo-mcp", + Env: map[string]string{"TOKEN": "$(false)"}, + } + tr, err := createTransport(t.Context(), m, r) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "env TOKEN") + }) + + t.Run("unset env var is a hard error, not silent empty", func(t *testing.T) { + t.Parallel() + // The regression at the heart of PLAN.md: unset vars used to + // silently substitute "" and hand an empty credential to the + // child process. Now they must error out before exec. + r := shellResolverWithPath(t, nil) + m := config.MCPConfig{ + Type: config.MCPStdio, + Command: "forgejo-mcp", + Env: map[string]string{"FORGEJO_ACCESS_TOKEN": "$FORGJO_TOKEN"}, + } + tr, err := createTransport(t.Context(), m, r) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "env FORGEJO_ACCESS_TOKEN") + }) + + t.Run("args resolution failure surfaces error, no transport created", func(t *testing.T) { + t.Parallel() + r := shellResolverWithPath(t, nil) + m := config.MCPConfig{ + Type: config.MCPStdio, + Command: "forgejo-mcp", + Args: []string{"--token", "$(false)"}, + } + tr, err := createTransport(t.Context(), m, r) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "arg 1") + }) + + t.Run("command resolution failure surfaces error, no transport created", func(t *testing.T) { + t.Parallel() + r := shellResolverWithPath(t, nil) + m := config.MCPConfig{ + Type: config.MCPStdio, + Command: "$(false)", + } + tr, err := createTransport(t.Context(), m, r) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "invalid mcp command") + }) + + t.Run("identity resolver round-trips templates verbatim", func(t *testing.T) { + t.Parallel() + // Client mode: no local expansion, no error on unset vars. + m := config.MCPConfig{ + Type: config.MCPStdio, + Command: "forgejo-mcp", + Args: []string{"--token", "$MCP_MISSING"}, + Env: map[string]string{"TOKEN": "$(vault read -f token)"}, + } + tr, err := createTransport(t.Context(), m, config.IdentityResolver()) + require.NoError(t, err) + ct, ok := tr.(*mcp.CommandTransport) + require.True(t, ok) + require.Equal(t, []string{"forgejo-mcp", "--token", "$MCP_MISSING"}, ct.Command.Args) + require.Contains(t, ct.Command.Env, "TOKEN=$(vault read -f token)") + }) +} + +// TestCreateTransport_HeadersResolution pins that a single failing +// header aborts HTTP/SSE transport creation and that the successful +// resolver passes every expanded header through to the round tripper. +func TestCreateTransport_HeadersResolution(t *testing.T) { + t.Parallel() + + t.Run("http headers success expands $(cmd)", func(t *testing.T) { + t.Parallel() + r := shellResolverWithPath(t, map[string]string{ + "GITHUB_TOKEN": "gh-secret", + }) + m := config.MCPConfig{ + Type: config.MCPHttp, + URL: "https://mcp.example.com/api", + Headers: map[string]string{ + "Authorization": "$(echo Bearer $GITHUB_TOKEN)", + "X-Static": "kept", + }, + } + tr, err := createTransport(t.Context(), m, r) + require.NoError(t, err) + + sct, ok := tr.(*mcp.StreamableClientTransport) + require.True(t, ok) + rt, ok := sct.HTTPClient.Transport.(*headerRoundTripper) + require.True(t, ok, "expected headerRoundTripper, got %T", sct.HTTPClient.Transport) + require.Equal(t, map[string]string{ + "Authorization": "Bearer gh-secret", + "X-Static": "kept", + }, rt.headers) + }) + + t.Run("http failing header surfaces error, no transport", func(t *testing.T) { + t.Parallel() + r := shellResolverWithPath(t, nil) + m := config.MCPConfig{ + Type: config.MCPHttp, + URL: "https://mcp.example.com/api", + Headers: map[string]string{"Authorization": "$(false)"}, + } + tr, err := createTransport(t.Context(), m, r) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "header Authorization") + }) + + t.Run("sse failing header surfaces error, no transport", func(t *testing.T) { + t.Parallel() + r := shellResolverWithPath(t, nil) + m := config.MCPConfig{ + Type: config.MCPSSE, + URL: "https://mcp.example.com/events", + Headers: map[string]string{"Authorization": "Bearer $MISSING_TOKEN"}, + } + tr, err := createTransport(t.Context(), m, r) + require.Error(t, err) + require.Nil(t, tr) + require.Contains(t, err.Error(), "header Authorization") + require.Contains(t, err.Error(), "$MISSING_TOKEN") + }) +} + +// TestCreateSession_ResolutionFailureUpdatesState pins the user-visible +// half of the regression fix: when any of command/args/env/headers/url +// fails to resolve, createSession must publish StateError to the state +// map so crush_info and the TUI's MCP status card can render a real +// error instead of the MCP silently sitting in "starting" or being +// spawned with an empty credential. +// +// These subtests cannot run in parallel: `states` is a package-level +// csync.Map and each assertion reads the entry written by the call +// under test. They do use unique MCP names per subtest to keep them +// independent regardless of ordering. +func TestCreateSession_ResolutionFailureUpdatesState(t *testing.T) { + r := shellResolverWithPath(t, nil) + + tests := []struct { + name string + mcpName string + cfg config.MCPConfig + wantErrContains string + }{ + { + name: "stdio env failure", + mcpName: "test-stdio-env-fail", + cfg: config.MCPConfig{ + Type: config.MCPStdio, + Command: "echo", + Env: map[string]string{"FORGEJO_ACCESS_TOKEN": "$(false)"}, + }, + wantErrContains: "env FORGEJO_ACCESS_TOKEN", + }, + { + name: "stdio args failure", + mcpName: "test-stdio-args-fail", + cfg: config.MCPConfig{ + Type: config.MCPStdio, + Command: "echo", + Args: []string{"--token", "$MCP_UNSET_TOKEN"}, + }, + wantErrContains: "arg 1", + }, + { + name: "http url failure", + mcpName: "test-http-url-fail", + cfg: config.MCPConfig{ + Type: config.MCPHttp, + URL: "https://$MCP_MISSING_HOST/api", + }, + wantErrContains: "url:", + }, + { + name: "http header failure", + mcpName: "test-http-header-fail", + cfg: config.MCPConfig{ + Type: config.MCPHttp, + URL: "https://mcp.example.com/api", + Headers: map[string]string{"Authorization": "$(false)"}, + }, + wantErrContains: "header Authorization", + }, + { + name: "sse url failure", + mcpName: "test-sse-url-fail", + cfg: config.MCPConfig{ + Type: config.MCPSSE, + URL: "https://$(false)/events", + }, + wantErrContains: "url:", + }, + { + name: "sse header failure", + mcpName: "test-sse-header-fail", + cfg: config.MCPConfig{ + Type: config.MCPSSE, + URL: "https://mcp.example.com/events", + Headers: map[string]string{"Authorization": "Bearer $MISSING_SSE_TOKEN"}, + }, + wantErrContains: "header Authorization", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Guarantee a clean slate on the shared state map so a + // stale entry from another test can't satisfy the + // assertion. + states.Del(tc.mcpName) + t.Cleanup(func() { states.Del(tc.mcpName) }) + + sess, err := createSession(t.Context(), tc.mcpName, tc.cfg, r) + require.Error(t, err) + require.Nil(t, sess) + require.Contains(t, err.Error(), tc.wantErrContains) + + info, ok := GetState(tc.mcpName) + require.True(t, ok, "state entry must be written for %q", tc.mcpName) + require.Equal(t, StateError, info.State, "expected StateError, got %s", info.State) + require.Error(t, info.Error, "state must carry the failure error") + require.Contains(t, info.Error.Error(), tc.wantErrContains) + require.Nil(t, info.Client, "no client session on failure") + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index d4870e7e063c87a070e8643c8c327f135ded7125..eeb6921c59af3ca960d1360ddff39b2b2babe7a3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "log/slog" "maps" "net/http" "net/url" @@ -15,7 +14,6 @@ import ( "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" "github.com/invopop/jsonschema" @@ -304,25 +302,96 @@ func (l LSPs) Sorted() []LSP { return sorted } -func (l LSPConfig) ResolvedEnv() []string { - return resolveEnvs(l.Env) -} - -func (m MCPConfig) ResolvedEnv() []string { - return resolveEnvs(m.Env) -} - -func (m MCPConfig) ResolvedHeaders() map[string]string { - resolver := NewShellVariableResolver(env.New()) - for e, v := range m.Headers { - var err error - m.Headers[e], err = resolver.ResolveValue(v) +// ResolvedEnv returns m.Env with every value expanded through the +// given resolver. The returned slice is of the form "KEY=value" sorted +// by key so callers get deterministic output; the receiver's Env map is +// not mutated. On the first resolution failure it returns nil and an +// error that identifies the offending key; the inner resolver error is +// already sanitized by ResolveValue and is wrapped with %w so +// errors.Is/As continues to work. Callers are expected to surface it +// (for MCP, via StateError on the status card) rather than silently +// spawn the server with an empty credential. +// +// The resolver choice matters: in server mode pass the shell resolver +// so $VAR / $(cmd) expand; in client mode pass IdentityResolver so the +// template is forwarded verbatim and expansion happens on the server. +func (m MCPConfig) ResolvedEnv(r VariableResolver) ([]string, error) { + return resolveEnvs(m.Env, r) +} + +// ResolvedArgs returns m.Args with every element expanded through the +// given resolver. A fresh slice is allocated; m.Args is never mutated. +// On the first resolution failure it returns nil and an error +// identifying the offending positional index; the inner resolver error +// is already sanitized by ResolveValue and is wrapped with %w so +// errors.Is/As continues to work. +// +// See ResolvedEnv for guidance on picking a resolver. +func (m MCPConfig) ResolvedArgs(r VariableResolver) ([]string, error) { + if len(m.Args) == 0 { + return nil, nil + } + out := make([]string, len(m.Args)) + for i, a := range m.Args { + v, err := r.ResolveValue(a) if err != nil { - slog.Error("Error resolving header variable", "error", err, "variable", e, "value", v) - continue + return nil, fmt.Errorf("arg %d: %w", i, err) } + out[i] = v + } + return out, nil +} + +// ResolvedURL returns m.URL expanded through the given resolver. The +// receiver is not mutated. Errors from the resolver are already +// sanitized by ResolveValue and are wrapped with %w for errors.Is/As. +// +// URLs run through the same shell-expansion pipeline as the other +// fields, so a literal '$' (e.g. OData query strings containing +// $filter/$select) must be escaped as '\$' or '${DOLLAR:-$}' to avoid +// being interpreted as a variable reference. Same constraint already +// applies to command, args, env, and headers. +// +// See ResolvedEnv for guidance on picking a resolver. +func (m MCPConfig) ResolvedURL(r VariableResolver) (string, error) { + if m.URL == "" { + return "", nil + } + v, err := r.ResolveValue(m.URL) + if err != nil { + return "", fmt.Errorf("url: %w", err) + } + return v, nil +} + +// ResolvedHeaders returns m.Headers with every value expanded through +// the given resolver. A fresh map is allocated; m.Headers is never +// mutated. On the first resolution failure it returns nil and an error +// identifying the offending header name; the inner resolver error is +// already sanitized by ResolveValue and is wrapped with %w so +// errors.Is/As continues to work. +// +// See ResolvedEnv for guidance on picking a resolver. +func (m MCPConfig) ResolvedHeaders(r VariableResolver) (map[string]string, error) { + if len(m.Headers) == 0 { + return map[string]string{}, nil + } + out := make(map[string]string, len(m.Headers)) + // Sort keys so failures are reported deterministically when more + // than one header would fail. + keys := make([]string, 0, len(m.Headers)) + for k := range m.Headers { + keys = append(keys, k) } - return m.Headers + slices.Sort(keys) + for _, k := range keys { + v, err := r.ResolveValue(m.Headers[k]) + if err != nil { + return nil, fmt.Errorf("header %s: %w", k, err) + } + out[k] = v + } + return out, nil } type Agent struct { @@ -662,22 +731,29 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { return nil } -func resolveEnvs(envs map[string]string) []string { - resolver := NewShellVariableResolver(env.New()) - for e, v := range envs { - var err error - envs[e], err = resolver.ResolveValue(v) - if err != nil { - slog.Error("Error resolving environment variable", "error", err, "variable", e, "value", v) - continue - } +// resolveEnvs expands every value in envs through the given resolver +// and returns a fresh "KEY=value" slice sorted by key. The input map is +// not mutated. On the first resolution failure it returns nil and an +// error identifying the offending variable; the inner resolver error is +// already sanitized by ResolveValue and is wrapped with %w. +func resolveEnvs(envs map[string]string, r VariableResolver) ([]string, error) { + if len(envs) == 0 { + return nil, nil } - + keys := make([]string, 0, len(envs)) + for k := range envs { + keys = append(keys, k) + } + slices.Sort(keys) res := make([]string, 0, len(envs)) - for k, v := range envs { + for _, k := range keys { + v, err := r.ResolveValue(envs[k]) + if err != nil { + return nil, fmt.Errorf("env %s: %w", k, err) + } res = append(res, fmt.Sprintf("%s=%s", k, v)) } - return res + return res, nil } func ptrValOr[T any](t *T, el T) T { diff --git a/internal/config/load.go b/internal/config/load.go index 967af2de8cd6bffd6a8bd08e77b562998a8d1913..b10d6fa643a54ed0ad64174cea92bebd6284a7d8 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -222,6 +222,12 @@ func (c *Config) configureProviders(store *ConfigStore, env env.Env, resolver Va if len(config.ExtraHeaders) > 0 { maps.Copy(headers, config.ExtraHeaders) } + // Intentional divergence from MCP env/headers/args/url resolution: + // provider headers that fail to resolve log and keep their literal + // template so providers with optional, env-gated headers still + // load on hosts where those vars are unset. Changing this to a + // hard error would break those configs. See PLAN.md "Design + // decisions" item 4 for the full rationale. for k, v := range headers { resolved, err := resolver.ResolveValue(v) if err != nil { diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 68d52da39fae5000433a47dea2401fd46c193ba3..80551a7749014b6f3262d4b0f2b70c26aabde34d 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -1589,3 +1589,61 @@ func TestConfig_configureProviders_HyperAPIKeyFromConfigOverrides(t *testing.T) require.True(t, ok, "Hyper provider should be configured") require.Equal(t, "env-api-key", pc.APIKey) } + +// TestConfig_configureProviders_ProviderHeaderResolveFailure pins the +// intentional divergence at load.go:225: a provider header whose +// expansion fails must NOT fail the whole config load. It logs an +// error, keeps the literal template in the resolved header map, and +// moves on. The MCP contract (hard error on failed expansion) does not +// apply here because many providers ship DefaultHeaders that reference +// env vars which are legitimately unset on most hosts. +// +// If this test ever flips to requiring an error, read PLAN.md "Design +// decisions" item 4 before changing the production code — the +// divergence is deliberate. +func TestConfig_configureProviders_ProviderHeaderResolveFailure(t *testing.T) { + knownProviders := []catwalk.Provider{ + { + ID: "openai", + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "https://api.openai.com/v1", + Models: []catwalk.Model{{ID: "test-model"}}, + }, + } + + cfg := &Config{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ + "openai": { + ExtraHeaders: map[string]string{ + // Failing $(...) — inner command exits 1, no stdout. + "X-Broken": "$(false)", + // Unset var — nounset makes this an error too. + "X-Missing": "$PROVIDER_HEADER_NEVER_SET", + // Happy path: must still be resolved, proving the + // failure in the other two did not abort the loop. + "X-Static": "kept-literal", + }, + }, + }), + } + cfg.setDefaults("/tmp", "") + + testEnv := env.NewFromMap(map[string]string{ + "OPENAI_API_KEY": "test-key", + "PATH": os.Getenv("PATH"), + }) + resolver := NewShellVariableResolver(testEnv) + + err := cfg.configureProviders(testStore(cfg), testEnv, resolver, knownProviders) + require.NoError(t, err, "provider load must tolerate failing header expansion") + + pc, ok := cfg.Providers.Get("openai") + require.True(t, ok, "openai provider must still be configured") + + // Literal template preserved for the two failing headers; the + // happy-path header is resolved through the shell (pass-through + // for a literal value). + require.Equal(t, "$(false)", pc.ExtraHeaders["X-Broken"]) + require.Equal(t, "$PROVIDER_HEADER_NEVER_SET", pc.ExtraHeaders["X-Missing"]) + require.Equal(t, "kept-literal", pc.ExtraHeaders["X-Static"]) +} diff --git a/internal/config/mcp_resolved_url_test.go b/internal/config/mcp_resolved_url_test.go new file mode 100644 index 0000000000000000000000000000000000000000..97c0d9df4725164a0ca05933da5df35239fdb691 --- /dev/null +++ b/internal/config/mcp_resolved_url_test.go @@ -0,0 +1,88 @@ +package config + +import ( + "errors" + "testing" + + "github.com/charmbracelet/crush/internal/env" + "github.com/stretchr/testify/require" +) + +func TestMCPConfig_ResolvedURL(t *testing.T) { + t.Parallel() + + t.Run("empty url short-circuits without calling resolver", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Type: MCPHttp} + got, err := m.ResolvedURL(stubResolver{err: errors.New("should not be called")}) + require.NoError(t, err) + require.Empty(t, got) + }) + + t.Run("literal url passes through unchanged", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Type: MCPHttp, URL: "https://mcp.example.com/api"} + got, err := m.ResolvedURL(NewShellVariableResolver(env.NewFromMap(nil))) + require.NoError(t, err) + require.Equal(t, "https://mcp.example.com/api", got) + }) + + t.Run("expands $VAR with shell resolver", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Type: MCPHttp, URL: "https://$MCP_HOST/api"} + r := NewShellVariableResolver(env.NewFromMap(map[string]string{"MCP_HOST": "mcp.example.com"})) + got, err := m.ResolvedURL(r) + require.NoError(t, err) + require.Equal(t, "https://mcp.example.com/api", got) + }) + + t.Run("expands $(cmd) with shell resolver", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Type: MCPSSE, URL: "https://$(echo mcp.example.com)/events"} + got, err := m.ResolvedURL(NewShellVariableResolver(env.NewFromMap(nil))) + require.NoError(t, err) + require.Equal(t, "https://mcp.example.com/events", got) + }) + + t.Run("unset var is an error wrapping the template", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Type: MCPHttp, URL: "https://$MCP_MISSING_HOST/api"} + _, err := m.ResolvedURL(NewShellVariableResolver(env.NewFromMap(nil))) + require.Error(t, err) + require.Contains(t, err.Error(), "url:") + require.Contains(t, err.Error(), "$MCP_MISSING_HOST") + require.Contains(t, err.Error(), "unbound") + }) + + t.Run("failing command substitution is an error", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Type: MCPHttp, URL: "https://$(false)/api"} + _, err := m.ResolvedURL(NewShellVariableResolver(env.NewFromMap(nil))) + require.Error(t, err) + require.Contains(t, err.Error(), "url:") + require.Contains(t, err.Error(), "$(false)") + }) + + t.Run("identity resolver round-trips template verbatim", func(t *testing.T) { + t.Parallel() + // In client mode expansion happens server-side; the client must + // forward the template without touching it and without erroring + // on unset vars. + tmpl := "https://$MCP_HOST/$(vault read -f url)" + m := MCPConfig{Type: MCPHttp, URL: tmpl} + got, err := m.ResolvedURL(IdentityResolver()) + require.NoError(t, err) + require.Equal(t, tmpl, got) + }) +} + +// stubResolver returns ("", err) for every call. Paired with a non-nil +// err the empty-URL test asserts ResolvedURL short-circuits before +// reaching ResolveValue: if it didn't, the test would fail with err. +type stubResolver struct { + err error +} + +func (s stubResolver) ResolveValue(v string) (string, error) { + return "", s.err +} diff --git a/internal/config/resolve.go b/internal/config/resolve.go index b9e7753386bb8c95b877b99172897ea4fdb0a045..8c22a8abc7a516cd19be6fc32eaa101d60e416d4 100644 --- a/internal/config/resolve.go +++ b/internal/config/resolve.go @@ -3,13 +3,17 @@ package config import ( "context" "fmt" - "strings" "time" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/shell" ) +// resolveTimeout bounds how long a single ResolveValue call may spend +// inside shell expansion (including any command substitution). Matches +// the timeout used by the previous hand-rolled parser. +const resolveTimeout = 5 * time.Minute + type VariableResolver interface { ResolveValue(value string) (string, error) } @@ -28,141 +32,140 @@ func IdentityResolver() VariableResolver { return identityResolver{} } -type Shell interface { - Exec(ctx context.Context, command string) (stdout, stderr string, err error) +// Expander is the single-value shell expansion seam used by +// shellVariableResolver. Production wires it to shell.ExpandValue; tests +// can inject a fake via WithExpander. +type Expander func(ctx context.Context, value string, env []string) (string, error) + +// ShellResolverOption customizes shell variable resolver construction. +type ShellResolverOption func(*shellVariableResolver) + +// WithExpander overrides the expansion function used by the resolver. +// Primarily intended for tests; production callers should not need this. +func WithExpander(e Expander) ShellResolverOption { + return func(r *shellVariableResolver) { + if e != nil { + r.expand = e + } + } } type shellVariableResolver struct { - shell Shell - env env.Env + env env.Env + expand Expander } -func NewShellVariableResolver(env env.Env) VariableResolver { - return &shellVariableResolver{ - env: env, - shell: shell.NewShell( - &shell.Options{ - Env: env.Env(), - }, - ), +// NewShellVariableResolver returns a VariableResolver that delegates to +// the embedded shell (the same interpreter used by the bash tool and +// hooks). Supported constructs match shell.ExpandValue: $VAR, ${VAR}, +// ${VAR:-default}, $(command), quoting, and escapes. Unset variables are +// an error; use ${VAR:-} to opt in to an empty fallback. +func NewShellVariableResolver(e env.Env, opts ...ShellResolverOption) VariableResolver { + r := &shellVariableResolver{ + env: e, + expand: shell.ExpandValue, } + for _, opt := range opts { + opt(r) + } + return r } -// ResolveValue is a method for resolving values, such as environment variables. -// it will resolve shell-like variable substitution anywhere in the string, including: -// - $(command) for command substitution -// - $VAR or ${VAR} for environment variables +// ResolveValue resolves shell-style substitution anywhere in the string: +// +// - $(command) for command substitution, with full quoting and nesting. +// - $VAR and ${VAR} for environment variables. +// - ${VAR:-default} / ${VAR:+alt} / ${VAR:?msg} for defaulting. +// +// Unset variables are a hard error (nounset), mirroring the historical +// behaviour of this resolver: silently expanding an unset variable to the +// empty string is exactly how broken credentials reach MCP servers. func (r *shellVariableResolver) ResolveValue(value string) (string, error) { - // Special case: lone $ is an error (backward compatibility) + // Preserve the historical backward-compat contract: a lone "$" is a + // malformed config value, not a legal literal. The underlying shell + // parser would accept it as a literal; we reject it here so existing + // configs that relied on this validation still fail early. if value == "$" { return "", fmt.Errorf("invalid value format: %s", value) } - // If no $ found, return as-is - if !strings.Contains(value, "$") { - return value, nil - } - - result := value + ctx, cancel := context.WithTimeout(context.Background(), resolveTimeout) + defer cancel() - // Handle command substitution: $(command) - for { - start := strings.Index(result, "$(") - if start == -1 { - break - } + out, err := r.expand(ctx, value, r.env.Env()) + if err != nil { + return "", sanitizeResolveError(value, err) + } + return out, nil +} - // Find matching closing parenthesis - depth := 0 - end := -1 - for i := start + 2; i < len(result); i++ { - if result[i] == '(' { - depth++ - } else if result[i] == ')' { - if depth == 0 { - end = i - break - } - depth-- - } - } +// maxResolveErrBytes bounds the size of the inner error message surfaced +// from a resolution failure. Defense-in-depth on top of shell.ExpandValue's +// own stderr budget: a custom Expander injected via WithExpander, or any +// future non-shell error path, must still produce a user-safe message. +const maxResolveErrBytes = 512 + +// sanitizeResolveError wraps an expansion error with the user-written +// template (the pre-expansion string — it is what they typed, safe to +// surface) and a bounded, scrubbed rendering of the inner error message. +// Contract: +// +// - Never includes the resolved (post-expansion) value. This helper +// only receives the template and err, so a successful expansion +// result cannot reach it. +// - May include the template verbatim. +// - Truncates the inner error's message to maxResolveErrBytes and +// replaces embedded NULs and other non-printables (except tab and +// newline) with '?'. +// +// The returned error still unwraps to the original for errors.Is/As so +// callers can inspect typed sentinels; only the rendered message is +// scrubbed. +func sanitizeResolveError(template string, err error) error { + if err == nil { + return nil + } + return &resolveError{ + template: template, + msg: scrubErrorMessage(err.Error()), + inner: err, + } +} - if end == -1 { - return "", fmt.Errorf("unmatched $( in value: %s", value) - } +// resolveError is the concrete type returned by sanitizeResolveError. +// Its Error() method returns the template + scrubbed inner message; +// Unwrap exposes the original error so errors.Is/As continue to work. +type resolveError struct { + template string + msg string + inner error +} - command := result[start+2 : end] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) +func (e *resolveError) Error() string { + return fmt.Sprintf("resolving %q: %s", e.template, e.msg) +} - stdout, _, err := r.shell.Exec(ctx, command) - cancel() - if err != nil { - return "", fmt.Errorf("command execution failed for '%s': %w", command, err) - } +func (e *resolveError) Unwrap() error { return e.inner } - // Replace the $(command) with the output - replacement := strings.TrimSpace(stdout) - result = result[:start] + replacement + result[end+1:] +// scrubErrorMessage bounds the message to maxResolveErrBytes bytes and +// replaces non-printable bytes (anything outside ASCII printable, tab, or +// newline) with '?'. Mirrors shell.sanitizeStderr but operates on a +// string rather than raw command stderr and runs at the config layer, +// so arbitrary Expander error text is also sanitized. +func scrubErrorMessage(s string) string { + if len(s) > maxResolveErrBytes { + s = s[:maxResolveErrBytes] } - - // Handle environment variables: $VAR and ${VAR} - searchStart := 0 - for { - start := strings.Index(result[searchStart:], "$") - if start == -1 { - break - } - start += searchStart // Adjust for the offset - - // Skip if this is part of $( which we already handled - if start+1 < len(result) && result[start+1] == '(' { - // Skip past this $(...) - searchStart = start + 1 + out := make([]byte, len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if c == '\t' || c == '\n' || (c >= 0x20 && c < 0x7f) { + out[i] = c continue } - var varName string - var end int - - if start+1 < len(result) && result[start+1] == '{' { - // Handle ${VAR} format - closeIdx := strings.Index(result[start+2:], "}") - if closeIdx == -1 { - return "", fmt.Errorf("unmatched ${ in value: %s", value) - } - varName = result[start+2 : start+2+closeIdx] - end = start + 2 + closeIdx + 1 - } else { - // Handle $VAR format - variable names must start with letter or underscore - if start+1 >= len(result) { - return "", fmt.Errorf("incomplete variable reference at end of string: %s", value) - } - - if result[start+1] != '_' && - (result[start+1] < 'a' || result[start+1] > 'z') && - (result[start+1] < 'A' || result[start+1] > 'Z') { - return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value) - } - - end = start + 1 - for end < len(result) && (result[end] == '_' || - (result[end] >= 'a' && result[end] <= 'z') || - (result[end] >= 'A' && result[end] <= 'Z') || - (result[end] >= '0' && result[end] <= '9')) { - end++ - } - varName = result[start+1 : end] - } - - envValue := r.env.Get(varName) - if envValue == "" { - return "", fmt.Errorf("environment variable %q not set", varName) - } - - result = result[:start] + envValue + result[end:] - searchStart = start + len(envValue) // Continue searching after the replacement + out[i] = '?' } - - return result, nil + return string(out) } type environmentVariableResolver struct { @@ -177,11 +180,11 @@ func NewEnvironmentVariableResolver(env env.Env) VariableResolver { // ResolveValue resolves environment variables from the provided env.Env. func (r *environmentVariableResolver) ResolveValue(value string) (string, error) { - if !strings.HasPrefix(value, "$") { + if len(value) == 0 || value[0] != '$' { return value, nil } - varName := strings.TrimPrefix(value, "$") + varName := value[1:] resolvedValue := r.env.Get(varName) if resolvedValue == "" { return "", fmt.Errorf("environment variable %q not set", varName) diff --git a/internal/config/resolve_real_test.go b/internal/config/resolve_real_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f33435e9615f44a09c9e2ce1692c3b3aa3ebf276 --- /dev/null +++ b/internal/config/resolve_real_test.go @@ -0,0 +1,310 @@ +package config + +import ( + "fmt" + "maps" + "os" + "path/filepath" + "slices" + "strings" + "testing" + + "github.com/charmbracelet/crush/internal/env" + "github.com/stretchr/testify/require" +) + +// These tests exercise the full shell-expansion path (no mocks, +// no injected Expander) to catch regressions that only surface when +// internal/shell actually runs the value. Table-level unit tests with +// fake expanders live in resolve_test.go. + +// realShellResolver builds a resolver backed by a shell env that +// contains PATH + the caller-supplied overrides. Production callers +// get PATH for free via env.New(); these tests need it so $(cat ...) +// and similar inner commands can resolve. +func realShellResolver(vars map[string]string) VariableResolver { + m := map[string]string{"PATH": os.Getenv("PATH")} + maps.Copy(m, vars) + return NewShellVariableResolver(env.NewFromMap(m)) +} + +func writeTempFile(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + p := filepath.Join(dir, "secret") + require.NoError(t, os.WriteFile(p, []byte(content), 0o600)) + return p +} + +// TestResolvedEnv_RealShell_Success covers the shell constructs the +// PLAN calls out: $(cat tempfile) with and without trailing newline, +// ${VAR:-default} for unset vars, literal-space preservation around +// $(...), nested parens, quoted args inside $(echo ...), and a +// glob-like literal round-tripping unchanged. +func TestResolvedEnv_RealShell_Success(t *testing.T) { + t.Parallel() + + // filepath.ToSlash so Windows temp paths (C:\Users\...) survive + // being injected into a shell command string — the embedded shell + // treats backslashes as escapes, forward slashes work on every OS. + withNL := filepath.ToSlash(writeTempFile(t, "token-with-nl\n")) + noNL := filepath.ToSlash(writeTempFile(t, "token-no-nl")) + + m := MCPConfig{ + Env: map[string]string{ + // POSIX strips exactly one trailing newline from $(...) + // output, so both forms land on the same value. + "TOK_NL": fmt.Sprintf("$(cat %s)", withNL), + "TOK_NO": fmt.Sprintf("$(cat %s)", noNL), + + // ${VAR:-default} must not error on unset: this is the + // opt-in escape hatch for "empty is fine". + "FALLBACK": "${MCP_MISSING:-fallback}", + + // Leading/trailing literal spaces around $(...) must be + // preserved — single-value contract, no field splitting. + "PADDED": " $(echo v) ", + + // ")" inside a quoted arg to echo is a regression guard + // for the old hand-rolled paren matcher. + "PAREN": `$(echo ")")`, + + // Embedded space inside a quoted arg must survive + // verbatim; no word-splitting side effect. + "SPACEY": `$(echo "a b")`, + + // Glob-like literals must not expand. + "GLOB": "*.go", + }, + } + + got, err := m.ResolvedEnv(realShellResolver(nil)) + require.NoError(t, err) + + // ResolvedEnv returns "KEY=value" sorted by key. + want := []string{ + "FALLBACK=fallback", + "GLOB=*.go", + "PADDED= v ", + "PAREN=)", + "SPACEY=a b", + "TOK_NL=token-with-nl", + "TOK_NO=token-no-nl", + } + require.Equal(t, want, got) +} + +// TestResolvedEnv_RealShell_DoesNotMutate pins that both success and +// failure paths leave m.Env untouched. Prior behaviour rewrote the +// value in place on error; that was the exact mechanism that shipped +// empty credentials to MCP servers. +func TestResolvedEnv_RealShell_DoesNotMutate(t *testing.T) { + t.Parallel() + + t.Run("success path leaves Env untouched", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Env: map[string]string{"TOKEN": "$(echo shh)"}} + orig := maps.Clone(m.Env) + + _, err := m.ResolvedEnv(realShellResolver(nil)) + require.NoError(t, err) + require.Equal(t, orig, m.Env) + }) + + t.Run("failure path leaves Env untouched", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Env: map[string]string{"BROKEN": "$(false)"}} + orig := maps.Clone(m.Env) + + _, err := m.ResolvedEnv(realShellResolver(nil)) + require.Error(t, err) + require.Equal(t, orig, m.Env, "map must be preserved on error") + }) +} + +// TestResolvedEnv_RealShell_Idempotent pins the pure-function contract: +// two calls on the same config return deeply-equal slices. +func TestResolvedEnv_RealShell_Idempotent(t *testing.T) { + t.Parallel() + + m := MCPConfig{ + Env: map[string]string{ + "A": "$(echo one)", + "B": "$(echo two)", + "C": "literal", + }, + } + r := realShellResolver(nil) + + first, err := m.ResolvedEnv(r) + require.NoError(t, err) + second, err := m.ResolvedEnv(r) + require.NoError(t, err) + require.Equal(t, first, second) +} + +// TestResolvedEnv_RealShell_Deterministic guards against Go's +// randomized map iteration leaking into the returned slice order. +func TestResolvedEnv_RealShell_Deterministic(t *testing.T) { + t.Parallel() + + m := MCPConfig{Env: map[string]string{ + "Z": "z", + "A": "a", + "M": "m", + }} + + got, err := m.ResolvedEnv(realShellResolver(nil)) + require.NoError(t, err) + require.True(t, slices.IsSorted(got), "env slice must be sorted; got %v", got) +} + +// TestResolvedEnv_RealShell_NounsetRegression is the single most +// important assertion in this file: an unset variable is an error, not +// an empty expansion. Swapping the hand-rolled parser for mvdan's +// default-expansion behaviour without nounset would re-introduce +// Defect 1 via a typo'd variable name. +func TestResolvedEnv_RealShell_NounsetRegression(t *testing.T) { + t.Parallel() + + m := MCPConfig{Env: map[string]string{ + // Intentional typo: user meant $FORGEJO_TOKEN. + "FORGEJO_ACCESS_TOKEN": "$FORGJO_TOKEN", + }} + got, err := m.ResolvedEnv(realShellResolver(nil)) + require.Error(t, err, "unset var must not silently expand to empty") + require.Nil(t, got) + require.Contains(t, err.Error(), "FORGEJO_ACCESS_TOKEN") + require.Contains(t, err.Error(), "$FORGJO_TOKEN") +} + +// TestResolvedEnv_RealShell_FailureDetail pins that a failing inner +// command surfaces enough detail (exit code + stderr on POSIX, the +// underlying OS error on Windows where coreutils runs in-process) to +// diagnose without forcing the user to re-run the command by hand. +// Also verifies the template is included so they know which Env +// entry blew up. +func TestResolvedEnv_RealShell_FailureDetail(t *testing.T) { + t.Parallel() + + // Forward slashes so the path survives shell-string injection on + // Windows; see TestResolvedEnv_RealShell_Success for the same note. + missing := filepath.ToSlash(filepath.Join(t.TempDir(), "definitely-not-here")) + m := MCPConfig{Env: map[string]string{ + "FORGEJO_ACCESS_TOKEN": fmt.Sprintf("$(cat %s)", missing), + }} + + _, err := m.ResolvedEnv(realShellResolver(nil)) + require.Error(t, err) + msg := err.Error() + require.Contains(t, msg, "FORGEJO_ACCESS_TOKEN", "must identify the failing env var") + require.Contains(t, msg, missing, "must include the template so users see what failed") + + // Inner diagnostic detail must survive. POSIX surfaces "exit + // status N" + stderr; Windows' in-process coreutils surfaces the + // Go OS error instead. Accept either shape so the test is + // portable without weakening the intent. + lower := strings.ToLower(msg) + hasDetail := strings.Contains(lower, "exit status") || + strings.Contains(lower, "no such file") || + strings.Contains(lower, "cannot find") + require.True(t, hasDetail, "must surface inner error detail: %s", msg) +} + +// TestResolvedHeaders_RealShell_FailurePreservesOriginal pins two +// invariants simultaneously: on failure the returned map is nil (not +// a partially-populated map) and the receiver's Headers map is +// unchanged. A test that only asserted on the returned value could +// hide an in-place mutation regression. +func TestResolvedHeaders_RealShell_FailurePreservesOriginal(t *testing.T) { + t.Parallel() + + m := MCPConfig{Headers: map[string]string{ + "Authorization": "Bearer $(false)", + "X-Static": "kept", + }} + orig := maps.Clone(m.Headers) + + got, err := m.ResolvedHeaders(realShellResolver(nil)) + require.Error(t, err) + require.Nil(t, got, "headers map must be nil on failure") + require.Contains(t, err.Error(), "header Authorization") + require.Equal(t, orig, m.Headers, "receiver Headers must be preserved") +} + +// TestResolvedArgs_RealShell exercises both success and failure for +// m.Args symmetrically with Env. Args are ordered so error messages +// must identify a positional index, not a key. +func TestResolvedArgs_RealShell(t *testing.T) { + t.Parallel() + + t.Run("success expands each element", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Args: []string{"--token", "$(echo shh)", "--host", "example.com"}} + got, err := m.ResolvedArgs(realShellResolver(nil)) + require.NoError(t, err) + require.Equal(t, []string{"--token", "shh", "--host", "example.com"}, got) + }) + + t.Run("failure identifies offending index", func(t *testing.T) { + t.Parallel() + m := MCPConfig{Args: []string{"--token", "$(false)"}} + orig := slices.Clone(m.Args) + + got, err := m.ResolvedArgs(realShellResolver(nil)) + require.Error(t, err) + require.Nil(t, got) + require.Contains(t, err.Error(), "arg 1") + require.Equal(t, orig, m.Args, "receiver Args must be preserved") + }) + + t.Run("nil args returns nil, no error", func(t *testing.T) { + t.Parallel() + m := MCPConfig{} + got, err := m.ResolvedArgs(realShellResolver(nil)) + require.NoError(t, err) + require.Nil(t, got) + }) +} + +// TestMCPConfig_IdentityResolver pins the client-mode contract: every +// Resolved* method round-trips the template verbatim and never errors +// on unset variables. Local expansion would double-expand when the +// server does its own — this has to stay a pure pass-through. +func TestMCPConfig_IdentityResolver(t *testing.T) { + t.Parallel() + + m := MCPConfig{ + Command: "$CMD", + Args: []string{"--token", "$MCP_MISSING_TOKEN", "$(vault read -f secret)"}, + Env: map[string]string{ + "TOKEN": "$(cat /run/secrets/x)", + "HOST": "$MCP_MISSING_HOST", + }, + Headers: map[string]string{ + "Authorization": "Bearer $(vault read -f token)", + }, + URL: "https://$MCP_HOST/$(vault read -f path)", + } + r := IdentityResolver() + + args, err := m.ResolvedArgs(r) + require.NoError(t, err) + require.Equal(t, m.Args, args) + + envs, err := m.ResolvedEnv(r) + require.NoError(t, err) + // Sorted "KEY=value". + require.Equal(t, []string{ + "HOST=$MCP_MISSING_HOST", + "TOKEN=$(cat /run/secrets/x)", + }, envs) + + headers, err := m.ResolvedHeaders(r) + require.NoError(t, err) + require.Equal(t, m.Headers, headers) + + u, err := m.ResolvedURL(r) + require.NoError(t, err) + require.Equal(t, m.URL, u) +} diff --git a/internal/config/resolve_test.go b/internal/config/resolve_test.go index ec9b06c25bdc023acebffc71f043b54a8da21597..18b461b8a6bc88f73d6db935eebd302cc3514781 100644 --- a/internal/config/resolve_test.go +++ b/internal/config/resolve_test.go @@ -3,260 +3,221 @@ package config import ( "context" "errors" + "strings" "testing" "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/require" ) -// mockShell implements the Shell interface for testing -type mockShell struct { - execFunc func(ctx context.Context, command string) (stdout, stderr string, err error) +// fakeExpander returns a canned value/error for the last passed value and +// records the context, raw value, and env slice it was called with. It +// lets the config-layer tests assert on delegation behaviour without +// spinning up a real interpreter — real-shell coverage lives in +// internal/shell/expand_test.go and resolve_real_test.go. +type fakeExpander struct { + expand func(ctx context.Context, value string, env []string) (string, error) + lastValue string + lastEnv []string + calls int } -func (m *mockShell) Exec(ctx context.Context, command string) (stdout, stderr string, err error) { - if m.execFunc != nil { - return m.execFunc(ctx, command) +func (f *fakeExpander) Expand(ctx context.Context, value string, env []string) (string, error) { + f.calls++ + f.lastValue = value + f.lastEnv = env + if f.expand == nil { + return value, nil } - return "", "", nil + return f.expand(ctx, value, env) } -func TestShellVariableResolver_ResolveValue(t *testing.T) { - tests := []struct { - name string - value string - envVars map[string]string - shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) - expected string - expectError bool - }{ - { - name: "non-variable string returns as-is", - value: "plain-string", - expected: "plain-string", - }, - { - name: "environment variable resolution", - value: "$HOME", - envVars: map[string]string{"HOME": "/home/user"}, - expected: "/home/user", - }, - { - name: "missing environment variable returns error", - value: "$MISSING_VAR", - envVars: map[string]string{}, - expectError: true, - }, +func TestShellVariableResolver_DelegatesToExpander(t *testing.T) { + t.Parallel() - { - name: "shell command with whitespace trimming", - value: "$(echo ' spaced ')", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "echo ' spaced '" { - return " spaced \n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "spaced", - }, - { - name: "shell command execution error", - value: "$(false)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - return "", "", errors.New("command failed") - }, - expectError: true, - }, - { - name: "invalid format returns error", - value: "$", - expectError: true, + fe := &fakeExpander{ + expand: func(_ context.Context, value string, _ []string) (string, error) { + if value == "hello $FOO" { + return "hello bar", nil + } + return value, nil }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testEnv := env.NewFromMap(tt.envVars) - resolver := &shellVariableResolver{ - shell: &mockShell{execFunc: tt.shellFunc}, - env: testEnv, - } + e := env.NewFromMap(map[string]string{"FOO": "bar"}) + r := NewShellVariableResolver(e, WithExpander(fe.Expand)) - result, err := resolver.ResolveValue(tt.value) + got, err := r.ResolveValue("hello $FOO") + require.NoError(t, err) + require.Equal(t, "hello bar", got) + require.Equal(t, 1, fe.calls) + require.Equal(t, "hello $FOO", fe.lastValue) + require.Contains(t, fe.lastEnv, "FOO=bar") +} - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expected, result) - } - }) - } +func TestShellVariableResolver_LoneDollarIsError(t *testing.T) { + t.Parallel() + + // Lone "$" must short-circuit before reaching the expander: the + // underlying shell parser would accept it as a literal, but this + // resolver has historically rejected it and callers depend on + // that early-fail behaviour. + fe := &fakeExpander{} + r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand)) + + _, err := r.ResolveValue("$") + require.Error(t, err) + require.Equal(t, 0, fe.calls, "expander must not be called for lone $") } -func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { - tests := []struct { - name string - value string - envVars map[string]string - shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) - expected string - expectError bool - }{ - { - name: "command substitution within string", - value: "Bearer $(echo token123)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "echo token123" { - return "token123\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "Bearer token123", - }, - { - name: "environment variable within string", - value: "Bearer $TOKEN", - envVars: map[string]string{"TOKEN": "sk-ant-123"}, - expected: "Bearer sk-ant-123", - }, - { - name: "environment variable with braces within string", - value: "Bearer ${TOKEN}", - envVars: map[string]string{"TOKEN": "sk-ant-456"}, - expected: "Bearer sk-ant-456", - }, - { - name: "mixed command and environment substitution", - value: "$USER-$(date +%Y)-$HOST", - envVars: map[string]string{ - "USER": "testuser", - "HOST": "localhost", - }, - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "date +%Y" { - return "2024\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "testuser-2024-localhost", - }, - { - name: "multiple command substitutions", - value: "$(echo hello) $(echo world)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - switch command { - case "echo hello": - return "hello\n", "", nil - case "echo world": - return "world\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "hello world", - }, - { - name: "nested parentheses in command", - value: "$(echo $(echo inner))", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "echo $(echo inner)" { - return "nested\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "nested", - }, - { - name: "lone dollar with non-variable chars", - value: "prefix$123suffix", // Numbers can't start variable names - expectError: true, - }, - { - name: "dollar with special chars", - value: "a$@b$#c", // Special chars aren't valid in variable names - expectError: true, - }, - { - name: "empty environment variable substitution", - value: "Bearer $EMPTY_VAR", - envVars: map[string]string{}, - expectError: true, - }, - { - name: "unmatched command substitution opening", - value: "Bearer $(echo test", - expectError: true, - }, - { - name: "unmatched environment variable braces", - value: "Bearer ${TOKEN", - expectError: true, - }, - { - name: "command substitution with error", - value: "Bearer $(false)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - return "", "", errors.New("command failed") - }, - expectError: true, - }, - { - name: "complex real-world example", - value: "Bearer $(cat /tmp/token.txt | base64 -w 0)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "cat /tmp/token.txt | base64 -w 0" { - return "c2stYW50LXRlc3Q=\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "Bearer c2stYW50LXRlc3Q=", - }, - { - name: "environment variable with underscores and numbers", - value: "Bearer $API_KEY_V2", - envVars: map[string]string{"API_KEY_V2": "sk-test-123"}, - expected: "Bearer sk-test-123", - }, - { - name: "no substitution needed", - value: "Bearer sk-ant-static-token", - expected: "Bearer sk-ant-static-token", - }, - { - name: "incomplete variable at end", - value: "Bearer $", - expectError: true, - }, - { - name: "variable with invalid character", - value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names - expectError: true, - }, - { - name: "multiple invalid variables", - value: "$1$2$3", - expectError: true, +func TestShellVariableResolver_PassesThroughLiterals(t *testing.T) { + t.Parallel() + + fe := &fakeExpander{ + expand: func(_ context.Context, value string, _ []string) (string, error) { + return value, nil }, } + r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand)) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testEnv := env.NewFromMap(tt.envVars) - resolver := &shellVariableResolver{ - shell: &mockShell{execFunc: tt.shellFunc}, - env: testEnv, - } + got, err := r.ResolveValue("plain-string") + require.NoError(t, err) + require.Equal(t, "plain-string", got) +} - result, err := resolver.ResolveValue(tt.value) +func TestShellVariableResolver_WrapsErrorsWithTemplate(t *testing.T) { + t.Parallel() - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expected, result) - } - }) + inner := errors.New("cat: /run/secrets/x: permission denied") + fe := &fakeExpander{ + expand: func(_ context.Context, _ string, _ []string) (string, error) { + return "", inner + }, } + r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand)) + + _, err := r.ResolveValue("$(cat /run/secrets/x)") + require.Error(t, err) + require.ErrorIs(t, err, inner) + require.Contains(t, err.Error(), "$(cat /run/secrets/x)") + require.Contains(t, err.Error(), "permission denied") +} + +func TestSanitizeResolveError(t *testing.T) { + t.Parallel() + + t.Run("nil passes through", func(t *testing.T) { + t.Parallel() + require.NoError(t, sanitizeResolveError("anything", nil)) + }) + + t.Run("includes template and wraps inner", func(t *testing.T) { + t.Parallel() + inner := errors.New("cat: /run/secrets/x: permission denied") + got := sanitizeResolveError("$(cat /run/secrets/x)", inner) + require.Error(t, got) + require.ErrorIs(t, got, inner) + require.Contains(t, got.Error(), "$(cat /run/secrets/x)") + require.Contains(t, got.Error(), "permission denied") + }) + + t.Run("unwrap preserves original for errors.Is", func(t *testing.T) { + t.Parallel() + inner := errors.New("sentinel") + got := sanitizeResolveError("$FOO", inner) + require.ErrorIs(t, got, inner) + }) + + t.Run("truncates over-budget inner message", func(t *testing.T) { + t.Parallel() + // Inner message holds far more than the budget. After + // sanitization the rendered inner portion must not exceed + // maxResolveErrBytes, and the characters beyond the budget + // (marked by a distinct tail sentinel) must be gone. + const tailSentinel = "TAIL_SENTINEL_BEYOND_BUDGET" + body := strings.Repeat("x", maxResolveErrBytes) + inner := errors.New(body + tailSentinel) + + got := sanitizeResolveError("$TEMPLATE", inner) + require.Error(t, got) + + prefix := `resolving "$TEMPLATE": ` + rendered := got.Error() + require.True( + t, + strings.HasPrefix(rendered, prefix), + "rendered error must start with template prefix", + ) + innerRendered := strings.TrimPrefix(rendered, prefix) + require.LessOrEqual( + t, + len(innerRendered), + maxResolveErrBytes, + "inner message must be bounded to maxResolveErrBytes", + ) + require.NotContains( + t, + rendered, + tailSentinel, + "content past the budget must not leak", + ) + }) + + t.Run("replaces non-printable bytes", func(t *testing.T) { + t.Parallel() + // NUL, BEL, ESC, DEL, and a UTF-8 high byte should all be + // scrubbed to '?'. Tab and newline are preserved because + // they show up legitimately in command stderr. + inner := errors.New("ok\x00bad\x07\x1b\x7f\xffend\ttab\nline") + got := sanitizeResolveError("$T", inner) + rendered := got.Error() + + require.NotContains(t, rendered, "\x00") + require.NotContains(t, rendered, "\x07") + require.NotContains(t, rendered, "\x1b") + require.NotContains(t, rendered, "\x7f") + require.NotContains(t, rendered, "\xff") + require.Contains(t, rendered, "ok?bad????end\ttab\nline") + }) + + t.Run("scrubbing does not depend on shell.ExpandValue upstream", func(t *testing.T) { + t.Parallel() + // A custom Expander can inject arbitrary error text. The + // config-layer helper is the single chokepoint; it must + // bound + scrub regardless of the error source. + nasty := strings.Repeat("A", maxResolveErrBytes+64) + "\x00BEYOND" + fe := &fakeExpander{ + expand: func(_ context.Context, _ string, _ []string) (string, error) { + return "", errors.New(nasty) + }, + } + r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand)) + + _, err := r.ResolveValue("$T") + require.Error(t, err) + require.NotContains(t, err.Error(), "BEYOND", "over-budget tail must not leak") + require.NotContains(t, err.Error(), "\x00", "non-printables must be scrubbed") + }) +} + +func TestScrubErrorMessage(t *testing.T) { + t.Parallel() + + t.Run("bounds output to maxResolveErrBytes", func(t *testing.T) { + t.Parallel() + got := scrubErrorMessage(strings.Repeat("a", maxResolveErrBytes*3)) + require.Len(t, got, maxResolveErrBytes) + }) + + t.Run("preserves printable ASCII tab and newline", func(t *testing.T) { + t.Parallel() + require.Equal(t, "a\tb\nc d!", scrubErrorMessage("a\tb\nc d!")) + }) + + t.Run("replaces control and non-ASCII bytes", func(t *testing.T) { + t.Parallel() + require.Equal(t, "a?b??c", scrubErrorMessage("a\x01b\x1b\xe2c")) + }) } func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) { diff --git a/internal/shell/expand.go b/internal/shell/expand.go new file mode 100644 index 0000000000000000000000000000000000000000..ee110a1df7b6b60b9ac2d51b7ae5f6abcbf3ac79 --- /dev/null +++ b/internal/shell/expand.go @@ -0,0 +1,123 @@ +package shell + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "strings" + + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" + "mvdan.cc/sh/v3/syntax" +) + +// maxInnerStderrBytes bounds how much stderr from a failing $(...) is +// surfaced in the returned error, to avoid leaking a secret that happened +// to be embedded in a failing inner command. +const maxInnerStderrBytes = 512 + +// ExpandValue expands shell-style substitutions in a single config value. +// +// Supported constructs match the bash tool: +// +// - $VAR and ${VAR} (unset is an error; see nounset below). +// - ${VAR:-default} / ${VAR:+alt} / ${VAR:?msg}. +// - $(command) with full quoting and nesting. +// - escaped and quoted strings ("...", '...'). +// +// Contract: +// +// - Returns exactly one string. No field splitting, no globbing, no +// pathname generation. Multi-word command output is preserved +// verbatim; it is never split into multiple values. +// - Nounset is on: unset variables produce an error instead of +// expanding to the empty string. Use ${VAR:-default} to opt in to +// an empty fallback. +// - Embedded whitespace and newlines in the input are preserved +// verbatim. Command substitution strips trailing newlines only +// (POSIX), never leading or internal whitespace. +// - Errors wrap the failing inner command's exit code and a bounded +// prefix of its stderr. Callers that surface the error to users +// should additionally scrub it for the original template text. +func ExpandValue(ctx context.Context, value string, env []string) (string, error) { + // Parse the value as a here-doc style word: no word splitting, no + // globbing, but full support for $VAR, ${VAR...}, $(...), and + // quoted/escaped strings. + word, err := syntax.NewParser().Document(strings.NewReader(value)) + if err != nil { + return "", fmt.Errorf("parse: %w", err) + } + + // Build a minimal Shell value purely to reuse its handler chain + // (builtins, block funcs, optional Go coreutils) inside $(...). + // We deliberately skip NewShell so the passed-in env is used + // verbatim, with no CRUSH/AGENT/AI_AGENT injection: callers of + // ExpandValue control the env, and nounset must treat any name + // not in env as unset. + cwd, _ := os.Getwd() + s := &Shell{ + cwd: cwd, + env: env, + logger: noopLogger{}, + } + + var stderrBuf bytes.Buffer + cfg := &expand.Config{ + Env: expand.ListEnviron(env...), + NoUnset: true, + CmdSubst: func(w io.Writer, cs *syntax.CmdSubst) error { + stderrBuf.Reset() + runner, rerr := interp.New( + interp.StdIO(nil, w, &stderrBuf), + interp.Interactive(false), + interp.Env(expand.ListEnviron(env...)), + interp.Dir(s.cwd), + interp.ExecHandlers(s.execHandlers()...), + // Match the outer NoUnset: an unset $VAR inside + // $(...) is also an error, not a silent empty. + interp.Params("-u"), + ) + if rerr != nil { + return rerr + } + if rerr := runner.Run(ctx, &syntax.File{Stmts: cs.Stmts}); rerr != nil { + return wrapCmdSubstErr(rerr, stderrBuf.Bytes()) + } + return nil + }, + // ReadDir / ReadDir2 left nil: globbing is disabled. + } + + return expand.Document(cfg, word) +} + +// wrapCmdSubstErr attaches a bounded prefix of the inner command's stderr +// to the original error, if any. +func wrapCmdSubstErr(err error, stderrBytes []byte) error { + msg := sanitizeStderr(stderrBytes) + if msg == "" { + return err + } + return fmt.Errorf("%w: %s", err, msg) +} + +// sanitizeStderr trims, bounds, and scrubs non-printable bytes from the +// stderr of a failing command so the result is safe to include in an +// error message shown to the user. +func sanitizeStderr(b []byte) string { + b = bytes.TrimRight(b, "\n") + if len(b) > maxInnerStderrBytes { + b = b[:maxInnerStderrBytes] + } + out := make([]byte, len(b)) + for i, c := range b { + if c == '\t' || c == '\n' || (c >= 0x20 && c < 0x7f) { + out[i] = c + } else { + out[i] = '?' + } + } + return string(out) +} diff --git a/internal/shell/expand_test.go b/internal/shell/expand_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8661d6682370d31666185576db7edddb17a65b97 --- /dev/null +++ b/internal/shell/expand_test.go @@ -0,0 +1,199 @@ +package shell + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExpandValue_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + env []string + want string + }{ + { + name: "plain string round trip", + value: "hello world", + want: "hello world", + }, + { + name: "plain var from env", + value: "$FOO", + env: []string{"FOO=bar"}, + want: "bar", + }, + { + name: "braced var from env", + value: "pre-${FOO}-post", + env: []string{"FOO=bar"}, + want: "pre-bar-post", + }, + { + name: "default syntax on unset", + value: "${MISSING:-fallback}", + want: "fallback", + }, + { + name: "default syntax on set preserves value", + value: "${SET:-fallback}", + env: []string{"SET=used"}, + want: "used", + }, + { + name: "command substitution", + value: "$(echo hi)", + want: "hi", + }, + { + name: "command substitution preserves internal spaces", + value: `$(echo "a b")`, + want: "a b", + }, + { + name: "command substitution strips only trailing newline", + value: "$(printf 'a\\nb\\n')", + want: "a\nb", + }, + { + name: "literal spaces around cmdsubst are preserved", + value: " $(echo v) ", + want: " v ", + }, + { + name: "paren inside quoted arg to echo", + value: `$(echo ")")`, + want: ")", + }, + { + name: "nested command substitution", + value: "$(echo $(echo hi))", + want: "hi", + }, + { + name: "glob-like input round trips unchanged", + value: "*.go", + want: "*.go", + }, + { + name: "tilde round trips unchanged", + value: "~", + want: "~", + }, + { + name: "env var inside cmdsubst", + value: `$(printf '%s' "$FOO")`, + env: []string{"FOO=bar"}, + want: "bar", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := ExpandValue(t.Context(), tc.value, tc.env) + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestExpandValue_Errors(t *testing.T) { + t.Parallel() + + t.Run("unset var is an error, not empty", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "$MISSING", nil) + require.Error(t, err) + }) + + t.Run("unset var inside braces is an error", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "${MISSING}", nil) + require.Error(t, err) + }) + + t.Run("unset var inside cmdsubst is an error", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), `$(printf '%s' "$MISSING")`, nil) + require.Error(t, err) + }) + + t.Run("bad syntax returns error", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "$(", nil) + require.Error(t, err) + }) + + t.Run("inner non-zero exit returns error with exit code", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "$(false)", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "exit status 1") + }) + + t.Run("inner explicit exit code is surfaced", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "$(exit 7)", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "exit status 7") + }) + + t.Run("inner stderr is surfaced", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue( + t.Context(), + `$(printf 'boom\n' 1>&2; exit 1)`, + nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "boom") + }) + + t.Run("inner stderr is truncated to byte budget", func(t *testing.T) { + t.Parallel() + // Emit more than maxInnerStderrBytes bytes of 'X' on stderr. + long := strings.Repeat("X", maxInnerStderrBytes*2) + _, err := ExpandValue( + t.Context(), + `$(printf '`+long+`' 1>&2; exit 1)`, + nil, + ) + require.Error(t, err) + require.NotContains( + t, + err.Error(), + strings.Repeat("X", maxInnerStderrBytes+1), + "stderr should be bounded", + ) + }) +} + +func TestSanitizeStderr(t *testing.T) { + t.Parallel() + + t.Run("trims trailing newlines", func(t *testing.T) { + t.Parallel() + require.Equal(t, "hi", sanitizeStderr([]byte("hi\n\n"))) + }) + + t.Run("preserves tabs and embedded newlines", func(t *testing.T) { + t.Parallel() + require.Equal(t, "a\tb\nc", sanitizeStderr([]byte("a\tb\nc"))) + }) + + t.Run("replaces control characters", func(t *testing.T) { + t.Parallel() + require.Equal(t, "a?b", sanitizeStderr([]byte{'a', 0x01, 'b'})) + }) + + t.Run("bounds output", func(t *testing.T) { + t.Parallel() + got := sanitizeStderr([]byte(strings.Repeat("x", maxInnerStderrBytes*2))) + require.Len(t, got, maxInnerStderrBytes) + }) +}