From f6ef43e9de46e6cf04a44a901389806a57e61b3e Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 26 Apr 2026 22:01:32 -0400 Subject: [PATCH] refactor(shell): extract stateless run entrypoint This is necessary for sharing the builtin shell with hooks. --- internal/shell/run.go | 152 ++++++++++++++++++++++ internal/shell/run_test.go | 252 +++++++++++++++++++++++++++++++++++++ internal/shell/shell.go | 65 +--------- 3 files changed, 409 insertions(+), 60 deletions(-) create mode 100644 internal/shell/run.go create mode 100644 internal/shell/run_test.go diff --git a/internal/shell/run.go b/internal/shell/run.go new file mode 100644 index 0000000000000000000000000000000000000000..fb785a7b04dee6a64d6ad8e704ae19ff3525561e --- /dev/null +++ b/internal/shell/run.go @@ -0,0 +1,152 @@ +package shell + +import ( + "context" + "fmt" + "io" + "strings" + + "mvdan.cc/sh/moreinterp/coreutils" + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" + "mvdan.cc/sh/v3/syntax" +) + +// RunOptions configures a single stateless shell execution via [Run]. +// +// The zero value is not useful; at minimum Command must be set. Stdin, +// Stdout, and Stderr may be nil (nil readers/writers are treated as +// empty/discard). BlockFuncs may be nil to disable block-list enforcement — +// hooks use this to run user-authored commands with the same trust level as +// a shell alias. +type RunOptions struct { + // Command is the shell source to parse and execute. + Command string + // Cwd is the working directory for the execution. Required: callers + // must supply a non-empty value. Run does not silently fall back to + // the Crush process cwd — hooks and the bash tool have different + // notions of "default" and each owns that decision. + Cwd string + // Env is the full environment visible to the command. The caller is + // responsible for inheriting from os.Environ() if that's desired. + Env []string + // Stdin is the command's standard input. nil is equivalent to an empty + // input stream. + Stdin io.Reader + // Stdout receives the command's standard output. nil discards output. + Stdout io.Writer + // Stderr receives the command's standard error. nil discards output. + Stderr io.Writer + // BlockFuncs is an optional list of deny-list matchers applied before + // each command reaches the exec layer. nil disables blocking entirely. + BlockFuncs []BlockFunc +} + +// Run parses and executes a shell command using the same mvdan.cc/sh +// interpreter stack that the stateful [Shell] type uses (builtins, +// optional block list, optional Go coreutils). It is safe to call +// concurrently from multiple goroutines: each call builds its own +// [interp.Runner] and shares no state with other callers or with any +// [Shell] instance. +// +// Errors returned from the command itself (non-zero exit, context +// cancellation, parse failures) follow the same conventions as +// [Shell.Exec]: inspect with [IsInterrupt] and [ExitCode]. +func Run(ctx context.Context, opts RunOptions) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("command execution panic: %v", r) + } + }() + + if opts.Cwd == "" { + return fmt.Errorf("shell.Run: Cwd is required") + } + + stdout := opts.Stdout + if stdout == nil { + stdout = io.Discard + } + stderr := opts.Stderr + if stderr == nil { + stderr = io.Discard + } + + line, err := syntax.NewParser().Parse(strings.NewReader(opts.Command), "") + if err != nil { + return fmt.Errorf("could not parse command: %w", err) + } + + runner, err := newRunner(opts.Cwd, opts.Env, opts.Stdin, stdout, stderr, opts.BlockFuncs) + if err != nil { + return fmt.Errorf("could not run command: %w", err) + } + + return runner.Run(ctx, line) +} + +// newRunner constructs an [interp.Runner] configured with the standard +// Crush handler stack. Shared by the stateless [Run] entrypoint and the +// stateful [Shell] so the two surfaces cannot drift. +func newRunner(cwd string, env []string, stdin io.Reader, stdout, stderr io.Writer, blockFuncs []BlockFunc) (*interp.Runner, error) { + return interp.New( + interp.StdIO(stdin, stdout, stderr), + interp.Interactive(false), + interp.Env(expand.ListEnviron(env...)), + interp.Dir(cwd), + interp.ExecHandlers(standardHandlers(blockFuncs)...), + ) +} + +// standardHandlers returns the exec-handler middleware chain used by both +// [Run] and [Shell]. Order matters: builtins first (so Crush's in-process +// jq wins over any PATH binary), then the block list, then optional Go +// coreutils. Future middleware (shebang dispatch, etc.) inserts here. +func standardHandlers(blockFuncs []BlockFunc) []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{ + builtinHandler(), + blockHandler(blockFuncs), + } + if useGoCoreUtils { + handlers = append(handlers, coreutils.ExecHandler) + } + return handlers +} + +// builtinHandler returns middleware that dispatches recognized Crush +// builtins to their in-process Go implementations. Currently: jq. +func builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + switch args[0] { + case "jq": + hc := interp.HandlerCtx(ctx) + return handleJQ(args, hc.Stdin, hc.Stdout, hc.Stderr) + default: + return next(ctx, args) + } + } + } +} + +// blockHandler returns middleware that rejects commands matched by any of +// the provided [BlockFunc]s before they reach the underlying exec path. +// A nil or empty blockFuncs slice is a no-op. +func blockHandler(blockFuncs []BlockFunc) func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + for _, blockFunc := range blockFuncs { + if blockFunc(args) { + return fmt.Errorf("command is not allowed for security reasons: %q", args[0]) + } + } + return next(ctx, args) + } + } +} diff --git a/internal/shell/run_test.go b/internal/shell/run_test.go new file mode 100644 index 0000000000000000000000000000000000000000..15253a1f624bda0316c935ebe52f0d1519d60e63 --- /dev/null +++ b/internal/shell/run_test.go @@ -0,0 +1,252 @@ +package shell + +import ( + "bytes" + "context" + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" +) + +func TestRun_Echo(t *testing.T) { + var stdout, stderr bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "echo hi", + Cwd: t.TempDir(), + Stdout: &stdout, + Stderr: &stderr, + }) + if err != nil { + t.Fatalf("Run returned error: %v (stderr=%q)", err, stderr.String()) + } + if got := stdout.String(); got != "hi\n" { + t.Fatalf("stdout = %q, want %q", got, "hi\n") + } +} + +func TestRun_ExitCode(t *testing.T) { + err := Run(t.Context(), RunOptions{ + Command: "exit 7", + Cwd: t.TempDir(), + }) + if err == nil { + t.Fatal("expected error for exit 7, got nil") + } + if code := ExitCode(err); code != 7 { + t.Fatalf("ExitCode = %d, want 7", code) + } +} + +func TestRun_Stdin(t *testing.T) { + // Use the `read` shell builtin so the test doesn't depend on any + // external binary being on PATH (we pass an empty Env here). + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "read line; echo got:$line", + Cwd: t.TempDir(), + Stdin: strings.NewReader("hello\n"), + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); got != "got:hello\n" { + t.Fatalf("stdout = %q, want %q", got, "got:hello\n") + } +} + +func TestRun_Env(t *testing.T) { + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: `echo "$FOO"`, + Cwd: t.TempDir(), + Env: []string{"FOO=bar"}, + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); got != "bar\n" { + t.Fatalf("stdout = %q, want %q", got, "bar\n") + } +} + +func TestRun_Cwd(t *testing.T) { + dir := t.TempDir() + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "pwd", + Cwd: dir, + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + // mvdan's pwd builtin resolves symlinks (e.g. /var -> /private/var on + // macOS). Compare against a suffix so we don't get bitten by that. + got := strings.TrimRight(stdout.String(), "\n") + if !strings.HasSuffix(got, dir) && !strings.HasSuffix(dir, got) { + t.Fatalf("pwd = %q, want it to match %q", got, dir) + } +} + +func TestRun_JqBuiltin(t *testing.T) { + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: `echo '{"a":1}' | jq .a`, + Cwd: t.TempDir(), + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); got != "1\n" { + t.Fatalf("stdout = %q, want %q", got, "1\n") + } +} + +func TestRun_ParallelIsolation(t *testing.T) { + const n = 10 + var wg sync.WaitGroup + wg.Add(n) + errs := make([]error, n) + outs := make([]string, n) + dirs := make([]string, n) + for i := range n { + dirs[i] = t.TempDir() + go func(i int) { + defer wg.Done() + var stdout bytes.Buffer + errs[i] = Run(t.Context(), RunOptions{ + Command: `echo "$MARKER"`, + Cwd: dirs[i], + Env: []string{fmt.Sprintf("MARKER=id-%d", i)}, + Stdout: &stdout, + }) + outs[i] = stdout.String() + }(i) + } + wg.Wait() + for i := range n { + if errs[i] != nil { + t.Errorf("goroutine %d: err = %v", i, errs[i]) + continue + } + want := fmt.Sprintf("id-%d\n", i) + if outs[i] != want { + t.Errorf("goroutine %d: stdout = %q, want %q", i, outs[i], want) + } + } +} + +// TestRun_CtxCancel_BusyLoop verifies that a pure-shell loop respects ctx +// cancellation. mvdan's interpreter checks ctx between statements, so this +// should return quickly even without any external command. The test bounds +// its own wait via a select so a regression can't hang CI. +func TestRun_CtxCancel_BusyLoop(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond) + t.Cleanup(cancel) + + done := make(chan error, 1) + go func() { + done <- Run(ctx, RunOptions{ + Command: "while true; do :; done", + Cwd: t.TempDir(), + }) + }() + + select { + case err := <-done: + if !IsInterrupt(err) && !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected interrupt/deadline error, got: %v", err) + } + case <-time.After(1500 * time.Millisecond): + t.Fatal("Run did not return within 1.5s after ctx cancel") + } +} + +// TestRun_CtxCancel_ExternalSleep verifies ctx cancellation reaches an +// external process via mvdan's default exec. Uses sleep, which lives in +// coreutils on Windows and /bin on Unix. +func TestRun_CtxCancel_ExternalSleep(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond) + t.Cleanup(cancel) + + done := make(chan error, 1) + start := time.Now() + go func() { + done <- Run(ctx, RunOptions{ + Command: "sleep 30", + Cwd: t.TempDir(), + }) + }() + + select { + case err := <-done: + elapsed := time.Since(start) + if elapsed > time.Second { + t.Fatalf("sleep took too long to cancel: %v", elapsed) + } + if err == nil { + t.Fatal("expected non-nil error from cancelled sleep") + } + case <-time.After(time.Second): + t.Fatal("Run did not return within 1s after ctx cancel") + } +} + +func TestRun_ParseError(t *testing.T) { + err := Run(t.Context(), RunOptions{ + Command: "echo 'unterminated", + Cwd: t.TempDir(), + }) + if err == nil { + t.Fatal("expected parse error, got nil") + } + if !strings.Contains(err.Error(), "parse") { + t.Fatalf("error should mention parse: %v", err) + } +} + +func TestRun_BlockFuncs(t *testing.T) { + block := CommandsBlocker([]string{"forbidden"}) + var stderr bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "forbidden", + Cwd: t.TempDir(), + Stderr: &stderr, + BlockFuncs: []BlockFunc{block}, + }) + if err == nil { + t.Fatal("expected error when running blocked command") + } + if !strings.Contains(err.Error(), "not allowed") { + t.Fatalf("expected 'not allowed' error, got: %v", err) + } +} + +func TestRun_RequiresCwd(t *testing.T) { + err := Run(t.Context(), RunOptions{ + Command: "echo hi", + }) + if err == nil { + t.Fatal("expected error when Cwd is empty, got nil") + } + if !strings.Contains(err.Error(), "Cwd is required") { + t.Fatalf("error should mention Cwd requirement: %v", err) + } +} + +func TestRun_DiscardsNilWriters(t *testing.T) { + // No panic when Stdout/Stderr are nil. + err := Run(t.Context(), RunOptions{ + Command: "echo hi; echo err >&2", + Cwd: t.TempDir(), + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 5ca0a25b57d5e80778a6fa95bdd3eb991638db4b..6cb9d9f34301cca82ac35075010c199990300eb8 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -21,8 +21,6 @@ import ( "sync" "github.com/charmbracelet/x/exp/slice" - "mvdan.cc/sh/moreinterp/coreutils" - "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/interp" "mvdan.cc/sh/v3/syntax" ) @@ -226,52 +224,10 @@ func splitArgsFlags(parts []string) (args []string, flags []string) { return args, flags } -func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - return func(ctx context.Context, args []string) error { - if len(args) == 0 { - return next(ctx, args) - } - - for _, blockFunc := range s.blockFuncs { - if blockFunc(args) { - return fmt.Errorf("command is not allowed for security reasons: %q", args[0]) - } - } - - return next(ctx, args) - } - } -} - -func (s *Shell) builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - return func(ctx context.Context, args []string) error { - if len(args) == 0 { - return next(ctx, args) - } - - // Builtins. - switch args[0] { - case "jq": - hc := interp.HandlerCtx(ctx) - return handleJQ(args, hc.Stdin, hc.Stdout, hc.Stderr) - default: - return next(ctx, args) - } - } - } -} - -// newInterp creates a new interpreter with the current shell state -func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) { - return interp.New( - interp.StdIO(nil, stdout, stderr), - interp.Interactive(false), - interp.Env(expand.ListEnviron(s.env...)), - interp.Dir(s.cwd), - interp.ExecHandlers(s.execHandlers()...), - ) +// newInterp creates a new interpreter with the current shell state. A nil +// stdin is equivalent to an empty input stream. +func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) { + return newRunner(s.cwd, s.env, stdin, stdout, stderr, s.blockFuncs) } // updateShellFromRunner updates the shell from the interpreter after execution. @@ -303,7 +259,7 @@ func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr i return fmt.Errorf("could not parse command: %w", err) } - runner, err = s.newInterp(stdout, stderr) + runner, err = s.newInterp(nil, stdout, stderr) if err != nil { return fmt.Errorf("could not run command: %w", err) } @@ -324,17 +280,6 @@ func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr i return s.execCommon(ctx, command, stdout, stderr) } -func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{ - s.builtinHandler(), - s.blockHandler(), - } - if useGoCoreUtils { - handlers = append(handlers, coreutils.ExecHandler) - } - return handlers -} - // IsInterrupt checks if an error is due to interruption func IsInterrupt(err error) bool { return errors.Is(err, context.Canceled) ||