From 5dc30cfac53a4db94cc0f81b0d8d8879afd72196 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sat, 2 May 2026 09:38:29 -0400 Subject: [PATCH] refactor(config): resolve shell vars via shell.ExpandValue Rewires shellVariableResolver onto the embedded shell interpreter used by the bash tool and hooks, replacing the hand-rolled parser. Adds a bounded, scrubbed sanitizeResolveError so resolution failures surface safely without leaking oversized or non-printable inner stderr. --- internal/config/resolve.go | 233 +++++++++--------- internal/config/resolve_test.go | 419 +++++++++++++++----------------- 2 files changed, 308 insertions(+), 344 deletions(-) diff --git a/internal/config/resolve.go b/internal/config/resolve.go index b9e7753386bb8c95b877b99172897ea4fdb0a045..8c22a8abc7a516cd19be6fc32eaa101d60e416d4 100644 --- a/internal/config/resolve.go +++ b/internal/config/resolve.go @@ -3,13 +3,17 @@ package config import ( "context" "fmt" - "strings" "time" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/shell" ) +// resolveTimeout bounds how long a single ResolveValue call may spend +// inside shell expansion (including any command substitution). Matches +// the timeout used by the previous hand-rolled parser. +const resolveTimeout = 5 * time.Minute + type VariableResolver interface { ResolveValue(value string) (string, error) } @@ -28,141 +32,140 @@ func IdentityResolver() VariableResolver { return identityResolver{} } -type Shell interface { - Exec(ctx context.Context, command string) (stdout, stderr string, err error) +// Expander is the single-value shell expansion seam used by +// shellVariableResolver. Production wires it to shell.ExpandValue; tests +// can inject a fake via WithExpander. +type Expander func(ctx context.Context, value string, env []string) (string, error) + +// ShellResolverOption customizes shell variable resolver construction. +type ShellResolverOption func(*shellVariableResolver) + +// WithExpander overrides the expansion function used by the resolver. +// Primarily intended for tests; production callers should not need this. +func WithExpander(e Expander) ShellResolverOption { + return func(r *shellVariableResolver) { + if e != nil { + r.expand = e + } + } } type shellVariableResolver struct { - shell Shell - env env.Env + env env.Env + expand Expander } -func NewShellVariableResolver(env env.Env) VariableResolver { - return &shellVariableResolver{ - env: env, - shell: shell.NewShell( - &shell.Options{ - Env: env.Env(), - }, - ), +// NewShellVariableResolver returns a VariableResolver that delegates to +// the embedded shell (the same interpreter used by the bash tool and +// hooks). Supported constructs match shell.ExpandValue: $VAR, ${VAR}, +// ${VAR:-default}, $(command), quoting, and escapes. Unset variables are +// an error; use ${VAR:-} to opt in to an empty fallback. +func NewShellVariableResolver(e env.Env, opts ...ShellResolverOption) VariableResolver { + r := &shellVariableResolver{ + env: e, + expand: shell.ExpandValue, } + for _, opt := range opts { + opt(r) + } + return r } -// ResolveValue is a method for resolving values, such as environment variables. -// it will resolve shell-like variable substitution anywhere in the string, including: -// - $(command) for command substitution -// - $VAR or ${VAR} for environment variables +// ResolveValue resolves shell-style substitution anywhere in the string: +// +// - $(command) for command substitution, with full quoting and nesting. +// - $VAR and ${VAR} for environment variables. +// - ${VAR:-default} / ${VAR:+alt} / ${VAR:?msg} for defaulting. +// +// Unset variables are a hard error (nounset), mirroring the historical +// behaviour of this resolver: silently expanding an unset variable to the +// empty string is exactly how broken credentials reach MCP servers. func (r *shellVariableResolver) ResolveValue(value string) (string, error) { - // Special case: lone $ is an error (backward compatibility) + // Preserve the historical backward-compat contract: a lone "$" is a + // malformed config value, not a legal literal. The underlying shell + // parser would accept it as a literal; we reject it here so existing + // configs that relied on this validation still fail early. if value == "$" { return "", fmt.Errorf("invalid value format: %s", value) } - // If no $ found, return as-is - if !strings.Contains(value, "$") { - return value, nil - } - - result := value + ctx, cancel := context.WithTimeout(context.Background(), resolveTimeout) + defer cancel() - // Handle command substitution: $(command) - for { - start := strings.Index(result, "$(") - if start == -1 { - break - } + out, err := r.expand(ctx, value, r.env.Env()) + if err != nil { + return "", sanitizeResolveError(value, err) + } + return out, nil +} - // Find matching closing parenthesis - depth := 0 - end := -1 - for i := start + 2; i < len(result); i++ { - if result[i] == '(' { - depth++ - } else if result[i] == ')' { - if depth == 0 { - end = i - break - } - depth-- - } - } +// maxResolveErrBytes bounds the size of the inner error message surfaced +// from a resolution failure. Defense-in-depth on top of shell.ExpandValue's +// own stderr budget: a custom Expander injected via WithExpander, or any +// future non-shell error path, must still produce a user-safe message. +const maxResolveErrBytes = 512 + +// sanitizeResolveError wraps an expansion error with the user-written +// template (the pre-expansion string — it is what they typed, safe to +// surface) and a bounded, scrubbed rendering of the inner error message. +// Contract: +// +// - Never includes the resolved (post-expansion) value. This helper +// only receives the template and err, so a successful expansion +// result cannot reach it. +// - May include the template verbatim. +// - Truncates the inner error's message to maxResolveErrBytes and +// replaces embedded NULs and other non-printables (except tab and +// newline) with '?'. +// +// The returned error still unwraps to the original for errors.Is/As so +// callers can inspect typed sentinels; only the rendered message is +// scrubbed. +func sanitizeResolveError(template string, err error) error { + if err == nil { + return nil + } + return &resolveError{ + template: template, + msg: scrubErrorMessage(err.Error()), + inner: err, + } +} - if end == -1 { - return "", fmt.Errorf("unmatched $( in value: %s", value) - } +// resolveError is the concrete type returned by sanitizeResolveError. +// Its Error() method returns the template + scrubbed inner message; +// Unwrap exposes the original error so errors.Is/As continue to work. +type resolveError struct { + template string + msg string + inner error +} - command := result[start+2 : end] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) +func (e *resolveError) Error() string { + return fmt.Sprintf("resolving %q: %s", e.template, e.msg) +} - stdout, _, err := r.shell.Exec(ctx, command) - cancel() - if err != nil { - return "", fmt.Errorf("command execution failed for '%s': %w", command, err) - } +func (e *resolveError) Unwrap() error { return e.inner } - // Replace the $(command) with the output - replacement := strings.TrimSpace(stdout) - result = result[:start] + replacement + result[end+1:] +// scrubErrorMessage bounds the message to maxResolveErrBytes bytes and +// replaces non-printable bytes (anything outside ASCII printable, tab, or +// newline) with '?'. Mirrors shell.sanitizeStderr but operates on a +// string rather than raw command stderr and runs at the config layer, +// so arbitrary Expander error text is also sanitized. +func scrubErrorMessage(s string) string { + if len(s) > maxResolveErrBytes { + s = s[:maxResolveErrBytes] } - - // Handle environment variables: $VAR and ${VAR} - searchStart := 0 - for { - start := strings.Index(result[searchStart:], "$") - if start == -1 { - break - } - start += searchStart // Adjust for the offset - - // Skip if this is part of $( which we already handled - if start+1 < len(result) && result[start+1] == '(' { - // Skip past this $(...) - searchStart = start + 1 + out := make([]byte, len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if c == '\t' || c == '\n' || (c >= 0x20 && c < 0x7f) { + out[i] = c continue } - var varName string - var end int - - if start+1 < len(result) && result[start+1] == '{' { - // Handle ${VAR} format - closeIdx := strings.Index(result[start+2:], "}") - if closeIdx == -1 { - return "", fmt.Errorf("unmatched ${ in value: %s", value) - } - varName = result[start+2 : start+2+closeIdx] - end = start + 2 + closeIdx + 1 - } else { - // Handle $VAR format - variable names must start with letter or underscore - if start+1 >= len(result) { - return "", fmt.Errorf("incomplete variable reference at end of string: %s", value) - } - - if result[start+1] != '_' && - (result[start+1] < 'a' || result[start+1] > 'z') && - (result[start+1] < 'A' || result[start+1] > 'Z') { - return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value) - } - - end = start + 1 - for end < len(result) && (result[end] == '_' || - (result[end] >= 'a' && result[end] <= 'z') || - (result[end] >= 'A' && result[end] <= 'Z') || - (result[end] >= '0' && result[end] <= '9')) { - end++ - } - varName = result[start+1 : end] - } - - envValue := r.env.Get(varName) - if envValue == "" { - return "", fmt.Errorf("environment variable %q not set", varName) - } - - result = result[:start] + envValue + result[end:] - searchStart = start + len(envValue) // Continue searching after the replacement + out[i] = '?' } - - return result, nil + return string(out) } type environmentVariableResolver struct { @@ -177,11 +180,11 @@ func NewEnvironmentVariableResolver(env env.Env) VariableResolver { // ResolveValue resolves environment variables from the provided env.Env. func (r *environmentVariableResolver) ResolveValue(value string) (string, error) { - if !strings.HasPrefix(value, "$") { + if len(value) == 0 || value[0] != '$' { return value, nil } - varName := strings.TrimPrefix(value, "$") + varName := value[1:] resolvedValue := r.env.Get(varName) if resolvedValue == "" { return "", fmt.Errorf("environment variable %q not set", varName) diff --git a/internal/config/resolve_test.go b/internal/config/resolve_test.go index ec9b06c25bdc023acebffc71f043b54a8da21597..18b461b8a6bc88f73d6db935eebd302cc3514781 100644 --- a/internal/config/resolve_test.go +++ b/internal/config/resolve_test.go @@ -3,260 +3,221 @@ package config import ( "context" "errors" + "strings" "testing" "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/require" ) -// mockShell implements the Shell interface for testing -type mockShell struct { - execFunc func(ctx context.Context, command string) (stdout, stderr string, err error) +// fakeExpander returns a canned value/error for the last passed value and +// records the context, raw value, and env slice it was called with. It +// lets the config-layer tests assert on delegation behaviour without +// spinning up a real interpreter — real-shell coverage lives in +// internal/shell/expand_test.go and resolve_real_test.go. +type fakeExpander struct { + expand func(ctx context.Context, value string, env []string) (string, error) + lastValue string + lastEnv []string + calls int } -func (m *mockShell) Exec(ctx context.Context, command string) (stdout, stderr string, err error) { - if m.execFunc != nil { - return m.execFunc(ctx, command) +func (f *fakeExpander) Expand(ctx context.Context, value string, env []string) (string, error) { + f.calls++ + f.lastValue = value + f.lastEnv = env + if f.expand == nil { + return value, nil } - return "", "", nil + return f.expand(ctx, value, env) } -func TestShellVariableResolver_ResolveValue(t *testing.T) { - tests := []struct { - name string - value string - envVars map[string]string - shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) - expected string - expectError bool - }{ - { - name: "non-variable string returns as-is", - value: "plain-string", - expected: "plain-string", - }, - { - name: "environment variable resolution", - value: "$HOME", - envVars: map[string]string{"HOME": "/home/user"}, - expected: "/home/user", - }, - { - name: "missing environment variable returns error", - value: "$MISSING_VAR", - envVars: map[string]string{}, - expectError: true, - }, +func TestShellVariableResolver_DelegatesToExpander(t *testing.T) { + t.Parallel() - { - name: "shell command with whitespace trimming", - value: "$(echo ' spaced ')", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "echo ' spaced '" { - return " spaced \n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "spaced", - }, - { - name: "shell command execution error", - value: "$(false)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - return "", "", errors.New("command failed") - }, - expectError: true, - }, - { - name: "invalid format returns error", - value: "$", - expectError: true, + fe := &fakeExpander{ + expand: func(_ context.Context, value string, _ []string) (string, error) { + if value == "hello $FOO" { + return "hello bar", nil + } + return value, nil }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testEnv := env.NewFromMap(tt.envVars) - resolver := &shellVariableResolver{ - shell: &mockShell{execFunc: tt.shellFunc}, - env: testEnv, - } + e := env.NewFromMap(map[string]string{"FOO": "bar"}) + r := NewShellVariableResolver(e, WithExpander(fe.Expand)) - result, err := resolver.ResolveValue(tt.value) + got, err := r.ResolveValue("hello $FOO") + require.NoError(t, err) + require.Equal(t, "hello bar", got) + require.Equal(t, 1, fe.calls) + require.Equal(t, "hello $FOO", fe.lastValue) + require.Contains(t, fe.lastEnv, "FOO=bar") +} - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expected, result) - } - }) - } +func TestShellVariableResolver_LoneDollarIsError(t *testing.T) { + t.Parallel() + + // Lone "$" must short-circuit before reaching the expander: the + // underlying shell parser would accept it as a literal, but this + // resolver has historically rejected it and callers depend on + // that early-fail behaviour. + fe := &fakeExpander{} + r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand)) + + _, err := r.ResolveValue("$") + require.Error(t, err) + require.Equal(t, 0, fe.calls, "expander must not be called for lone $") } -func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { - tests := []struct { - name string - value string - envVars map[string]string - shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) - expected string - expectError bool - }{ - { - name: "command substitution within string", - value: "Bearer $(echo token123)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "echo token123" { - return "token123\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "Bearer token123", - }, - { - name: "environment variable within string", - value: "Bearer $TOKEN", - envVars: map[string]string{"TOKEN": "sk-ant-123"}, - expected: "Bearer sk-ant-123", - }, - { - name: "environment variable with braces within string", - value: "Bearer ${TOKEN}", - envVars: map[string]string{"TOKEN": "sk-ant-456"}, - expected: "Bearer sk-ant-456", - }, - { - name: "mixed command and environment substitution", - value: "$USER-$(date +%Y)-$HOST", - envVars: map[string]string{ - "USER": "testuser", - "HOST": "localhost", - }, - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "date +%Y" { - return "2024\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "testuser-2024-localhost", - }, - { - name: "multiple command substitutions", - value: "$(echo hello) $(echo world)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - switch command { - case "echo hello": - return "hello\n", "", nil - case "echo world": - return "world\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "hello world", - }, - { - name: "nested parentheses in command", - value: "$(echo $(echo inner))", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "echo $(echo inner)" { - return "nested\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "nested", - }, - { - name: "lone dollar with non-variable chars", - value: "prefix$123suffix", // Numbers can't start variable names - expectError: true, - }, - { - name: "dollar with special chars", - value: "a$@b$#c", // Special chars aren't valid in variable names - expectError: true, - }, - { - name: "empty environment variable substitution", - value: "Bearer $EMPTY_VAR", - envVars: map[string]string{}, - expectError: true, - }, - { - name: "unmatched command substitution opening", - value: "Bearer $(echo test", - expectError: true, - }, - { - name: "unmatched environment variable braces", - value: "Bearer ${TOKEN", - expectError: true, - }, - { - name: "command substitution with error", - value: "Bearer $(false)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - return "", "", errors.New("command failed") - }, - expectError: true, - }, - { - name: "complex real-world example", - value: "Bearer $(cat /tmp/token.txt | base64 -w 0)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "cat /tmp/token.txt | base64 -w 0" { - return "c2stYW50LXRlc3Q=\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "Bearer c2stYW50LXRlc3Q=", - }, - { - name: "environment variable with underscores and numbers", - value: "Bearer $API_KEY_V2", - envVars: map[string]string{"API_KEY_V2": "sk-test-123"}, - expected: "Bearer sk-test-123", - }, - { - name: "no substitution needed", - value: "Bearer sk-ant-static-token", - expected: "Bearer sk-ant-static-token", - }, - { - name: "incomplete variable at end", - value: "Bearer $", - expectError: true, - }, - { - name: "variable with invalid character", - value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names - expectError: true, - }, - { - name: "multiple invalid variables", - value: "$1$2$3", - expectError: true, +func TestShellVariableResolver_PassesThroughLiterals(t *testing.T) { + t.Parallel() + + fe := &fakeExpander{ + expand: func(_ context.Context, value string, _ []string) (string, error) { + return value, nil }, } + r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand)) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testEnv := env.NewFromMap(tt.envVars) - resolver := &shellVariableResolver{ - shell: &mockShell{execFunc: tt.shellFunc}, - env: testEnv, - } + got, err := r.ResolveValue("plain-string") + require.NoError(t, err) + require.Equal(t, "plain-string", got) +} - result, err := resolver.ResolveValue(tt.value) +func TestShellVariableResolver_WrapsErrorsWithTemplate(t *testing.T) { + t.Parallel() - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expected, result) - } - }) + inner := errors.New("cat: /run/secrets/x: permission denied") + fe := &fakeExpander{ + expand: func(_ context.Context, _ string, _ []string) (string, error) { + return "", inner + }, } + r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand)) + + _, err := r.ResolveValue("$(cat /run/secrets/x)") + require.Error(t, err) + require.ErrorIs(t, err, inner) + require.Contains(t, err.Error(), "$(cat /run/secrets/x)") + require.Contains(t, err.Error(), "permission denied") +} + +func TestSanitizeResolveError(t *testing.T) { + t.Parallel() + + t.Run("nil passes through", func(t *testing.T) { + t.Parallel() + require.NoError(t, sanitizeResolveError("anything", nil)) + }) + + t.Run("includes template and wraps inner", func(t *testing.T) { + t.Parallel() + inner := errors.New("cat: /run/secrets/x: permission denied") + got := sanitizeResolveError("$(cat /run/secrets/x)", inner) + require.Error(t, got) + require.ErrorIs(t, got, inner) + require.Contains(t, got.Error(), "$(cat /run/secrets/x)") + require.Contains(t, got.Error(), "permission denied") + }) + + t.Run("unwrap preserves original for errors.Is", func(t *testing.T) { + t.Parallel() + inner := errors.New("sentinel") + got := sanitizeResolveError("$FOO", inner) + require.ErrorIs(t, got, inner) + }) + + t.Run("truncates over-budget inner message", func(t *testing.T) { + t.Parallel() + // Inner message holds far more than the budget. After + // sanitization the rendered inner portion must not exceed + // maxResolveErrBytes, and the characters beyond the budget + // (marked by a distinct tail sentinel) must be gone. + const tailSentinel = "TAIL_SENTINEL_BEYOND_BUDGET" + body := strings.Repeat("x", maxResolveErrBytes) + inner := errors.New(body + tailSentinel) + + got := sanitizeResolveError("$TEMPLATE", inner) + require.Error(t, got) + + prefix := `resolving "$TEMPLATE": ` + rendered := got.Error() + require.True( + t, + strings.HasPrefix(rendered, prefix), + "rendered error must start with template prefix", + ) + innerRendered := strings.TrimPrefix(rendered, prefix) + require.LessOrEqual( + t, + len(innerRendered), + maxResolveErrBytes, + "inner message must be bounded to maxResolveErrBytes", + ) + require.NotContains( + t, + rendered, + tailSentinel, + "content past the budget must not leak", + ) + }) + + t.Run("replaces non-printable bytes", func(t *testing.T) { + t.Parallel() + // NUL, BEL, ESC, DEL, and a UTF-8 high byte should all be + // scrubbed to '?'. Tab and newline are preserved because + // they show up legitimately in command stderr. + inner := errors.New("ok\x00bad\x07\x1b\x7f\xffend\ttab\nline") + got := sanitizeResolveError("$T", inner) + rendered := got.Error() + + require.NotContains(t, rendered, "\x00") + require.NotContains(t, rendered, "\x07") + require.NotContains(t, rendered, "\x1b") + require.NotContains(t, rendered, "\x7f") + require.NotContains(t, rendered, "\xff") + require.Contains(t, rendered, "ok?bad????end\ttab\nline") + }) + + t.Run("scrubbing does not depend on shell.ExpandValue upstream", func(t *testing.T) { + t.Parallel() + // A custom Expander can inject arbitrary error text. The + // config-layer helper is the single chokepoint; it must + // bound + scrub regardless of the error source. + nasty := strings.Repeat("A", maxResolveErrBytes+64) + "\x00BEYOND" + fe := &fakeExpander{ + expand: func(_ context.Context, _ string, _ []string) (string, error) { + return "", errors.New(nasty) + }, + } + r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand)) + + _, err := r.ResolveValue("$T") + require.Error(t, err) + require.NotContains(t, err.Error(), "BEYOND", "over-budget tail must not leak") + require.NotContains(t, err.Error(), "\x00", "non-printables must be scrubbed") + }) +} + +func TestScrubErrorMessage(t *testing.T) { + t.Parallel() + + t.Run("bounds output to maxResolveErrBytes", func(t *testing.T) { + t.Parallel() + got := scrubErrorMessage(strings.Repeat("a", maxResolveErrBytes*3)) + require.Len(t, got, maxResolveErrBytes) + }) + + t.Run("preserves printable ASCII tab and newline", func(t *testing.T) { + t.Parallel() + require.Equal(t, "a\tb\nc d!", scrubErrorMessage("a\tb\nc d!")) + }) + + t.Run("replaces control and non-ASCII bytes", func(t *testing.T) { + t.Parallel() + require.Equal(t, "a?b??c", scrubErrorMessage("a\x01b\x1b\xe2c")) + }) } func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {