refactor(shell): extract stateless run entrypoint

Christian Rocha created

This is necessary for sharing the builtin shell with hooks.

Change summary

internal/shell/run.go      | 152 ++++++++++++++++++++++++
internal/shell/run_test.go | 252 ++++++++++++++++++++++++++++++++++++++++
internal/shell/shell.go    |  65 ---------
3 files changed, 409 insertions(+), 60 deletions(-)

Detailed changes

internal/shell/run.go 🔗

@@ -0,0 +1,152 @@
+package shell
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"strings"
+
+	"mvdan.cc/sh/moreinterp/coreutils"
+	"mvdan.cc/sh/v3/expand"
+	"mvdan.cc/sh/v3/interp"
+	"mvdan.cc/sh/v3/syntax"
+)
+
+// RunOptions configures a single stateless shell execution via [Run].
+//
+// The zero value is not useful; at minimum Command must be set. Stdin,
+// Stdout, and Stderr may be nil (nil readers/writers are treated as
+// empty/discard). BlockFuncs may be nil to disable block-list enforcement —
+// hooks use this to run user-authored commands with the same trust level as
+// a shell alias.
+type RunOptions struct {
+	// Command is the shell source to parse and execute.
+	Command string
+	// Cwd is the working directory for the execution. Required: callers
+	// must supply a non-empty value. Run does not silently fall back to
+	// the Crush process cwd — hooks and the bash tool have different
+	// notions of "default" and each owns that decision.
+	Cwd string
+	// Env is the full environment visible to the command. The caller is
+	// responsible for inheriting from os.Environ() if that's desired.
+	Env []string
+	// Stdin is the command's standard input. nil is equivalent to an empty
+	// input stream.
+	Stdin io.Reader
+	// Stdout receives the command's standard output. nil discards output.
+	Stdout io.Writer
+	// Stderr receives the command's standard error. nil discards output.
+	Stderr io.Writer
+	// BlockFuncs is an optional list of deny-list matchers applied before
+	// each command reaches the exec layer. nil disables blocking entirely.
+	BlockFuncs []BlockFunc
+}
+
+// Run parses and executes a shell command using the same mvdan.cc/sh
+// interpreter stack that the stateful [Shell] type uses (builtins,
+// optional block list, optional Go coreutils). It is safe to call
+// concurrently from multiple goroutines: each call builds its own
+// [interp.Runner] and shares no state with other callers or with any
+// [Shell] instance.
+//
+// Errors returned from the command itself (non-zero exit, context
+// cancellation, parse failures) follow the same conventions as
+// [Shell.Exec]: inspect with [IsInterrupt] and [ExitCode].
+func Run(ctx context.Context, opts RunOptions) (err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("command execution panic: %v", r)
+		}
+	}()
+
+	if opts.Cwd == "" {
+		return fmt.Errorf("shell.Run: Cwd is required")
+	}
+
+	stdout := opts.Stdout
+	if stdout == nil {
+		stdout = io.Discard
+	}
+	stderr := opts.Stderr
+	if stderr == nil {
+		stderr = io.Discard
+	}
+
+	line, err := syntax.NewParser().Parse(strings.NewReader(opts.Command), "")
+	if err != nil {
+		return fmt.Errorf("could not parse command: %w", err)
+	}
+
+	runner, err := newRunner(opts.Cwd, opts.Env, opts.Stdin, stdout, stderr, opts.BlockFuncs)
+	if err != nil {
+		return fmt.Errorf("could not run command: %w", err)
+	}
+
+	return runner.Run(ctx, line)
+}
+
+// newRunner constructs an [interp.Runner] configured with the standard
+// Crush handler stack. Shared by the stateless [Run] entrypoint and the
+// stateful [Shell] so the two surfaces cannot drift.
+func newRunner(cwd string, env []string, stdin io.Reader, stdout, stderr io.Writer, blockFuncs []BlockFunc) (*interp.Runner, error) {
+	return interp.New(
+		interp.StdIO(stdin, stdout, stderr),
+		interp.Interactive(false),
+		interp.Env(expand.ListEnviron(env...)),
+		interp.Dir(cwd),
+		interp.ExecHandlers(standardHandlers(blockFuncs)...),
+	)
+}
+
+// standardHandlers returns the exec-handler middleware chain used by both
+// [Run] and [Shell]. Order matters: builtins first (so Crush's in-process
+// jq wins over any PATH binary), then the block list, then optional Go
+// coreutils. Future middleware (shebang dispatch, etc.) inserts here.
+func standardHandlers(blockFuncs []BlockFunc) []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
+		builtinHandler(),
+		blockHandler(blockFuncs),
+	}
+	if useGoCoreUtils {
+		handlers = append(handlers, coreutils.ExecHandler)
+	}
+	return handlers
+}
+
+// builtinHandler returns middleware that dispatches recognized Crush
+// builtins to their in-process Go implementations. Currently: jq.
+func builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+		return func(ctx context.Context, args []string) error {
+			if len(args) == 0 {
+				return next(ctx, args)
+			}
+			switch args[0] {
+			case "jq":
+				hc := interp.HandlerCtx(ctx)
+				return handleJQ(args, hc.Stdin, hc.Stdout, hc.Stderr)
+			default:
+				return next(ctx, args)
+			}
+		}
+	}
+}
+
+// blockHandler returns middleware that rejects commands matched by any of
+// the provided [BlockFunc]s before they reach the underlying exec path.
+// A nil or empty blockFuncs slice is a no-op.
+func blockHandler(blockFuncs []BlockFunc) func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+		return func(ctx context.Context, args []string) error {
+			if len(args) == 0 {
+				return next(ctx, args)
+			}
+			for _, blockFunc := range blockFuncs {
+				if blockFunc(args) {
+					return fmt.Errorf("command is not allowed for security reasons: %q", args[0])
+				}
+			}
+			return next(ctx, args)
+		}
+	}
+}

internal/shell/run_test.go 🔗

@@ -0,0 +1,252 @@
+package shell
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+)
+
+func TestRun_Echo(t *testing.T) {
+	var stdout, stderr bytes.Buffer
+	err := Run(t.Context(), RunOptions{
+		Command: "echo hi",
+		Cwd:     t.TempDir(),
+		Stdout:  &stdout,
+		Stderr:  &stderr,
+	})
+	if err != nil {
+		t.Fatalf("Run returned error: %v (stderr=%q)", err, stderr.String())
+	}
+	if got := stdout.String(); got != "hi\n" {
+		t.Fatalf("stdout = %q, want %q", got, "hi\n")
+	}
+}
+
+func TestRun_ExitCode(t *testing.T) {
+	err := Run(t.Context(), RunOptions{
+		Command: "exit 7",
+		Cwd:     t.TempDir(),
+	})
+	if err == nil {
+		t.Fatal("expected error for exit 7, got nil")
+	}
+	if code := ExitCode(err); code != 7 {
+		t.Fatalf("ExitCode = %d, want 7", code)
+	}
+}
+
+func TestRun_Stdin(t *testing.T) {
+	// Use the `read` shell builtin so the test doesn't depend on any
+	// external binary being on PATH (we pass an empty Env here).
+	var stdout bytes.Buffer
+	err := Run(t.Context(), RunOptions{
+		Command: "read line; echo got:$line",
+		Cwd:     t.TempDir(),
+		Stdin:   strings.NewReader("hello\n"),
+		Stdout:  &stdout,
+	})
+	if err != nil {
+		t.Fatalf("Run returned error: %v", err)
+	}
+	if got := stdout.String(); got != "got:hello\n" {
+		t.Fatalf("stdout = %q, want %q", got, "got:hello\n")
+	}
+}
+
+func TestRun_Env(t *testing.T) {
+	var stdout bytes.Buffer
+	err := Run(t.Context(), RunOptions{
+		Command: `echo "$FOO"`,
+		Cwd:     t.TempDir(),
+		Env:     []string{"FOO=bar"},
+		Stdout:  &stdout,
+	})
+	if err != nil {
+		t.Fatalf("Run returned error: %v", err)
+	}
+	if got := stdout.String(); got != "bar\n" {
+		t.Fatalf("stdout = %q, want %q", got, "bar\n")
+	}
+}
+
+func TestRun_Cwd(t *testing.T) {
+	dir := t.TempDir()
+	var stdout bytes.Buffer
+	err := Run(t.Context(), RunOptions{
+		Command: "pwd",
+		Cwd:     dir,
+		Stdout:  &stdout,
+	})
+	if err != nil {
+		t.Fatalf("Run returned error: %v", err)
+	}
+	// mvdan's pwd builtin resolves symlinks (e.g. /var -> /private/var on
+	// macOS). Compare against a suffix so we don't get bitten by that.
+	got := strings.TrimRight(stdout.String(), "\n")
+	if !strings.HasSuffix(got, dir) && !strings.HasSuffix(dir, got) {
+		t.Fatalf("pwd = %q, want it to match %q", got, dir)
+	}
+}
+
+func TestRun_JqBuiltin(t *testing.T) {
+	var stdout bytes.Buffer
+	err := Run(t.Context(), RunOptions{
+		Command: `echo '{"a":1}' | jq .a`,
+		Cwd:     t.TempDir(),
+		Stdout:  &stdout,
+	})
+	if err != nil {
+		t.Fatalf("Run returned error: %v", err)
+	}
+	if got := stdout.String(); got != "1\n" {
+		t.Fatalf("stdout = %q, want %q", got, "1\n")
+	}
+}
+
+func TestRun_ParallelIsolation(t *testing.T) {
+	const n = 10
+	var wg sync.WaitGroup
+	wg.Add(n)
+	errs := make([]error, n)
+	outs := make([]string, n)
+	dirs := make([]string, n)
+	for i := range n {
+		dirs[i] = t.TempDir()
+		go func(i int) {
+			defer wg.Done()
+			var stdout bytes.Buffer
+			errs[i] = Run(t.Context(), RunOptions{
+				Command: `echo "$MARKER"`,
+				Cwd:     dirs[i],
+				Env:     []string{fmt.Sprintf("MARKER=id-%d", i)},
+				Stdout:  &stdout,
+			})
+			outs[i] = stdout.String()
+		}(i)
+	}
+	wg.Wait()
+	for i := range n {
+		if errs[i] != nil {
+			t.Errorf("goroutine %d: err = %v", i, errs[i])
+			continue
+		}
+		want := fmt.Sprintf("id-%d\n", i)
+		if outs[i] != want {
+			t.Errorf("goroutine %d: stdout = %q, want %q", i, outs[i], want)
+		}
+	}
+}
+
+// TestRun_CtxCancel_BusyLoop verifies that a pure-shell loop respects ctx
+// cancellation. mvdan's interpreter checks ctx between statements, so this
+// should return quickly even without any external command. The test bounds
+// its own wait via a select so a regression can't hang CI.
+func TestRun_CtxCancel_BusyLoop(t *testing.T) {
+	ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
+	t.Cleanup(cancel)
+
+	done := make(chan error, 1)
+	go func() {
+		done <- Run(ctx, RunOptions{
+			Command: "while true; do :; done",
+			Cwd:     t.TempDir(),
+		})
+	}()
+
+	select {
+	case err := <-done:
+		if !IsInterrupt(err) && !errors.Is(err, context.DeadlineExceeded) {
+			t.Fatalf("expected interrupt/deadline error, got: %v", err)
+		}
+	case <-time.After(1500 * time.Millisecond):
+		t.Fatal("Run did not return within 1.5s after ctx cancel")
+	}
+}
+
+// TestRun_CtxCancel_ExternalSleep verifies ctx cancellation reaches an
+// external process via mvdan's default exec. Uses sleep, which lives in
+// coreutils on Windows and /bin on Unix.
+func TestRun_CtxCancel_ExternalSleep(t *testing.T) {
+	ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond)
+	t.Cleanup(cancel)
+
+	done := make(chan error, 1)
+	start := time.Now()
+	go func() {
+		done <- Run(ctx, RunOptions{
+			Command: "sleep 30",
+			Cwd:     t.TempDir(),
+		})
+	}()
+
+	select {
+	case err := <-done:
+		elapsed := time.Since(start)
+		if elapsed > time.Second {
+			t.Fatalf("sleep took too long to cancel: %v", elapsed)
+		}
+		if err == nil {
+			t.Fatal("expected non-nil error from cancelled sleep")
+		}
+	case <-time.After(time.Second):
+		t.Fatal("Run did not return within 1s after ctx cancel")
+	}
+}
+
+func TestRun_ParseError(t *testing.T) {
+	err := Run(t.Context(), RunOptions{
+		Command: "echo 'unterminated",
+		Cwd:     t.TempDir(),
+	})
+	if err == nil {
+		t.Fatal("expected parse error, got nil")
+	}
+	if !strings.Contains(err.Error(), "parse") {
+		t.Fatalf("error should mention parse: %v", err)
+	}
+}
+
+func TestRun_BlockFuncs(t *testing.T) {
+	block := CommandsBlocker([]string{"forbidden"})
+	var stderr bytes.Buffer
+	err := Run(t.Context(), RunOptions{
+		Command:    "forbidden",
+		Cwd:        t.TempDir(),
+		Stderr:     &stderr,
+		BlockFuncs: []BlockFunc{block},
+	})
+	if err == nil {
+		t.Fatal("expected error when running blocked command")
+	}
+	if !strings.Contains(err.Error(), "not allowed") {
+		t.Fatalf("expected 'not allowed' error, got: %v", err)
+	}
+}
+
+func TestRun_RequiresCwd(t *testing.T) {
+	err := Run(t.Context(), RunOptions{
+		Command: "echo hi",
+	})
+	if err == nil {
+		t.Fatal("expected error when Cwd is empty, got nil")
+	}
+	if !strings.Contains(err.Error(), "Cwd is required") {
+		t.Fatalf("error should mention Cwd requirement: %v", err)
+	}
+}
+
+func TestRun_DiscardsNilWriters(t *testing.T) {
+	// No panic when Stdout/Stderr are nil.
+	err := Run(t.Context(), RunOptions{
+		Command: "echo hi; echo err >&2",
+		Cwd:     t.TempDir(),
+	})
+	if err != nil {
+		t.Fatalf("Run returned error: %v", err)
+	}
+}

