@@ -3,13 +3,17 @@ package config
import (
"context"
"fmt"
- "strings"
"time"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/shell"
)
+// resolveTimeout bounds how long a single ResolveValue call may spend
+// inside shell expansion (including any command substitution). Matches
+// the timeout used by the previous hand-rolled parser.
+const resolveTimeout = 5 * time.Minute
+
type VariableResolver interface {
ResolveValue(value string) (string, error)
}
@@ -28,141 +32,140 @@ func IdentityResolver() VariableResolver {
return identityResolver{}
}
-type Shell interface {
- Exec(ctx context.Context, command string) (stdout, stderr string, err error)
+// Expander is the single-value shell expansion seam used by
+// shellVariableResolver. Production wires it to shell.ExpandValue; tests
+// can inject a fake via WithExpander.
+type Expander func(ctx context.Context, value string, env []string) (string, error)
+
+// ShellResolverOption customizes shell variable resolver construction.
+type ShellResolverOption func(*shellVariableResolver)
+
+// WithExpander overrides the expansion function used by the resolver.
+// Primarily intended for tests; production callers should not need this.
+func WithExpander(e Expander) ShellResolverOption {
+ return func(r *shellVariableResolver) {
+ if e != nil {
+ r.expand = e
+ }
+ }
}
type shellVariableResolver struct {
- shell Shell
- env env.Env
+ env env.Env
+ expand Expander
}
-func NewShellVariableResolver(env env.Env) VariableResolver {
- return &shellVariableResolver{
- env: env,
- shell: shell.NewShell(
- &shell.Options{
- Env: env.Env(),
- },
- ),
+// NewShellVariableResolver returns a VariableResolver that delegates to
+// the embedded shell (the same interpreter used by the bash tool and
+// hooks). Supported constructs match shell.ExpandValue: $VAR, ${VAR},
+// ${VAR:-default}, $(command), quoting, and escapes. Unset variables are
+// an error; use ${VAR:-} to opt in to an empty fallback.
+func NewShellVariableResolver(e env.Env, opts ...ShellResolverOption) VariableResolver {
+ r := &shellVariableResolver{
+ env: e,
+ expand: shell.ExpandValue,
}
+ for _, opt := range opts {
+ opt(r)
+ }
+ return r
}
-// ResolveValue is a method for resolving values, such as environment variables.
-// it will resolve shell-like variable substitution anywhere in the string, including:
-// - $(command) for command substitution
-// - $VAR or ${VAR} for environment variables
+// ResolveValue resolves shell-style substitution anywhere in the string:
+//
+// - $(command) for command substitution, with full quoting and nesting.
+// - $VAR and ${VAR} for environment variables.
+// - ${VAR:-default} / ${VAR:+alt} / ${VAR:?msg} for defaulting.
+//
+// Unset variables are a hard error (nounset), mirroring the historical
+// behaviour of this resolver: silently expanding an unset variable to the
+// empty string is exactly how broken credentials reach MCP servers.
func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
- // Special case: lone $ is an error (backward compatibility)
+ // Preserve the historical backward-compat contract: a lone "$" is a
+ // malformed config value, not a legal literal. The underlying shell
+ // parser would accept it as a literal; we reject it here so existing
+ // configs that relied on this validation still fail early.
if value == "$" {
return "", fmt.Errorf("invalid value format: %s", value)
}
- // If no $ found, return as-is
- if !strings.Contains(value, "$") {
- return value, nil
- }
-
- result := value
+ ctx, cancel := context.WithTimeout(context.Background(), resolveTimeout)
+ defer cancel()
- // Handle command substitution: $(command)
- for {
- start := strings.Index(result, "$(")
- if start == -1 {
- break
- }
+ out, err := r.expand(ctx, value, r.env.Env())
+ if err != nil {
+ return "", sanitizeResolveError(value, err)
+ }
+ return out, nil
+}
- // Find matching closing parenthesis
- depth := 0
- end := -1
- for i := start + 2; i < len(result); i++ {
- if result[i] == '(' {
- depth++
- } else if result[i] == ')' {
- if depth == 0 {
- end = i
- break
- }
- depth--
- }
- }
+// maxResolveErrBytes bounds the size of the inner error message surfaced
+// from a resolution failure. Defense-in-depth on top of shell.ExpandValue's
+// own stderr budget: a custom Expander injected via WithExpander, or any
+// future non-shell error path, must still produce a user-safe message.
+const maxResolveErrBytes = 512
+
+// sanitizeResolveError wraps an expansion error with the user-written
+// template (the pre-expansion string — it is what they typed, safe to
+// surface) and a bounded, scrubbed rendering of the inner error message.
+// Contract:
+//
+// - Never includes the resolved (post-expansion) value. This helper
+// only receives the template and err, so a successful expansion
+// result cannot reach it.
+// - May include the template verbatim.
+// - Truncates the inner error's message to maxResolveErrBytes and
+// replaces embedded NULs and other non-printables (except tab and
+// newline) with '?'.
+//
+// The returned error still unwraps to the original for errors.Is/As so
+// callers can inspect typed sentinels; only the rendered message is
+// scrubbed.
+func sanitizeResolveError(template string, err error) error {
+ if err == nil {
+ return nil
+ }
+ return &resolveError{
+ template: template,
+ msg: scrubErrorMessage(err.Error()),
+ inner: err,
+ }
+}
- if end == -1 {
- return "", fmt.Errorf("unmatched $( in value: %s", value)
- }
+// resolveError is the concrete type returned by sanitizeResolveError.
+// Its Error() method returns the template + scrubbed inner message;
+// Unwrap exposes the original error so errors.Is/As continue to work.
+type resolveError struct {
+ template string
+ msg string
+ inner error
+}
- command := result[start+2 : end]
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+func (e *resolveError) Error() string {
+ return fmt.Sprintf("resolving %q: %s", e.template, e.msg)
+}
- stdout, _, err := r.shell.Exec(ctx, command)
- cancel()
- if err != nil {
- return "", fmt.Errorf("command execution failed for '%s': %w", command, err)
- }
+func (e *resolveError) Unwrap() error { return e.inner }
- // Replace the $(command) with the output
- replacement := strings.TrimSpace(stdout)
- result = result[:start] + replacement + result[end+1:]
+// scrubErrorMessage bounds the message to maxResolveErrBytes bytes and
+// replaces non-printable bytes (anything outside ASCII printable, tab, or
+// newline) with '?'. Mirrors shell.sanitizeStderr but operates on a
+// string rather than raw command stderr and runs at the config layer,
+// so arbitrary Expander error text is also sanitized.
+func scrubErrorMessage(s string) string {
+ if len(s) > maxResolveErrBytes {
+ s = s[:maxResolveErrBytes]
}
-
- // Handle environment variables: $VAR and ${VAR}
- searchStart := 0
- for {
- start := strings.Index(result[searchStart:], "$")
- if start == -1 {
- break
- }
- start += searchStart // Adjust for the offset
-
- // Skip if this is part of $( which we already handled
- if start+1 < len(result) && result[start+1] == '(' {
- // Skip past this $(...)
- searchStart = start + 1
+ out := make([]byte, len(s))
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if c == '\t' || c == '\n' || (c >= 0x20 && c < 0x7f) {
+ out[i] = c
continue
}
- var varName string
- var end int
-
- if start+1 < len(result) && result[start+1] == '{' {
- // Handle ${VAR} format
- closeIdx := strings.Index(result[start+2:], "}")
- if closeIdx == -1 {
- return "", fmt.Errorf("unmatched ${ in value: %s", value)
- }
- varName = result[start+2 : start+2+closeIdx]
- end = start + 2 + closeIdx + 1
- } else {
- // Handle $VAR format - variable names must start with letter or underscore
- if start+1 >= len(result) {
- return "", fmt.Errorf("incomplete variable reference at end of string: %s", value)
- }
-
- if result[start+1] != '_' &&
- (result[start+1] < 'a' || result[start+1] > 'z') &&
- (result[start+1] < 'A' || result[start+1] > 'Z') {
- return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value)
- }
-
- end = start + 1
- for end < len(result) && (result[end] == '_' ||
- (result[end] >= 'a' && result[end] <= 'z') ||
- (result[end] >= 'A' && result[end] <= 'Z') ||
- (result[end] >= '0' && result[end] <= '9')) {
- end++
- }
- varName = result[start+1 : end]
- }
-
- envValue := r.env.Get(varName)
- if envValue == "" {
- return "", fmt.Errorf("environment variable %q not set", varName)
- }
-
- result = result[:start] + envValue + result[end:]
- searchStart = start + len(envValue) // Continue searching after the replacement
+ out[i] = '?'
}
-
- return result, nil
+ return string(out)
}
type environmentVariableResolver struct {
@@ -177,11 +180,11 @@ func NewEnvironmentVariableResolver(env env.Env) VariableResolver {
// ResolveValue resolves environment variables from the provided env.Env.
func (r *environmentVariableResolver) ResolveValue(value string) (string, error) {
- if !strings.HasPrefix(value, "$") {
+ if len(value) == 0 || value[0] != '$' {
return value, nil
}
- varName := strings.TrimPrefix(value, "$")
+ varName := value[1:]
resolvedValue := r.env.Get(varName)
if resolvedValue == "" {
return "", fmt.Errorf("environment variable %q not set", varName)
@@ -3,260 +3,221 @@ package config
import (
"context"
"errors"
+ "strings"
"testing"
"github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/require"
)
-// mockShell implements the Shell interface for testing
-type mockShell struct {
- execFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
+// fakeExpander returns a canned value/error for the last passed value and
+// records the context, raw value, and env slice it was called with. It
+// lets the config-layer tests assert on delegation behaviour without
+// spinning up a real interpreter — real-shell coverage lives in
+// internal/shell/expand_test.go and resolve_real_test.go.
+type fakeExpander struct {
+ expand func(ctx context.Context, value string, env []string) (string, error)
+ lastValue string
+ lastEnv []string
+ calls int
}
-func (m *mockShell) Exec(ctx context.Context, command string) (stdout, stderr string, err error) {
- if m.execFunc != nil {
- return m.execFunc(ctx, command)
+func (f *fakeExpander) Expand(ctx context.Context, value string, env []string) (string, error) {
+ f.calls++
+ f.lastValue = value
+ f.lastEnv = env
+ if f.expand == nil {
+ return value, nil
}
- return "", "", nil
+ return f.expand(ctx, value, env)
}
-func TestShellVariableResolver_ResolveValue(t *testing.T) {
- tests := []struct {
- name string
- value string
- envVars map[string]string
- shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
- expected string
- expectError bool
- }{
- {
- name: "non-variable string returns as-is",
- value: "plain-string",
- expected: "plain-string",
- },
- {
- name: "environment variable resolution",
- value: "$HOME",
- envVars: map[string]string{"HOME": "/home/user"},
- expected: "/home/user",
- },
- {
- name: "missing environment variable returns error",
- value: "$MISSING_VAR",
- envVars: map[string]string{},
- expectError: true,
- },
+func TestShellVariableResolver_DelegatesToExpander(t *testing.T) {
+ t.Parallel()
- {
- name: "shell command with whitespace trimming",
- value: "$(echo ' spaced ')",
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- if command == "echo ' spaced '" {
- return " spaced \n", "", nil
- }
- return "", "", errors.New("unexpected command")
- },
- expected: "spaced",
- },
- {
- name: "shell command execution error",
- value: "$(false)",
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- return "", "", errors.New("command failed")
- },
- expectError: true,
- },
- {
- name: "invalid format returns error",
- value: "$",
- expectError: true,
+ fe := &fakeExpander{
+ expand: func(_ context.Context, value string, _ []string) (string, error) {
+ if value == "hello $FOO" {
+ return "hello bar", nil
+ }
+ return value, nil
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- testEnv := env.NewFromMap(tt.envVars)
- resolver := &shellVariableResolver{
- shell: &mockShell{execFunc: tt.shellFunc},
- env: testEnv,
- }
+ e := env.NewFromMap(map[string]string{"FOO": "bar"})
+ r := NewShellVariableResolver(e, WithExpander(fe.Expand))
- result, err := resolver.ResolveValue(tt.value)
+ got, err := r.ResolveValue("hello $FOO")
+ require.NoError(t, err)
+ require.Equal(t, "hello bar", got)
+ require.Equal(t, 1, fe.calls)
+ require.Equal(t, "hello $FOO", fe.lastValue)
+ require.Contains(t, fe.lastEnv, "FOO=bar")
+}
- if tt.expectError {
- require.Error(t, err)
- } else {
- require.NoError(t, err)
- require.Equal(t, tt.expected, result)
- }
- })
- }
+func TestShellVariableResolver_LoneDollarIsError(t *testing.T) {
+ t.Parallel()
+
+ // Lone "$" must short-circuit before reaching the expander: the
+ // underlying shell parser would accept it as a literal, but this
+ // resolver has historically rejected it and callers depend on
+ // that early-fail behaviour.
+ fe := &fakeExpander{}
+ r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand))
+
+ _, err := r.ResolveValue("$")
+ require.Error(t, err)
+ require.Equal(t, 0, fe.calls, "expander must not be called for lone $")
}
-func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
- tests := []struct {
- name string
- value string
- envVars map[string]string
- shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
- expected string
- expectError bool
- }{
- {
- name: "command substitution within string",
- value: "Bearer $(echo token123)",
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- if command == "echo token123" {
- return "token123\n", "", nil
- }
- return "", "", errors.New("unexpected command")
- },
- expected: "Bearer token123",
- },
- {
- name: "environment variable within string",
- value: "Bearer $TOKEN",
- envVars: map[string]string{"TOKEN": "sk-ant-123"},
- expected: "Bearer sk-ant-123",
- },
- {
- name: "environment variable with braces within string",
- value: "Bearer ${TOKEN}",
- envVars: map[string]string{"TOKEN": "sk-ant-456"},
- expected: "Bearer sk-ant-456",
- },
- {
- name: "mixed command and environment substitution",
- value: "$USER-$(date +%Y)-$HOST",
- envVars: map[string]string{
- "USER": "testuser",
- "HOST": "localhost",
- },
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- if command == "date +%Y" {
- return "2024\n", "", nil
- }
- return "", "", errors.New("unexpected command")
- },
- expected: "testuser-2024-localhost",
- },
- {
- name: "multiple command substitutions",
- value: "$(echo hello) $(echo world)",
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- switch command {
- case "echo hello":
- return "hello\n", "", nil
- case "echo world":
- return "world\n", "", nil
- }
- return "", "", errors.New("unexpected command")
- },
- expected: "hello world",
- },
- {
- name: "nested parentheses in command",
- value: "$(echo $(echo inner))",
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- if command == "echo $(echo inner)" {
- return "nested\n", "", nil
- }
- return "", "", errors.New("unexpected command")
- },
- expected: "nested",
- },
- {
- name: "lone dollar with non-variable chars",
- value: "prefix$123suffix", // Numbers can't start variable names
- expectError: true,
- },
- {
- name: "dollar with special chars",
- value: "a$@b$#c", // Special chars aren't valid in variable names
- expectError: true,
- },
- {
- name: "empty environment variable substitution",
- value: "Bearer $EMPTY_VAR",
- envVars: map[string]string{},
- expectError: true,
- },
- {
- name: "unmatched command substitution opening",
- value: "Bearer $(echo test",
- expectError: true,
- },
- {
- name: "unmatched environment variable braces",
- value: "Bearer ${TOKEN",
- expectError: true,
- },
- {
- name: "command substitution with error",
- value: "Bearer $(false)",
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- return "", "", errors.New("command failed")
- },
- expectError: true,
- },
- {
- name: "complex real-world example",
- value: "Bearer $(cat /tmp/token.txt | base64 -w 0)",
- shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
- if command == "cat /tmp/token.txt | base64 -w 0" {
- return "c2stYW50LXRlc3Q=\n", "", nil
- }
- return "", "", errors.New("unexpected command")
- },
- expected: "Bearer c2stYW50LXRlc3Q=",
- },
- {
- name: "environment variable with underscores and numbers",
- value: "Bearer $API_KEY_V2",
- envVars: map[string]string{"API_KEY_V2": "sk-test-123"},
- expected: "Bearer sk-test-123",
- },
- {
- name: "no substitution needed",
- value: "Bearer sk-ant-static-token",
- expected: "Bearer sk-ant-static-token",
- },
- {
- name: "incomplete variable at end",
- value: "Bearer $",
- expectError: true,
- },
- {
- name: "variable with invalid character",
- value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names
- expectError: true,
- },
- {
- name: "multiple invalid variables",
- value: "$1$2$3",
- expectError: true,
+func TestShellVariableResolver_PassesThroughLiterals(t *testing.T) {
+ t.Parallel()
+
+ fe := &fakeExpander{
+ expand: func(_ context.Context, value string, _ []string) (string, error) {
+ return value, nil
},
}
+ r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand))
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- testEnv := env.NewFromMap(tt.envVars)
- resolver := &shellVariableResolver{
- shell: &mockShell{execFunc: tt.shellFunc},
- env: testEnv,
- }
+ got, err := r.ResolveValue("plain-string")
+ require.NoError(t, err)
+ require.Equal(t, "plain-string", got)
+}
- result, err := resolver.ResolveValue(tt.value)
+func TestShellVariableResolver_WrapsErrorsWithTemplate(t *testing.T) {
+ t.Parallel()
- if tt.expectError {
- require.Error(t, err)
- } else {
- require.NoError(t, err)
- require.Equal(t, tt.expected, result)
- }
- })
+ inner := errors.New("cat: /run/secrets/x: permission denied")
+ fe := &fakeExpander{
+ expand: func(_ context.Context, _ string, _ []string) (string, error) {
+ return "", inner
+ },
}
+ r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand))
+
+ _, err := r.ResolveValue("$(cat /run/secrets/x)")
+ require.Error(t, err)
+ require.ErrorIs(t, err, inner)
+ require.Contains(t, err.Error(), "$(cat /run/secrets/x)")
+ require.Contains(t, err.Error(), "permission denied")
+}
+
+func TestSanitizeResolveError(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil passes through", func(t *testing.T) {
+ t.Parallel()
+ require.NoError(t, sanitizeResolveError("anything", nil))
+ })
+
+ t.Run("includes template and wraps inner", func(t *testing.T) {
+ t.Parallel()
+ inner := errors.New("cat: /run/secrets/x: permission denied")
+ got := sanitizeResolveError("$(cat /run/secrets/x)", inner)
+ require.Error(t, got)
+ require.ErrorIs(t, got, inner)
+ require.Contains(t, got.Error(), "$(cat /run/secrets/x)")
+ require.Contains(t, got.Error(), "permission denied")
+ })
+
+ t.Run("unwrap preserves original for errors.Is", func(t *testing.T) {
+ t.Parallel()
+ inner := errors.New("sentinel")
+ got := sanitizeResolveError("$FOO", inner)
+ require.ErrorIs(t, got, inner)
+ })
+
+ t.Run("truncates over-budget inner message", func(t *testing.T) {
+ t.Parallel()
+ // Inner message holds far more than the budget. After
+ // sanitization the rendered inner portion must not exceed
+ // maxResolveErrBytes, and the characters beyond the budget
+ // (marked by a distinct tail sentinel) must be gone.
+ const tailSentinel = "TAIL_SENTINEL_BEYOND_BUDGET"
+ body := strings.Repeat("x", maxResolveErrBytes)
+ inner := errors.New(body + tailSentinel)
+
+ got := sanitizeResolveError("$TEMPLATE", inner)
+ require.Error(t, got)
+
+ prefix := `resolving "$TEMPLATE": `
+ rendered := got.Error()
+ require.True(
+ t,
+ strings.HasPrefix(rendered, prefix),
+ "rendered error must start with template prefix",
+ )
+ innerRendered := strings.TrimPrefix(rendered, prefix)
+ require.LessOrEqual(
+ t,
+ len(innerRendered),
+ maxResolveErrBytes,
+ "inner message must be bounded to maxResolveErrBytes",
+ )
+ require.NotContains(
+ t,
+ rendered,
+ tailSentinel,
+ "content past the budget must not leak",
+ )
+ })
+
+ t.Run("replaces non-printable bytes", func(t *testing.T) {
+ t.Parallel()
+ // NUL, BEL, ESC, DEL, and a UTF-8 high byte should all be
+ // scrubbed to '?'. Tab and newline are preserved because
+ // they show up legitimately in command stderr.
+ inner := errors.New("ok\x00bad\x07\x1b\x7f\xffend\ttab\nline")
+ got := sanitizeResolveError("$T", inner)
+ rendered := got.Error()
+
+ require.NotContains(t, rendered, "\x00")
+ require.NotContains(t, rendered, "\x07")
+ require.NotContains(t, rendered, "\x1b")
+ require.NotContains(t, rendered, "\x7f")
+ require.NotContains(t, rendered, "\xff")
+ require.Contains(t, rendered, "ok?bad????end\ttab\nline")
+ })
+
+ t.Run("scrubbing does not depend on shell.ExpandValue upstream", func(t *testing.T) {
+ t.Parallel()
+ // A custom Expander can inject arbitrary error text. The
+ // config-layer helper is the single chokepoint; it must
+ // bound + scrub regardless of the error source.
+ nasty := strings.Repeat("A", maxResolveErrBytes+64) + "\x00BEYOND"
+ fe := &fakeExpander{
+ expand: func(_ context.Context, _ string, _ []string) (string, error) {
+ return "", errors.New(nasty)
+ },
+ }
+ r := NewShellVariableResolver(env.NewFromMap(nil), WithExpander(fe.Expand))
+
+ _, err := r.ResolveValue("$T")
+ require.Error(t, err)
+ require.NotContains(t, err.Error(), "BEYOND", "over-budget tail must not leak")
+ require.NotContains(t, err.Error(), "\x00", "non-printables must be scrubbed")
+ })
+}
+
+func TestScrubErrorMessage(t *testing.T) {
+ t.Parallel()
+
+ t.Run("bounds output to maxResolveErrBytes", func(t *testing.T) {
+ t.Parallel()
+ got := scrubErrorMessage(strings.Repeat("a", maxResolveErrBytes*3))
+ require.Len(t, got, maxResolveErrBytes)
+ })
+
+ t.Run("preserves printable ASCII tab and newline", func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, "a\tb\nc d!", scrubErrorMessage("a\tb\nc d!"))
+ })
+
+ t.Run("replaces control and non-ASCII bytes", func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, "a?b??c", scrubErrorMessage("a\x01b\x1b\xe2c"))
+ })
}
func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {