diff --git a/internal/shell/expand.go b/internal/shell/expand.go new file mode 100644 index 0000000000000000000000000000000000000000..ee110a1df7b6b60b9ac2d51b7ae5f6abcbf3ac79 --- /dev/null +++ b/internal/shell/expand.go @@ -0,0 +1,123 @@ +package shell + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "strings" + + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" + "mvdan.cc/sh/v3/syntax" +) + +// maxInnerStderrBytes bounds how much stderr from a failing $(...) is +// surfaced in the returned error, to avoid leaking a secret that happened +// to be embedded in a failing inner command. +const maxInnerStderrBytes = 512 + +// ExpandValue expands shell-style substitutions in a single config value. +// +// Supported constructs match the bash tool: +// +// - $VAR and ${VAR} (unset is an error; see nounset below). +// - ${VAR:-default} / ${VAR:+alt} / ${VAR:?msg}. +// - $(command) with full quoting and nesting. +// - escaped and quoted strings ("...", '...'). +// +// Contract: +// +// - Returns exactly one string. No field splitting, no globbing, no +// pathname generation. Multi-word command output is preserved +// verbatim; it is never split into multiple values. +// - Nounset is on: unset variables produce an error instead of +// expanding to the empty string. Use ${VAR:-default} to opt in to +// an empty fallback. +// - Embedded whitespace and newlines in the input are preserved +// verbatim. Command substitution strips trailing newlines only +// (POSIX), never leading or internal whitespace. +// - Errors wrap the failing inner command's exit code and a bounded +// prefix of its stderr. Callers that surface the error to users +// should additionally scrub it for the original template text. +func ExpandValue(ctx context.Context, value string, env []string) (string, error) { + // Parse the value as a here-doc style word: no word splitting, no + // globbing, but full support for $VAR, ${VAR...}, $(...), and + // quoted/escaped strings. + word, err := syntax.NewParser().Document(strings.NewReader(value)) + if err != nil { + return "", fmt.Errorf("parse: %w", err) + } + + // Build a minimal Shell value purely to reuse its handler chain + // (builtins, block funcs, optional Go coreutils) inside $(...). + // We deliberately skip NewShell so the passed-in env is used + // verbatim, with no CRUSH/AGENT/AI_AGENT injection: callers of + // ExpandValue control the env, and nounset must treat any name + // not in env as unset. + cwd, _ := os.Getwd() + s := &Shell{ + cwd: cwd, + env: env, + logger: noopLogger{}, + } + + var stderrBuf bytes.Buffer + cfg := &expand.Config{ + Env: expand.ListEnviron(env...), + NoUnset: true, + CmdSubst: func(w io.Writer, cs *syntax.CmdSubst) error { + stderrBuf.Reset() + runner, rerr := interp.New( + interp.StdIO(nil, w, &stderrBuf), + interp.Interactive(false), + interp.Env(expand.ListEnviron(env...)), + interp.Dir(s.cwd), + interp.ExecHandlers(s.execHandlers()...), + // Match the outer NoUnset: an unset $VAR inside + // $(...) is also an error, not a silent empty. + interp.Params("-u"), + ) + if rerr != nil { + return rerr + } + if rerr := runner.Run(ctx, &syntax.File{Stmts: cs.Stmts}); rerr != nil { + return wrapCmdSubstErr(rerr, stderrBuf.Bytes()) + } + return nil + }, + // ReadDir / ReadDir2 left nil: globbing is disabled. + } + + return expand.Document(cfg, word) +} + +// wrapCmdSubstErr attaches a bounded prefix of the inner command's stderr +// to the original error, if any. +func wrapCmdSubstErr(err error, stderrBytes []byte) error { + msg := sanitizeStderr(stderrBytes) + if msg == "" { + return err + } + return fmt.Errorf("%w: %s", err, msg) +} + +// sanitizeStderr trims, bounds, and scrubs non-printable bytes from the +// stderr of a failing command so the result is safe to include in an +// error message shown to the user. +func sanitizeStderr(b []byte) string { + b = bytes.TrimRight(b, "\n") + if len(b) > maxInnerStderrBytes { + b = b[:maxInnerStderrBytes] + } + out := make([]byte, len(b)) + for i, c := range b { + if c == '\t' || c == '\n' || (c >= 0x20 && c < 0x7f) { + out[i] = c + } else { + out[i] = '?' + } + } + return string(out) +} diff --git a/internal/shell/expand_test.go b/internal/shell/expand_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8661d6682370d31666185576db7edddb17a65b97 --- /dev/null +++ b/internal/shell/expand_test.go @@ -0,0 +1,199 @@ +package shell + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExpandValue_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + env []string + want string + }{ + { + name: "plain string round trip", + value: "hello world", + want: "hello world", + }, + { + name: "plain var from env", + value: "$FOO", + env: []string{"FOO=bar"}, + want: "bar", + }, + { + name: "braced var from env", + value: "pre-${FOO}-post", + env: []string{"FOO=bar"}, + want: "pre-bar-post", + }, + { + name: "default syntax on unset", + value: "${MISSING:-fallback}", + want: "fallback", + }, + { + name: "default syntax on set preserves value", + value: "${SET:-fallback}", + env: []string{"SET=used"}, + want: "used", + }, + { + name: "command substitution", + value: "$(echo hi)", + want: "hi", + }, + { + name: "command substitution preserves internal spaces", + value: `$(echo "a b")`, + want: "a b", + }, + { + name: "command substitution strips only trailing newline", + value: "$(printf 'a\\nb\\n')", + want: "a\nb", + }, + { + name: "literal spaces around cmdsubst are preserved", + value: " $(echo v) ", + want: " v ", + }, + { + name: "paren inside quoted arg to echo", + value: `$(echo ")")`, + want: ")", + }, + { + name: "nested command substitution", + value: "$(echo $(echo hi))", + want: "hi", + }, + { + name: "glob-like input round trips unchanged", + value: "*.go", + want: "*.go", + }, + { + name: "tilde round trips unchanged", + value: "~", + want: "~", + }, + { + name: "env var inside cmdsubst", + value: `$(printf '%s' "$FOO")`, + env: []string{"FOO=bar"}, + want: "bar", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := ExpandValue(t.Context(), tc.value, tc.env) + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestExpandValue_Errors(t *testing.T) { + t.Parallel() + + t.Run("unset var is an error, not empty", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "$MISSING", nil) + require.Error(t, err) + }) + + t.Run("unset var inside braces is an error", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "${MISSING}", nil) + require.Error(t, err) + }) + + t.Run("unset var inside cmdsubst is an error", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), `$(printf '%s' "$MISSING")`, nil) + require.Error(t, err) + }) + + t.Run("bad syntax returns error", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "$(", nil) + require.Error(t, err) + }) + + t.Run("inner non-zero exit returns error with exit code", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "$(false)", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "exit status 1") + }) + + t.Run("inner explicit exit code is surfaced", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue(t.Context(), "$(exit 7)", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "exit status 7") + }) + + t.Run("inner stderr is surfaced", func(t *testing.T) { + t.Parallel() + _, err := ExpandValue( + t.Context(), + `$(printf 'boom\n' 1>&2; exit 1)`, + nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "boom") + }) + + t.Run("inner stderr is truncated to byte budget", func(t *testing.T) { + t.Parallel() + // Emit more than maxInnerStderrBytes bytes of 'X' on stderr. + long := strings.Repeat("X", maxInnerStderrBytes*2) + _, err := ExpandValue( + t.Context(), + `$(printf '`+long+`' 1>&2; exit 1)`, + nil, + ) + require.Error(t, err) + require.NotContains( + t, + err.Error(), + strings.Repeat("X", maxInnerStderrBytes+1), + "stderr should be bounded", + ) + }) +} + +func TestSanitizeStderr(t *testing.T) { + t.Parallel() + + t.Run("trims trailing newlines", func(t *testing.T) { + t.Parallel() + require.Equal(t, "hi", sanitizeStderr([]byte("hi\n\n"))) + }) + + t.Run("preserves tabs and embedded newlines", func(t *testing.T) { + t.Parallel() + require.Equal(t, "a\tb\nc", sanitizeStderr([]byte("a\tb\nc"))) + }) + + t.Run("replaces control characters", func(t *testing.T) { + t.Parallel() + require.Equal(t, "a?b", sanitizeStderr([]byte{'a', 0x01, 'b'})) + }) + + t.Run("bounds output", func(t *testing.T) { + t.Parallel() + got := sanitizeStderr([]byte(strings.Repeat("x", maxInnerStderrBytes*2))) + require.Len(t, got, maxInnerStderrBytes) + }) +}