internal/shell/shell.go 🔗

@@ -21,8 +21,6 @@ import (
 	"sync"
 
 	"github.com/charmbracelet/x/exp/slice"
-	"mvdan.cc/sh/moreinterp/coreutils"
-	"mvdan.cc/sh/v3/expand"
 	"mvdan.cc/sh/v3/interp"
 	"mvdan.cc/sh/v3/syntax"
 )
@@ -226,52 +224,10 @@ func splitArgsFlags(parts []string) (args []string, flags []string) {
 	return args, flags
 }
 
-func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
-	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
-		return func(ctx context.Context, args []string) error {
-			if len(args) == 0 {
-				return next(ctx, args)
-			}
-
-			for _, blockFunc := range s.blockFuncs {
-				if blockFunc(args) {
-					return fmt.Errorf("command is not allowed for security reasons: %q", args[0])
-				}
-			}
-
-			return next(ctx, args)
-		}
-	}
-}
-
-func (s *Shell) builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
-	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
-		return func(ctx context.Context, args []string) error {
-			if len(args) == 0 {
-				return next(ctx, args)
-			}
-
-			// Builtins.
-			switch args[0] {
-			case "jq":
-				hc := interp.HandlerCtx(ctx)
-				return handleJQ(args, hc.Stdin, hc.Stdout, hc.Stderr)
-			default:
-				return next(ctx, args)
-			}
-		}
-	}
-}
-
-// newInterp creates a new interpreter with the current shell state
-func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) {
-	return interp.New(
-		interp.StdIO(nil, stdout, stderr),
-		interp.Interactive(false),
-		interp.Env(expand.ListEnviron(s.env...)),
-		interp.Dir(s.cwd),
-		interp.ExecHandlers(s.execHandlers()...),
-	)
+// newInterp creates a new interpreter with the current shell state. A nil
+// stdin is equivalent to an empty input stream.
+func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
+	return newRunner(s.cwd, s.env, stdin, stdout, stderr, s.blockFuncs)
 }
 
 // updateShellFromRunner updates the shell from the interpreter after execution.
@@ -303,7 +259,7 @@ func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr i
 		return fmt.Errorf("could not parse command: %w", err)
 	}
 
-	runner, err = s.newInterp(stdout, stderr)
+	runner, err = s.newInterp(nil, stdout, stderr)
 	if err != nil {
 		return fmt.Errorf("could not run command: %w", err)
 	}
@@ -324,17 +280,6 @@ func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr i
 	return s.execCommon(ctx, command, stdout, stderr)
 }
 
-func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
-	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
-		s.builtinHandler(),
-		s.blockHandler(),
-	}
-	if useGoCoreUtils {
-		handlers = append(handlers, coreutils.ExecHandler)
-	}
-	return handlers
-}
-
 // IsInterrupt checks if an error is due to interruption
 func IsInterrupt(err error) bool {
 	return errors.Is(err, context.Canceled) ||