diff --git a/.agents/skills/shell-builtins/SKILL.md b/.agents/skills/shell-builtins/SKILL.md index a9914a08aa1b506006fcb7ef35ded092efc7f895..f5953d9ff1f00631361e8598d9401a1f8b1a809f 100644 --- a/.agents/skills/shell-builtins/SKILL.md +++ b/.agents/skills/shell-builtins/SKILL.md @@ -11,38 +11,59 @@ emulation. Commands can be intercepted before they reach the OS by adding ## How Builtins Work -Builtins live in `Shell.builtinHandler()` in `internal/shell/shell.go`. -This is an `interp.ExecHandlerFunc` middleware registered in -`execHandlers()` **before** the block handler, so builtins run even for -commands that would otherwise be blocked. +Builtins live in `builtinHandler()` in `internal/shell/run.go`. This is an +`interp.ExecHandlerFunc` middleware registered in `standardHandlers()` +**before** the block handler, so builtins run even for commands that would +otherwise be blocked. The same handler chain is shared by the stateful +`Shell` type and the stateless `Run` entrypoint used by the hook runner, +so builtins are available identically in the `bash` tool and in hooks. The handler is a switch on `args[0]`. Each case either handles the command inline or delegates to a helper function. ## Adding a New Builtin -1. **Add the case** to the switch in `builtinHandler()` in `shell.go`. +1. **Add the case** to the switch in `builtinHandler()` in `run.go`. 2. **Get I/O from the handler context**, not from `os.Stdin`/`os.Stdout`. This ensures the builtin works with pipes and redirections: ```go case "mycommand": hc := interp.HandlerCtx(ctx) - return handleMyCommand(args, hc.Stdin, hc.Stdout, hc.Stderr) + return handleMyCommand(ctx, args, hc.Stdin, hc.Stdout, hc.Stderr) ``` 3. **Implement the handler** in its own file (e.g., - `internal/shell/mycommand.go`). The function signature should accept - args, stdin, stdout, and stderr: + `internal/shell/mycommand.go`). The function signature must accept a + `context.Context` as the first parameter, plus args, stdin, stdout, and + stderr: ```go - func handleMyCommand(args []string, stdin io.Reader, stdout, stderr io.Writer) error { + func handleMyCommand(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { // args[0] is the command name ("mycommand"), args[1:] are arguments. // Write output to stdout, errors to stderr. // Return nil on success, or interp.ExitStatus(n) for non-zero exit codes. } ``` -4. **Return values**: return `nil` for success, `interp.ExitStatus(n)` for - non-zero exit codes. Write error messages to `stderr` before returning. -5. **No extra wiring needed** — `builtinHandler()` is already registered - in `execHandlers()`. +4. **Poll `ctx` in every unbounded loop.** Builtins that iterate over + input, emit values in a generator-style loop, or do any other work + that can exceed a few milliseconds MUST check `ctx.Err()` on each + iteration and return it verbatim when non-nil. Hook timeouts rely on + this: an unbounded builtin that never polls ctx cannot be interrupted + by a hook's `timeout_sec`, and the hook runner will have to abandon + the goroutine (see `internal/hooks/runner.go`). Returning `ctx.Err()` + (not `interp.ExitStatus(n)`) lets callers distinguish "command exited + non-zero" from "we ran out of time". + ```go + for _, item := range items { + if err := ctx.Err(); err != nil { + return err + } + // ... process item + } + ``` +5. **Return values**: return `nil` for success, `interp.ExitStatus(n)` for + non-zero exit codes, or `ctx.Err()` on cancellation. Write error + messages to `stderr` before returning. +6. **No extra wiring needed** — `builtinHandler()` is already registered + in `standardHandlers()`. ## Existing Builtins diff --git a/docs/hooks/FUTURE.md b/docs/hooks/FUTURE.md index 0ddbfe2905a9387fcc37316dba631bd1d2039a02..bb63185785e6a28c8d0f2c8dda99c93b1845fed4 100644 --- a/docs/hooks/FUTURE.md +++ b/docs/hooks/FUTURE.md @@ -256,121 +256,5 @@ Reuses the universal rules: ## Cross-platform shell (Windows support) -**Status:** not implemented. - -### Problem - -Today the hook runner uses `exec.Command("sh", "-c", hook.Command)`. On Windows -this fails without WSL or Git Bash on PATH. Even with `sh.exe` available, -Windows has no kernel shebang handling — `./hooks/foo.sh` can't be exec'd -directly the way it can on Unix. Hooks are effectively Unix-only. - -### Approach - -Keep the `command` field as a string. Tokenize it shell-style, examine -`argv[0]`, and branch: - -- If `argv[0]` starts with `./`, `../`, `/`, or `~/` — treat it as a **file - invocation**. Read the first ≤128 bytes, parse a shebang if present, and - dispatch to the named interpreter via `os/exec`. Extra args from the command - string pass through to the interpreter. -- Otherwise — treat the whole string as **shell code** and hand it to mvdan's - in-process interpreter. mvdan resolves `node`, `bash`, `jq`, builtins, - pipelines, redirects, etc. via its own exec handler. - -No sentinel: a script with no shebang defaults to mvdan. A script with an -explicit shebang (`#!/bin/bash`, `#!/usr/bin/env python3`, etc.) uses the named -interpreter, which the user is responsible for having on PATH. Same contract on -every platform. - -### Dispatch examples - -| `command` | `argv[0]` | Route | -| ---------------------------------------- | -------------- | ------------------------ | -| `ls -la` | `ls` | mvdan | -| `bash -c 'ls'` | `bash` | mvdan (which execs bash) | -| `node ./script.js` | `node` | mvdan (which execs node) | -| `./script.sh` (no shebang) | `./script.sh` | mvdan, fed the file | -| `./script.sh` (`#!/bin/bash`) | `./script.sh` | `bash ./script.sh` | -| `./script.py` (`#!/usr/bin/env python3`) | `./script.py` | `python3 ./script.py` | -| `./script.exe` | `./script.exe` | `os/exec` direct | - -### Contract on Windows - -- Inline shell runs through mvdan natively. No external dependency. -- Shebang-dispatched scripts require the named interpreter on PATH (`bash.exe`, - `python.exe`, `node.exe`, etc.). Crush does the dispatch that the Windows - kernel won't. -- Shebang-less scripts run through mvdan regardless of extension. CRLF line - endings are tolerated. - -### Implementation sketch - -- New function - `dispatch(ctx, cmd string, env []string, stdin io.Reader) (stdout, stderr string, exitCode int, err error)` - in `internal/hooks/`. -- Tokenize using mvdan's parser (already a dep) for consistent quoting/escape - behavior with shell intuition. -- Path-prefix check on `argv[0]`; if path, read shebang with a bounded - `io.LimitReader` and parse. Support: - - `#!/absolute/interpreter args…` - - `#!/usr/bin/env NAME` → resolve `NAME` on PATH - - `#!/usr/bin/env -S NAME args…` → treat as above; `-S` is common enough to - handle. Other `env` flags can error. -- Unified exit-code helper. mvdan's `interp.ExitStatus` and `os/exec`'s - `ProcessState.ExitCode()` both become a single `int`. -- Context cancellation: mvdan's exec handler uses `exec.CommandContext` for its - children, so a cancelled hook kills both the interpreter and any children. - Verify with a `sleep 60` test. -- One fresh `interp.Runner` per hook invocation (parallel hooks must not share - state). - -### Swap the call site - -`Runner.runOne` in `internal/hooks/runner.go` replaces its -`exec.Command("sh", "-c", …)` with a call to `dispatch(…)`. Everything -downstream (exit-code 2 / 49 / other dispatch, stdout JSON parsing, -stderr-as-reason) stays identical. - -### Tests - -Cross-platform matrix: - -- Inline: `echo hi`; `exit 2`; pipelines; redirections. -- File, no shebang: treated as shell source through mvdan. -- File, `#!/bin/bash` on Unix — invokes system bash. -- File, `#!/usr/bin/env python3` — invokes python if present, skips if not. -- File, `#!/usr/bin/env -S node --foo` — extra flags preserved. -- File with CRLF line endings in the shebang. -- `./missing-file` — non-blocking error, hook proceeds as "no opinion". -- Timeout: hook that sleeps past its timeout gets killed; context cancellation - kills the interpreter and its children. -- Concurrency: 10 hooks in parallel don't leak env/cwd/state between runners. -- Windows-specific: `./script.exe` exec'd directly; bash-shebang script fails - gracefully when bash isn't on PATH. - -### Pitfalls to watch - -- **Userland shebang parsing is now our problem.** Edge cases around `env` - flags, args with spaces, CRLF, missing interpreter. Well-trodden but needs - real tests. -- **The path-prefix heuristic is a heuristic.** `relative/path.sh` (no leading - `./`) gets mvdan'd, not file-dispatched. Matches shell intuition — at a bash - prompt, `relative/path.sh` doesn't run unless `.` is on PATH — but worth - documenting. -- **Kernel shebang handling is bypassed on Unix.** Today a chmod+x'd script is - exec'd by the kernel; after this change, by our parser. Behavior should be - byte-identical; verify with tests. -- **Two code paths.** mvdan vs direct-exec. Exit codes, stdin, signal - propagation, env inheritance must be aligned. Discipline, not cleverness. - -### Explicit non-goals - -- No bundled `bash.exe` or `python.exe`. Users bring their own interpreters. -- No custom mvdan builtins (`crush_approve` etc.). Hooks stay portable and - testable under bare `bash`. -- No `.sh`-extension filter on discovery. Hook file shape is driven by shebang, - not filename. -- No Crush-as-script-interpreter mode (users can't write `#!/usr/bin/env crush` - and have it mean something). If we want that later, it's an additive feature, - not a dependency of this work. +**Status:** implemented. See the [Execution model](README.md#execution-model) +section in `README.md` for the current behavior and contract. diff --git a/docs/hooks/README.md b/docs/hooks/README.md index f2584b33dcbc42664fe51c05ef3ae5107da3ea98..c732187575d6d7cf62b124511bbb8e1228d7fb04 100644 --- a/docs/hooks/README.md +++ b/docs/hooks/README.md @@ -88,6 +88,45 @@ That's basically it. For the full guide on how hooks work, however, read on. --- +## Execution model + +Hooks run through Crush's embedded POSIX shell (`mvdan.cc/sh`) — the same +interpreter the `bash` tool uses. Inline commands and shebang-less scripts +execute in-process; scripts with a `#!` shebang dispatch to the named +interpreter via `os/exec`. This contract is identical on macOS, Linux, and +Windows. + +What this means in practice: + +- **Windows without Unix tooling**: inline shell (`echo`, pipelines, `jq`, + `grep`), shebang-less `.sh` scripts, inline PowerShell + (`powershell -Command …`), and `.exe` invocations all work out of the box + with no WSL, Git Bash, Cygwin, or MSYS required. +- **PowerShell scripts** (`.ps1`) are not auto-dispatched by extension. + Invoke them explicitly: `powershell -File ./audit.ps1` (or + `pwsh -File ./audit.ps1`). +- **Shebang'd scripts** require the named interpreter on `PATH`. Git for + Windows ships `bash.exe`, which makes `#!/bin/bash` and + `#!/usr/bin/env bash` scripts work on Windows the same way they do on + Unix. CRLF line endings in the shebang line are tolerated. +- **Permissive shebang fallback**: if the absolute path in a shebang + doesn't exist (e.g. `#!/bin/bash` on Windows), Crush falls back to a + `PATH` lookup of the base name (`bash`) before giving up. A debug-level + log records the fallback. If the interpreter isn't on `PATH` either, the + hook fails cleanly as a non-blocking warning and the agent proceeds as + "no opinion". +- **Environment**: every hook sees `CRUSH=1`, `AGENT=crush`, and + `AI_AGENT=crush` on top of the `CRUSH_*` hook-specific variables. These + three markers are guaranteed and match what the `bash` tool sets, so + scripts that detect "am I being run by an AI agent?" behave the same in + both contexts. +- **Timeout behavior**: when a hook exceeds its timeout, Crush cancels the + context and waits a short grace period (~1s) for the interpreter to + yield. If the hook still hasn't returned, Crush abandons it, logs a + warning, and treats the result as "no opinion" so the agent can proceed. + Long-running work should honor context cancellation or run in a + subprocess via a shebang. + ## Configuration Hooks can be added to your `crush.json` (or `.crush.json`) at both the global @@ -143,7 +182,8 @@ When a hook fires, Crush: 1. Filters hooks whose `matcher` regex matches the tool name (no matcher = match all). 2. Deduplicates by `command` (identical commands run once). -3. Runs all matching hooks **in parallel** as subprocesses. +3. Runs all matching hooks **in parallel** through Crush's embedded POSIX + shell (see [Execution model](#execution-model)). 4. Waits for all to finish (or time out), then aggregates results **in config order**: deny wins over allow, allow wins over none; `updated_input` patches shallow-merge in order. @@ -153,8 +193,9 @@ When a hook fires, Crush: also skips the prompt. Silence (no decision) falls through to the normal permission flow. -Note that you can omit `matcher` and match in your shell script instead, however -you'll incur some additional overhead as Crush will `exec` each script. +Note that you can omit `matcher` and match in your shell script instead, +however you'll incur some additional overhead as Crush will still parse and +run each hook. ### Input @@ -168,6 +209,9 @@ The available environment variables are: | Variable | Description | | ---------------------------- | ---------------------------------------------- | +| `CRUSH` | Always `1` when running under Crush. | +| `AGENT` | Always `crush`. | +| `AI_AGENT` | Always `crush`. | | `CRUSH_EVENT` | The hook event name (e.g. `PreToolUse`). | | `CRUSH_TOOL_NAME` | The tool being called (e.g. `bash`). | | `CRUSH_SESSION_ID` | Current session ID. | @@ -176,6 +220,10 @@ The available environment variables are: | `CRUSH_TOOL_INPUT_COMMAND` | For `bash` calls: the shell command being run. | | `CRUSH_TOOL_INPUT_FILE_PATH` | For file tools: the target file path. | +The `CRUSH`, `AGENT`, and `AI_AGENT` markers are also set by the `bash` +tool, so a script can detect "am I running under Crush?" the same way in +either context. + #### JSON Standard input provides the full context as JSON: @@ -326,9 +374,12 @@ When multiple hooks match the same tool call: ### Timeouts -If a hook exceeds its timeout, the process is killed and treated as a -non-blocking error and the tool call proceeds. The default timeout is 30 -seconds. +If a hook exceeds its timeout, Crush cancels its context and treats the +result as a non-blocking error so the tool call proceeds. The default +timeout is 30 seconds. Shebang-dispatched subprocesses are killed through +`exec.CommandContext`; in-process hooks get a short grace period to yield +and are then abandoned (the agent moves on regardless). Long-running work +should honor context cancellation or run out-of-process via a shebang. ## Examples diff --git a/internal/hooks/hooks_test.go b/internal/hooks/hooks_test.go index c0fbec7edf97a1ca9645a4d1efc1403508354ca3..77217e77354367a07d8a2a8e172152b2281e3643 100644 --- a/internal/hooks/hooks_test.go +++ b/internal/hooks/hooks_test.go @@ -2,11 +2,14 @@ package hooks import ( "context" + "io" "strings" + "sync" "testing" "time" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/shell" "github.com/stretchr/testify/require" ) @@ -192,6 +195,13 @@ func TestBuildEnv(t *testing.T) { require.Equal(t, "/project", envMap["CRUSH_PROJECT_DIR"]) require.Equal(t, "ls", envMap["CRUSH_TOOL_INPUT_COMMAND"]) require.Equal(t, "/tmp/f.txt", envMap["CRUSH_TOOL_INPUT_FILE_PATH"]) + + // Shared Crush markers must be present so hook-authored scripts can + // detect they're running under Crush the same way bash-tool-invoked + // scripts can. + require.Equal(t, "1", envMap["CRUSH"]) + require.Equal(t, "crush", envMap["AGENT"]) + require.Equal(t, "crush", envMap["AI_AGENT"]) } func splitFirst(s, sep string) []string { @@ -602,6 +612,63 @@ func TestAggregationUpdatedInput(t *testing.T) { }) } +// TestRunnerAbandonRaceSafety verifies that if a hook's shell execution +// does not yield to ctx cancellation within abandonGrace, runOne returns +// promptly and never touches the shared stdout/stderr buffers again — +// even while the abandoned goroutine continues to write to them. +// +// The substitute shell executor ignores ctx entirely, writes to Stdout +// both before and after the abandon deadline, and only then returns. +// Under -race this catches any code path in runOne that reads those +// buffers after returning the DecisionNone abandon result. +func TestRunnerAbandonRaceSafety(t *testing.T) { + origRunShell := runShell + t.Cleanup(func() { runShell = origRunShell }) + + // Synchronize shutdown with the abandoned goroutine so the test + // exits cleanly even under -race. + var wg sync.WaitGroup + release := make(chan struct{}) + t.Cleanup(func() { + close(release) + wg.Wait() + }) + + runShell = func(_ context.Context, opts shell.RunOptions) error { + wg.Add(1) + defer wg.Done() + // Write before the caller observes ctx.Done(); the caller will + // not read the buffer while we still own it. + _, _ = io.WriteString(opts.Stdout, "before\n") + // Hold past ctx deadline + abandonGrace so the caller takes + // the abandon branch, then continue writing. If the caller + // reads these buffers after abandoning, -race will flag it. + select { + case <-time.After(5 * time.Second): + case <-release: + } + _, _ = io.WriteString(opts.Stdout, "after\n") + return nil + } + + hookCfg := config.HookConfig{ + Command: "# irrelevant; runShell is stubbed", + Timeout: 1, + } + r := NewRunner([]config.HookConfig{hookCfg}, t.TempDir(), t.TempDir()) + + start := time.Now() + result, err := r.Run(context.Background(), EventPreToolUse, "sess", "bash", `{}`) + elapsed := time.Since(start) + + require.NoError(t, err) + require.Equal(t, DecisionNone, result.Decision) + // Abandon must happen at ~timeout + abandonGrace. Allow generous + // slack so CI noise doesn't flake the test. + require.Less(t, elapsed, 3500*time.Millisecond, + "runOne should return within timeout+abandonGrace+slack") +} + func TestRunnerUpdatedInput(t *testing.T) { t.Parallel() hookCfg := config.HookConfig{ diff --git a/internal/hooks/input.go b/internal/hooks/input.go index af77d5d149bcf85cd3747a831d5e841ad9050fe1..dfe1a52a42abb26eb81b8db1ffcbbf13c17544f8 100644 --- a/internal/hooks/input.go +++ b/internal/hooks/input.go @@ -7,6 +7,7 @@ import ( "os" "strings" + "github.com/charmbracelet/crush/internal/shell" "github.com/tidwall/gjson" ) @@ -51,6 +52,7 @@ func BuildPayload(eventName, sessionID, cwd, toolName, toolInputJSON string) []b // It includes all current process env vars plus hook-specific ones. func BuildEnv(eventName, toolName, sessionID, cwd, projectDir, toolInputJSON string) []string { env := os.Environ() + env = append(env, shell.CrushEnvMarkers()...) env = append(env, fmt.Sprintf("CRUSH_EVENT=%s", eventName), fmt.Sprintf("CRUSH_TOOL_NAME=%s", toolName), diff --git a/internal/hooks/runner.go b/internal/hooks/runner.go index 0e6ba6abba8e564353579b2b024c11527ab965e9..be625b0f10f3d97b4c67220dce3d96720797eb52 100644 --- a/internal/hooks/runner.go +++ b/internal/hooks/runner.go @@ -4,15 +4,27 @@ import ( "bytes" "context" "log/slog" - "os/exec" "regexp" "strings" "sync" "time" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/shell" ) +// abandonGrace is how long runOne waits after ctx cancellation for the +// shell goroutine to yield before returning control to the caller and +// letting the goroutine finish on its own. Mirrors the historical +// cmd.WaitDelay = time.Second behavior of the previous os/exec path. +const abandonGrace = time.Second + +// runShell is the shell executor used by runOne. It is a package-level +// variable so tests can substitute a blocking or non-yielding +// implementation to exercise the abandon-on-timeout path without +// depending on the scheduling behavior of the real interpreter. +var runShell = shell.Run + // compiledHook pairs a HookConfig with its compiled matcher regex. A nil // matcher means "match every tool". type compiledHook struct { @@ -140,24 +152,59 @@ func (r *Runner) matchingHooks(toolName string) []config.HookConfig { } // runOne executes a single hook command and returns its result. +// +// Execution goes through Crush's embedded POSIX shell (shell.Run) so the +// same interpreter, builtins, and coreutils are visible to hooks as to +// the bash tool. BlockFuncs are intentionally omitted: hooks are +// user-authored config that carry the same trust as a shell alias. +// +// A hook that fails to yield after its deadline has passed is abandoned +// after abandonGrace so the caller never blocks longer than +// timeout + abandonGrace. Ownership of the stdout and stderr buffers is +// strictly single-goroutine: +// - before receiving from `done`, only the goroutine writes to them; +// - after `done` delivers a value, the goroutine is finished and the +// outer frame reads them; +// - on the abandon path, the goroutine may still be writing and the +// outer frame must not touch them again. func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult { timeout := hook.TimeoutDuration() ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() - cmd := exec.CommandContext(ctx, "sh", "-c", hook.Command) - cmd.WaitDelay = time.Second - cmd.Env = envVars - cmd.Dir = r.cwd - cmd.Stdin = bytes.NewReader(payload) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr + done := make(chan error, 1) + go func() { + done <- runShell(ctx, shell.RunOptions{ + Command: hook.Command, + Cwd: r.cwd, + Env: envVars, + Stdin: bytes.NewReader(payload), + Stdout: &stdout, + Stderr: &stderr, + }) + }() - err := cmd.Run() + var err error + select { + case err = <-done: + // Normal path: goroutine has finished, buffers are safe to read. + case <-ctx.Done(): + select { + case err = <-done: + // Interpreter yielded within the grace period; safe to read. + case <-time.After(abandonGrace): + slog.Warn("Hook did not yield after cancel; abandoning goroutine", + "command", hook.Command, + "timeout", timeout, + ) + // The goroutine may still be writing to stdout/stderr; do + // not read either buffer below this point. + return HookResult{Decision: DecisionNone} + } + } - if ctx.Err() != nil { + if shell.IsInterrupt(err) { // Distinguish timeout from parent cancellation. if parentCtx.Err() != nil { slog.Debug("Hook cancelled by parent context", "command", hook.Command) @@ -168,10 +215,7 @@ func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVa } if err != nil { - exitCode := -1 - if cmd.ProcessState != nil { - exitCode = cmd.ProcessState.ExitCode() - } + exitCode := shell.ExitCode(err) switch exitCode { case 2: // Exit code 2 = block this tool call. Stderr is the reason. @@ -200,6 +244,7 @@ func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVa "command", hook.Command, "exit_code", exitCode, "stderr", strings.TrimSpace(stderr.String()), + "error", err, ) return HookResult{Decision: DecisionNone} } diff --git a/internal/shell/dispatch.go b/internal/shell/dispatch.go new file mode 100644 index 0000000000000000000000000000000000000000..869970639a5d3ba597ddb59e4486ea484972fe5d --- /dev/null +++ b/internal/shell/dispatch.go @@ -0,0 +1,426 @@ +package shell + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/fs" + "log/slog" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" + "mvdan.cc/sh/v3/syntax" +) + +// probeWindow is how many bytes we read from the head of a file to decide +// how to dispatch it. 128 is plenty for a shebang line and for magic-byte +// inspection, while small enough to make the probe cheap for users whose +// hooks invoke many scripts. +const probeWindow = 128 + +// scriptDispatchHandler returns middleware that intercepts exec of a +// path-prefixed argv[0] (e.g. ./foo.sh, /opt/bin/tool, C:\foo\bar.exe) and +// dispatches based on the file's contents: +// +// 1. Shebang line (#!...) → exec the named interpreter via os/exec. The +// interpreter is resolved literally first, then via PATH on the +// basename as a permissive fallback (so #!/bin/bash works on Windows +// boxes where Git for Windows puts bash.exe on PATH). +// 2. Known binary magic (MZ, ELF, Mach-O) or a NUL byte in the probe +// window → pass through to the next handler (mvdan's default exec). +// 3. Otherwise → treat the file as shell source and run it in-process via +// a nested interp.Runner that reuses the same handler stack. +// +// Non-path-prefixed argv[0] and empty args are passed straight through; this +// handler is a no-op for ordinary commands like `echo` or `jq`. +// +// blockFuncs is the block list used when building the nested runner for the +// shell-source case, so deny rules apply recursively to commands invoked +// from in-process scripts. +func scriptDispatchHandler(blockFuncs []BlockFunc) func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 || !isPathPrefixed(args[0]) { + return next(ctx, args) + } + + scriptPath := args[0] + // Resolve relative paths against the interpreter's cwd, not + // the process cwd — hook commands are authored with the hook + // Runner's cwd in mind and sub-shells can cd before an exec. + if !filepath.IsAbs(scriptPath) { + scriptPath = filepath.Join(interp.HandlerCtx(ctx).Dir, scriptPath) + } + probe, err := probeFile(scriptPath) + if err != nil { + return err + } + + switch { + case hasShebang(probe): + return dispatchShebang(ctx, scriptPath, probe, args) + case isBinary(probe): + return next(ctx, args) + default: + return runShellSource(ctx, scriptPath, args, blockFuncs) + } + } + } +} + +// isPathPrefixed reports whether argv[0] is a file reference (as opposed +// to a bare command to be resolved via PATH). A path reference starts with +// `./`, `../`, `/`, or — on Windows — a drive-letter prefix. +// +// Note: mvdan already performs tilde expansion during word expansion, so +// `~/script.sh` arrives here as an absolute path. We still call the helper +// on the raw string to stay robust if a future change ever bypasses that +// expansion; cover that path with a regression test. +func isPathPrefixed(arg string) bool { + switch { + case strings.HasPrefix(arg, "./"), + strings.HasPrefix(arg, "../"), + strings.HasPrefix(arg, "/"): + return true + } + if runtime.GOOS == "windows" { + // Drive-letter paths: C:\foo or C:/foo (length check avoids + // accidentally matching a single letter followed by a colon). + if len(arg) >= 3 && isDriveLetter(arg[0]) && arg[1] == ':' && + (arg[2] == '\\' || arg[2] == '/') { + return true + } + // Also treat backslash-prefixed UNC-like paths as path-prefixed. + if strings.HasPrefix(arg, "\\") { + return true + } + } + return false +} + +func isDriveLetter(b byte) bool { + return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') +} + +// probeFile reads the first probeWindow bytes of the target path. It +// deliberately does not slurp the whole file: callers that need the full +// contents (only the shell-source branch) re-open via os.ReadFile. This +// keeps memory bounded when argv[0] turns out to be a large binary. +// +// Returns errors surfaced by os.Open/os.Stat directly so callers see the +// real reason: ENOENT, EACCES, EISDIR, ELOOP, etc. +func probeFile(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + return nil, err + } + if fi.IsDir() { + return nil, fmt.Errorf("%s: is a directory", path) + } + probe := make([]byte, probeWindow) + n, err := io.ReadFull(f, probe) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return nil, err + } + return probe[:n], nil +} + +// hasShebang reports whether probe starts with the `#!` marker. A +// one-byte file that happens to be `#` is not a shebang. +func hasShebang(probe []byte) bool { + return len(probe) >= 2 && probe[0] == '#' && probe[1] == '!' +} + +// isBinary heuristically classifies probe as an executable or otherwise +// non-text file. A NUL byte in the first probeWindow bytes is the classic +// Unix-y text-vs-binary signal; we additionally recognize known magic +// numbers so we can fast-path well-formed binaries that happen to have no +// NUL in the first 128 bytes (rare but possible for small binaries). +func isBinary(probe []byte) bool { + if bytes.IndexByte(probe, 0) >= 0 { + return true + } + magics := [][]byte{ + {'M', 'Z'}, // Windows PE / DOS MZ. + {0x7F, 'E', 'L', 'F'}, // ELF. + {0xFE, 0xED, 0xFA, 0xCE}, // Mach-O 32-bit BE. + {0xFE, 0xED, 0xFA, 0xCF}, // Mach-O 64-bit BE. + {0xCF, 0xFA, 0xED, 0xFE}, // Mach-O 64-bit LE. + {0xCE, 0xFA, 0xED, 0xFE}, // Mach-O 32-bit LE. + {0xCA, 0xFE, 0xBA, 0xBE}, // Mach-O fat binary. + } + for _, m := range magics { + if bytes.HasPrefix(probe, m) { + return true + } + } + return false +} + +// dispatchShebang parses probe's shebang line and execs the resolved +// interpreter via os/exec, inheriting the parent runner's cwd, env, and +// stdio. Returns interp.ExitStatus on non-zero interpreter exit so the +// parent interpreter sees it as a normal non-zero status. +func dispatchShebang(ctx context.Context, scriptPath string, probe []byte, args []string) error { + sb, err := parseShebang(probe) + if err != nil { + hc := interp.HandlerCtx(ctx) + fmt.Fprintf(hc.Stderr, "crush: %s: %s\n", scriptPath, err) + return interp.ExitStatus(126) + } + + interpreter, err := resolveInterpreter(sb.interpreter) + if err != nil { + hc := interp.HandlerCtx(ctx) + fmt.Fprintf(hc.Stderr, "crush: %s: %s\n", scriptPath, err) + return interp.ExitStatus(127) + } + + cmdArgs := append([]string{}, sb.args...) + cmdArgs = append(cmdArgs, scriptPath) + cmdArgs = append(cmdArgs, args[1:]...) + + cmd := exec.CommandContext(ctx, interpreter, cmdArgs...) + hc := interp.HandlerCtx(ctx) + cmd.Dir = hc.Dir + cmd.Env = execEnvList(hc.Env) + cmd.Stdin = hc.Stdin + cmd.Stdout = hc.Stdout + cmd.Stderr = hc.Stderr + + if err := cmd.Run(); err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + code := exitErr.ExitCode() + if code < 0 { + code = 1 + } + return interp.ExitStatus(uint8(code)) + } + return err + } + return nil +} + +// resolveInterpreter tries the literal shebang path first, then falls back +// to PATH-lookup on its basename — but only when the literal path is +// genuinely missing. A file that exists but fails stat for another reason +// (EACCES, ELOOP, etc.) surfaces the real error: silently resolving a +// different binary off PATH in that case would hide a real problem and +// produce surprising behavior for the user. +// +// The permissive fallback is what makes #!/bin/bash portable to Windows +// boxes where Git for Windows puts bash.exe on PATH but there is no +// /bin/bash on disk. +func resolveInterpreter(path string) (string, error) { + _, statErr := os.Stat(path) + if statErr == nil { + return path, nil + } + if !errors.Is(statErr, fs.ErrNotExist) { + return "", statErr + } + + base := filepath.Base(path) + if base == "" || base == path && !strings.ContainsAny(path, `/\`) { + // Already a bare name — just do a PATH lookup. + resolved, err := exec.LookPath(path) + if err != nil { + return "", fmt.Errorf("interpreter %q not found in PATH", path) + } + return resolved, nil + } + resolved, err := exec.LookPath(base) + if err != nil { + return "", fmt.Errorf("interpreter %q not found and %q not in PATH", path, base) + } + slog.Debug("Shebang interpreter not found; falling back to PATH", + "requested", path, "resolved", resolved) + return resolved, nil +} + +// shebang captures the parsed `#!` line. interpreter is the program to +// invoke; args is the list of extra arguments to pass before the script +// path. The kernel's single-arg semantics (for literal paths and for env +// without `-S`) is encoded by returning a single-element args slice +// containing the un-tokenized remainder. +type shebang struct { + interpreter string + args []string +} + +// parseShebang extracts the interpreter invocation from probe. It tolerates +// CRLF line endings and a single leading space between `#!` and the path. +// env special-cases: `/usr/bin/env NAME [args...]` unwraps to NAME with +// kernel single-arg semantics; `-S` enables tokenized argument splitting. +func parseShebang(probe []byte) (*shebang, error) { + if !hasShebang(probe) { + return nil, errors.New("not a shebang") + } + line := probe[2:] + // Take up to the first newline. + if idx := bytes.IndexByte(line, '\n'); idx >= 0 { + line = line[:idx] + } + // Strip trailing CR (CRLF-authored scripts). + line = bytes.TrimRight(line, "\r") + // Strip leading whitespace ("#! /usr/bin/env bash" is legal). + line = bytes.TrimLeft(line, " \t") + if len(line) == 0 { + return nil, errors.New("empty shebang") + } + + var pathStr, rest string + if idx := bytes.IndexAny(line, " \t"); idx >= 0 { + pathStr = string(line[:idx]) + rest = strings.TrimLeft(string(line[idx+1:]), " \t") + } else { + pathStr = string(line) + } + + if isEnvShebang(pathStr) { + return parseEnvShebang(rest) + } + + // Literal-path shebang: kernel semantics pass the remainder as a + // single argv[1], not tokenized. + sb := &shebang{interpreter: pathStr} + if rest != "" { + sb.args = []string{rest} + } + return sb, nil +} + +// isEnvShebang reports whether the shebang path targets `env`. We accept +// both common absolute paths and a bare `env` so that unusual setups +// (NixOS, BSDs) still work. +func isEnvShebang(p string) bool { + if p == "/usr/bin/env" || p == "/bin/env" { + return true + } + return filepath.Base(p) == "env" +} + +// parseEnvShebang handles `/usr/bin/env` rewriting. Without `-S`, the +// remainder after the program name is a single argv[1] (kernel +// single-arg semantics via env, even though real env would fail to find a +// program named "bash -x"). With `-S`, the remainder is tokenized on +// whitespace. Any other `env` flag is rejected — forwarding unknown flags +// to a /usr/bin/env on disk is a subtle portability footgun we don't want. +func parseEnvShebang(rest string) (*shebang, error) { + if rest == "" { + return nil, errors.New("env: missing program name") + } + + useSplit := false + if strings.HasPrefix(rest, "-") { + var flag, after string + if idx := strings.IndexAny(rest, " \t"); idx >= 0 { + flag = rest[:idx] + after = strings.TrimLeft(rest[idx+1:], " \t") + } else { + flag = rest + after = "" + } + if flag != "-S" { + return nil, fmt.Errorf("unsupported env flag: %s", flag) + } + useSplit = true + rest = after + if rest == "" { + return nil, errors.New("env -S requires a program") + } + } + + if rest == "" { + return nil, errors.New("env: missing program name") + } + + var prog, remainder string + if idx := strings.IndexAny(rest, " \t"); idx >= 0 { + prog = rest[:idx] + remainder = strings.TrimLeft(rest[idx+1:], " \t") + } else { + prog = rest + } + + sb := &shebang{interpreter: prog} + if remainder != "" { + if useSplit { + sb.args = strings.Fields(remainder) + } else { + sb.args = []string{remainder} + } + } + return sb, nil +} + +// runShellSource parses path's contents as POSIX shell and runs it +// in-process via a nested interp.Runner. It reuses the parent runner's cwd, +// env, and stdio, and rebuilds the Crush handler stack so builtins and the +// dispatch handler itself remain available to anything the script invokes. +// Positional parameters ($1, $2, …) come from args[1:]. +// +// This is the only branch that reads the full file; probeFile keeps its +// read to probeWindow bytes so the binary/shebang paths never touch more +// than 128 bytes of I/O. +func runShellSource(ctx context.Context, path string, args []string, blockFuncs []BlockFunc) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + + file, err := syntax.NewParser().Parse(bytes.NewReader(data), path) + if err != nil { + return fmt.Errorf("could not parse %s: %w", path, err) + } + + hc := interp.HandlerCtx(ctx) + + opts := []interp.RunnerOption{ + interp.StdIO(hc.Stdin, hc.Stdout, hc.Stderr), + interp.Interactive(false), + interp.Env(hc.Env), + interp.Dir(hc.Dir), + interp.ExecHandlers(standardHandlers(blockFuncs)...), + } + if len(args) > 1 { + // Params with a leading "--" avoids any of args[1:] being + // misinterpreted as set-options (e.g. a user passing "-e" as + // a positional arg to their script). + params := append([]string{"--"}, args[1:]...) + opts = append(opts, interp.Params(params...)) + } + + runner, err := interp.New(opts...) + if err != nil { + return fmt.Errorf("could not build runner for %s: %w", path, err) + } + return runner.Run(ctx, file) +} + +// execEnvList converts an expand.Environ to the []string form that +// os/exec.Cmd.Env expects. Only exported string variables are included, +// matching what a real shell would pass to a child process. +func execEnvList(env expand.Environ) []string { + var out []string + env.Each(func(name string, vr expand.Variable) bool { + if vr.Exported && vr.Kind == expand.String { + out = append(out, name+"="+vr.Str) + } + return true + }) + return out +} diff --git a/internal/shell/dispatch_test.go b/internal/shell/dispatch_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1ada7643b5049d1675ea031d8fb921c9c0941324 --- /dev/null +++ b/internal/shell/dispatch_test.go @@ -0,0 +1,594 @@ +package shell + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "os" + "os/exec" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" +) + +// writeScript is a small helper that drops a file with the given contents +// and executable mode into dir. Tests that need exec semantics rely on the +// 0o755 mode on Unix; Windows ignores file modes but doesn't need them +// because dispatch decides what to do from file contents, not permissions. +func writeScript(t *testing.T, dir, name, contents string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte(contents), 0o755); err != nil { + t.Fatalf("write %s: %v", name, err) + } + return filepath.ToSlash(path) +} + +// randSuffix returns a short random hex string, used to build +// intentionally-unique paths that won't collide with anything on disk. +func randSuffix() string { + var b [4]byte + _, _ = rand.Read(b[:]) + return hex.EncodeToString(b[:]) +} + +// TestIsPathPrefixed covers the classification rules used by the dispatch +// handler to decide whether argv[0] is a file reference. +func TestIsPathPrefixed(t *testing.T) { + cases := []struct { + in string + want bool + }{ + {"./foo.sh", true}, + {"../foo.sh", true}, + {"/usr/bin/foo", true}, + {"foo", false}, + {"foo.sh", false}, + {"jq", false}, + {"", false}, + } + for _, c := range cases { + if got := isPathPrefixed(c.in); got != c.want { + t.Errorf("isPathPrefixed(%q) = %v, want %v", c.in, got, c.want) + } + } + + if runtime.GOOS == "windows" { + winCases := []struct { + in string + want bool + }{ + {`C:\foo\bar.exe`, true}, + {`C:/foo/bar.exe`, true}, + {`c:\foo`, true}, + {`Z:/x`, true}, + {`C:`, false}, // just a drive, no path. + {`\\server\share`, true}, + } + for _, c := range winCases { + if got := isPathPrefixed(c.in); got != c.want { + t.Errorf("isPathPrefixed(%q) = %v, want %v", c.in, got, c.want) + } + } + } +} + +// TestParseShebang covers the shebang grammar: literal paths, env, +// env -S, kernel single-arg semantics, CRLF tolerance, and every +// enumerated error case. +func TestParseShebang(t *testing.T) { + type want struct { + interp string + args []string + errSub string // substring expected in error message (empty → no error) + } + cases := []struct { + name string + in string + want want + }{ + { + name: "literal-no-args", + in: "#!/bin/bash\necho body\n", + want: want{interp: "/bin/bash"}, + }, + { + name: "literal-kernel-single-arg", + in: "#!/bin/bash -x -y\n", + want: want{interp: "/bin/bash", args: []string{"-x -y"}}, + }, + { + name: "env-basic", + in: "#!/usr/bin/env bash\n", + want: want{interp: "bash"}, + }, + { + name: "env-kernel-single-arg", + in: "#!/usr/bin/env bash -x\n", + want: want{interp: "bash", args: []string{"-x"}}, + }, + { + name: "env-dash-S-splits", + in: "#!/usr/bin/env -S bash -x\n", + want: want{interp: "bash", args: []string{"-x"}}, + }, + { + name: "env-dash-S-multi-args", + in: "#!/usr/bin/env -S bash -x --noprofile\n", + want: want{interp: "bash", args: []string{"-x", "--noprofile"}}, + }, + { + name: "leading-space", + in: "#! /usr/bin/env bash\n", + want: want{interp: "bash"}, + }, + { + name: "crlf", + in: "#!/bin/bash\r\n", + want: want{interp: "/bin/bash"}, + }, + { + name: "bare-env-name", + in: "#!env bash\n", + want: want{interp: "bash"}, + }, + { + name: "empty-after-hashbang", + in: "#!\n", + want: want{errSub: "empty shebang"}, + }, + { + name: "env-alone", + in: "#!/usr/bin/env\n", + want: want{errSub: "missing program name"}, + }, + { + name: "env-dash-S-alone", + in: "#!/usr/bin/env -S\n", + want: want{errSub: "env -S requires a program"}, + }, + { + name: "env-unknown-flag", + in: "#!/usr/bin/env -x bash\n", + want: want{errSub: "unsupported env flag"}, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + sb, err := parseShebang([]byte(c.in)) + if c.want.errSub != "" { + if err == nil || !strings.Contains(err.Error(), c.want.errSub) { + t.Fatalf("expected error containing %q, got: %v", c.want.errSub, err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sb.interpreter != c.want.interp { + t.Errorf("interpreter = %q, want %q", sb.interpreter, c.want.interp) + } + if !equalStringSlice(sb.args, c.want.args) { + t.Errorf("args = %v, want %v", sb.args, c.want.args) + } + }) + } +} + +func equalStringSlice(a, b []string) bool { + if len(a) == 0 && len(b) == 0 { + return true + } + return reflect.DeepEqual(a, b) +} + +// TestIsBinary covers the NUL-byte and magic-byte classification used to +// keep compiled executables off the in-process shell-source path. +func TestIsBinary(t *testing.T) { + cases := []struct { + name string + in []byte + want bool + }{ + {"shell", []byte("echo hi\n"), false}, + {"nul", []byte("hello\x00world"), true}, + {"elf", []byte{0x7F, 'E', 'L', 'F', 0x02, 0x01}, true}, + {"mz", []byte("MZ\x90\x00"), true}, + {"macho-64-le", []byte{0xCF, 0xFA, 0xED, 0xFE}, true}, + {"short-non-binary", []byte("a"), false}, + } + for _, c := range cases { + if got := isBinary(c.in); got != c.want { + t.Errorf("%s: isBinary = %v, want %v", c.name, got, c.want) + } + } +} + +// TestDispatch_ShellSourceNoShebang exercises the in-process shell-source +// branch: a file without a shebang runs via a nested runner and sees +// positional params from argv[1:]. +func TestDispatch_ShellSourceNoShebang(t *testing.T) { + dir := t.TempDir() + script := writeScript(t, dir, "args.sh", `echo "$1 $2"`) + + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: script + " alpha beta", + Cwd: dir, + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); got != "alpha beta\n" { + t.Fatalf("stdout = %q, want %q", got, "alpha beta\n") + } +} + +// TestDispatch_EmptyFile confirms a zero-byte script runs as empty shell +// source (exit 0, no output). +func TestDispatch_EmptyFile(t *testing.T) { + dir := t.TempDir() + script := writeScript(t, dir, "empty.sh", "") + + var stdout, stderr bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: script, + Cwd: dir, + Stdout: &stdout, + Stderr: &stderr, + }) + if err != nil { + t.Fatalf("Run returned error: %v (stderr=%q)", err, stderr.String()) + } + if stdout.Len() != 0 || stderr.Len() != 0 { + t.Fatalf("expected empty output, got stdout=%q stderr=%q", stdout.String(), stderr.String()) + } +} + +// TestDispatch_ShellSourceComposesWithPipe confirms the dispatch handler +// plays nicely with mvdan's pipeline logic: a shell-source script on the +// left feeds the jq builtin on the right. +func TestDispatch_ShellSourceComposesWithPipe(t *testing.T) { + dir := t.TempDir() + script := writeScript(t, dir, "emit.sh", `printf '"value"'`) + + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: script + ` | jq -r .`, + Cwd: dir, + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); got != "value\n" { + t.Fatalf("stdout = %q, want %q", got, "value\n") + } +} + +// TestDispatch_MissingFile returns a clean error for a non-existent path. +func TestDispatch_MissingFile(t *testing.T) { + dir := t.TempDir() + missing := filepath.Join(dir, "nope.sh") + err := Run(t.Context(), RunOptions{ + Command: missing, + Cwd: dir, + }) + if err == nil { + t.Fatal("expected error for missing script, got nil") + } +} + +// TestDispatch_DirectoryNotFile surfaces a distinct error when the path +// resolves to a directory. +func TestDispatch_DirectoryNotFile(t *testing.T) { + dir := t.TempDir() + subDir := filepath.Join(dir, "adir") + if err := os.MkdirAll(subDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + + var stderr bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "./adir", + Cwd: dir, + Stderr: &stderr, + }) + if err == nil { + t.Fatal("expected error when invoking a directory, got nil") + } + if !strings.Contains(err.Error(), "is a directory") { + t.Fatalf("expected 'is a directory' in error, got: %v", err) + } +} + +// TestDispatch_BashShebang runs a #!/bin/bash script via os/exec. Skipped +// if bash isn't available (rare in CI, but keep the test robust). +func TestDispatch_BashShebang(t *testing.T) { + bash, err := exec.LookPath("bash") + if err != nil { + t.Skipf("bash not in PATH: %v", err) + } + _ = bash + + dir := t.TempDir() + script := writeScript(t, dir, "bash-echo.sh", "#!/usr/bin/env bash\necho bashout\n") + + var stdout, stderr bytes.Buffer + err = Run(t.Context(), RunOptions{ + Command: script, + Cwd: dir, + Stdout: &stdout, + Stderr: &stderr, + }) + if err != nil { + t.Fatalf("Run returned error: %v (stderr=%q)", err, stderr.String()) + } + if got := stdout.String(); got != "bashout\n" { + t.Fatalf("stdout = %q, want %q", got, "bashout\n") + } +} + +// TestDispatch_ShebangPassesExitCode maps interpreter exit codes through to +// interp.ExitStatus so the caller can inspect them with ExitCode. +func TestDispatch_ShebangPassesExitCode(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skipf("bash not in PATH: %v", err) + } + dir := t.TempDir() + script := writeScript(t, dir, "fail.sh", "#!/usr/bin/env bash\nexit 5\n") + + err := Run(t.Context(), RunOptions{ + Command: script, + Cwd: dir, + }) + if err == nil { + t.Fatal("expected non-nil error from exit 5") + } + if code := ExitCode(err); code != 5 { + t.Fatalf("ExitCode = %d, want 5", code) + } +} + +// TestDispatch_MissingInterpreter surfaces a clear error (and non-zero +// exit) when the shebang points to a binary that doesn't exist and has +// no PATH fallback. +func TestDispatch_MissingInterpreter(t *testing.T) { + dir := t.TempDir() + script := writeScript(t, dir, "bad.sh", "#!/no/such/interpreter-"+randSuffix()+"\n:\n") + + var stderr bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: script, + Cwd: dir, + Stderr: &stderr, + }) + if err == nil { + t.Fatal("expected error for missing interpreter, got nil") + } + if ExitCode(err) == 0 { + t.Fatalf("expected non-zero exit code, got 0") + } + if !strings.Contains(stderr.String(), "not found") { + t.Fatalf("expected 'not found' in stderr, got: %q", stderr.String()) + } +} + +// TestDispatch_BarePathNotHandled confirms the handler ignores +// non-path-prefixed argv[0] entirely: a benign bare `true` command must +// not try to open a file in cwd. If dispatch were (incorrectly) firing +// on bare commands, this test would see probeFile's ENOENT. +func TestDispatch_BarePathNotHandled(t *testing.T) { + dir := t.TempDir() + err := Run(t.Context(), RunOptions{ + Command: "true", + Cwd: dir, + }) + if err != nil { + t.Fatalf("bare `true` should not trigger dispatch: %v", err) + } +} + +// TestDispatch_ProbeWindowClassifiesByHead confirms that classification is +// done on the first probeWindow bytes even when the file is much larger; +// a file whose head is shell source but whose tail contains NUL bytes is +// classified as shell source, not binary. +func TestDispatch_ProbeWindowClassifiesByHead(t *testing.T) { + dir := t.TempDir() + head := "echo prefix\n" + // Pad past probeWindow, then append some NULs. + padding := strings.Repeat(" ", probeWindow) + contents := head + padding + "\x00\x00\x00" + script := writeScript(t, dir, "long.sh", contents) + + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: script, + Cwd: dir, + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); !strings.HasPrefix(got, "prefix\n") { + t.Fatalf("stdout = %q, want prefix %q", got, "prefix\n") + } +} + +// TestDispatch_BinaryPassthroughExecutes copies a real binary from PATH +// into a tempdir, invokes it via a path-prefixed argv[0], and verifies it +// ran — i.e. the binary branch correctly returns through `next` to the +// default exec handler. We use whichever of `true`/`echo` is available on +// PATH so the test works on any Unix-y system; it skips on Windows where +// the stock binaries don't share names and the Go test binary approach +// is heavier than this test deserves. +func TestDispatch_BinaryPassthroughExecutes(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("relies on a Unix-style PATH binary") + } + src, err := exec.LookPath("true") + if err != nil { + t.Skipf("no `true` binary on PATH: %v", err) + } + data, err := os.ReadFile(src) + if err != nil { + t.Fatalf("read %s: %v", src, err) + } + dir := t.TempDir() + dst := filepath.Join(dir, "copied-true") + if err := os.WriteFile(dst, data, 0o755); err != nil { + t.Fatalf("write %s: %v", dst, err) + } + + runErr := Run(t.Context(), RunOptions{ + Command: dst, + Cwd: dir, + // Default handler needs PATH to resolve dynamic linker / loader + // helpers on some systems; inherit the process env so the copy + // can actually start. + Env: os.Environ(), + }) + if runErr != nil { + t.Fatalf("expected copy of /bin/true to exit 0, got: %v", runErr) + } +} + +// TestDispatch_UnreadableFile confirms an EACCES on the script surfaces +// as a clean error rather than a silent fallback or a mis-classified +// shell-source attempt. POSIX-only: Windows doesn't have the same +// permission model and running as root would bypass the check anyway. +func TestDispatch_UnreadableFile(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("POSIX permission model") + } + if os.Geteuid() == 0 { + t.Skip("root bypasses file mode permission checks") + } + dir := t.TempDir() + script := writeScript(t, dir, "unreadable.sh", "echo nope\n") + if err := os.Chmod(script, 0o000); err != nil { + t.Fatalf("chmod: %v", err) + } + t.Cleanup(func() { _ = os.Chmod(script, 0o644) }) + + err := Run(t.Context(), RunOptions{ + Command: script, + Cwd: dir, + }) + if err == nil { + t.Fatal("expected permission error, got nil") + } + if !strings.Contains(err.Error(), "permission") { + t.Fatalf("expected 'permission' in error, got: %v", err) + } +} + +// TestDispatch_SymlinkLoop confirms that an ELOOP-returning path surfaces +// cleanly. POSIX-only: creating symlinks reliably on Windows requires +// elevated privileges or developer mode, and neither is guaranteed in CI. +func TestDispatch_SymlinkLoop(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink creation requires special privileges on Windows") + } + dir := t.TempDir() + a := filepath.Join(dir, "a") + b := filepath.Join(dir, "b") + if err := os.Symlink(b, a); err != nil { + t.Fatalf("symlink a→b: %v", err) + } + if err := os.Symlink(a, b); err != nil { + t.Fatalf("symlink b→a: %v", err) + } + + err := Run(t.Context(), RunOptions{ + Command: a, + Cwd: dir, + }) + if err == nil { + t.Fatal("expected loop error, got nil") + } + // The exact error varies by OS; any of these message fragments is + // acceptable evidence that the loop was detected. + msg := err.Error() + if !strings.Contains(msg, "too many") && + !strings.Contains(msg, "loop") && + !strings.Contains(msg, "level") { + t.Fatalf("expected symlink-loop-ish error, got: %v", err) + } +} + +// TestResolveInterpreter_PermissiveFallback confirms the key portability +// behavior: a literal shebang path that doesn't exist falls back to a +// PATH-lookup on its basename. This is what makes #!/bin/bash work on a +// Windows box where bash.exe lives somewhere else on PATH. We construct a +// fake PATH in a tempdir rather than depending on what the host has +// installed so the test is deterministic everywhere. +func TestResolveInterpreter_PermissiveFallback(t *testing.T) { + if runtime.GOOS == "windows" { + // exec.LookPath on Windows requires a recognized extension + // (.exe/.bat/.cmd). Producing one of those without a compiler + // run is more ceremony than this smoke test deserves; the + // logic under test is exercised by the Unix run. + t.Skip("Windows PATH lookup requires an extension-matched binary") + } + dir := t.TempDir() + fake := filepath.Join(dir, "bash") + if err := os.WriteFile(fake, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("write fake bash: %v", err) + } + t.Setenv("PATH", dir) + + // Basename must match the fake we planted on PATH; the directory + // prefix must not exist so the literal stat fails. + missingDir := filepath.Join(dir, "definitely-not-here-"+randSuffix()) + resolved, err := resolveInterpreter(filepath.Join(missingDir, "bash")) + if err != nil { + t.Fatalf("expected fallback to succeed, got: %v", err) + } + if resolved != fake { + t.Fatalf("resolved = %q, want %q", resolved, fake) + } +} + +// TestResolveInterpreter_NonENOENTErrorsSurface guards against silently +// falling back to PATH when stat fails for a reason other than the file +// being missing. With a directory at the shebang path, os.Stat succeeds +// (no fallback needed), but with an EACCES'd file it fails with a non- +// ENOENT error that must be surfaced — otherwise we'd silently resolve a +// different binary off PATH and hide the real problem. +func TestResolveInterpreter_NonENOENTErrorsSurface(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("POSIX permission model") + } + if os.Geteuid() == 0 { + t.Skip("root bypasses dir mode permission checks") + } + dir := t.TempDir() + // Put a candidate interpreter inside an unreadable/untraversable dir. + inner := filepath.Join(dir, "private") + if err := os.Mkdir(inner, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + interp := filepath.Join(inner, "bash") + if err := os.WriteFile(interp, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("write interpreter: %v", err) + } + // Drop search permission on inner so os.Stat(interp) returns EACCES. + if err := os.Chmod(inner, 0o000); err != nil { + t.Fatalf("chmod: %v", err) + } + t.Cleanup(func() { _ = os.Chmod(inner, 0o755) }) + + _, err := resolveInterpreter(interp) + if err == nil { + t.Fatal("expected error for unreadable interpreter, got nil") + } + // Must NOT have silently fallen back — the returned path shouldn't + // be a valid resolution; either way, the error has to surface. + if !strings.Contains(err.Error(), "permission") { + t.Fatalf("expected permission-denied error to surface, got: %v", err) + } +} diff --git a/internal/shell/dispatch_windows_test.go b/internal/shell/dispatch_windows_test.go new file mode 100644 index 0000000000000000000000000000000000000000..529945ad5a33eac5d943bdd2a562a46d8d4ff000 --- /dev/null +++ b/internal/shell/dispatch_windows_test.go @@ -0,0 +1,40 @@ +//go:build windows + +package shell + +import ( + "os" + "path/filepath" + "testing" +) + +// TestResolveInterpreter_PermissiveFallback_Windows is the Windows-native +// counterpart to the POSIX permissive-fallback test. It proves the one +// behavior that makes `#!/bin/bash` hooks work on a stock Windows box +// with Git for Windows installed: when the literal interpreter path does +// not exist, we fall back to a PATH-lookup on the basename and that +// lookup accepts any executable extension Windows honors (here, `.bat`). +// +// We plant a bash.bat in a tempdir rather than a .exe because producing +// a .exe would require a toolchain step; LookPath on Windows resolves +// PATHEXT extensions, so .bat is just as valid for the lookup codepath. +func TestResolveInterpreter_PermissiveFallback_Windows(t *testing.T) { + dir := t.TempDir() + fake := filepath.Join(dir, "bash.bat") + contents := "@echo off\r\nexit /b 0\r\n" + if err := os.WriteFile(fake, []byte(contents), 0o755); err != nil { + t.Fatalf("write fake bash.bat: %v", err) + } + t.Setenv("PATH", dir) + t.Setenv("PATHEXT", ".BAT;.CMD;.EXE") + + // Literal path must be absent so the stat fails with ENOENT. + missing := filepath.Join(dir, "definitely-not-here-"+randSuffix(), "bash") + resolved, err := resolveInterpreter(missing) + if err != nil { + t.Fatalf("expected fallback to succeed, got: %v", err) + } + if resolved != fake { + t.Fatalf("resolved = %q, want %q", resolved, fake) + } +} diff --git a/internal/shell/jq.go b/internal/shell/jq.go index ceac574df13c97befa05e8817b91a8d928a96a11..4204e5722f4f5ce6e6e6e2f9f8e1acf26ba5438a 100644 --- a/internal/shell/jq.go +++ b/internal/shell/jq.go @@ -1,6 +1,7 @@ package shell import ( + "context" "encoding/json" "fmt" "io" @@ -36,10 +37,16 @@ Options: // flags: -r (raw output), -c (compact output), -s (slurp), -n (null input), // -e (exit status), -R (raw input), and --arg name value. // +// ctx is polled at each iteration of the output loop and at each reader in +// [readInputs] so that hook timeouts or other cancellations can interrupt +// long-running queries. A cancelled context surfaces as ctx.Err(), not an +// [interp.ExitStatus], so callers (e.g. the hook runner) can distinguish +// "filter exited non-zero" from "we ran out of time". +// // Note that this is somewhat of a reimplmentation of the CLI of the glorious // github.com/itchyny/gojq, and we'd ideally get the CLI exposed upstream to // avoid this falling out of sync. -func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error { +func handleJQ(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { var ( rawOutput bool compact bool @@ -143,8 +150,13 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error { } // Build input values. - inputs, err := readInputs(stdin, fileArgs, nullInput, rawInput, slurp) + inputs, err := readInputs(ctx, stdin, fileArgs, nullInput, rawInput, slurp) if err != nil { + // Prefer surfacing ctx cancellation verbatim so timeouts are + // distinguishable from user input errors. + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } fmt.Fprintf(stderr, "jq: %s\n", err) return interp.ExitStatus(2) } @@ -153,6 +165,12 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error { for _, input := range inputs { iter := code.Run(input, argValues...) for { + // Poll ctx on every value so a long-running filter (e.g. a + // generator over a slurped array) can be interrupted by hook + // timeouts without waiting for iter.Next to yield. + if err := ctx.Err(); err != nil { + return err + } v, ok := iter.Next() if !ok { break @@ -177,7 +195,22 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error { } // readInputs reads JSON (or raw) input values from stdin or files. -func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool) ([]any, error) { +// +// ctx is polled in three places so that a cancellation observed mid-read +// short-circuits promptly: +// - between readers (before opening the next file / consuming stdin); +// - on every io.Read call via ctxReader, so io.ReadAll on a large but +// non-blocking source (e.g. the bytes.NewReader payload the hook +// runner supplies) returns ctx.Err() on the next chunk boundary; +// - inside the post-read value accumulation loops (raw-input line +// split and JSON stream decode), which are otherwise unbounded in +// the size of the input. +// +// A reader that blocks forever in Read (e.g. an unterminated pipe) can +// still outlast ctx; the outer abandon-goroutine path in the hook +// runner (internal/hooks/runner.go) is the authoritative enforcer for +// that case. +func readInputs(ctx context.Context, stdin io.Reader, files []string, nullInput, rawInput, slurp bool) ([]any, error) { if nullInput { return []any{nil}, nil } @@ -198,8 +231,16 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool var vals []any for _, r := range readers { - data, err := io.ReadAll(r) + if err := ctx.Err(); err != nil { + return nil, err + } + data, err := io.ReadAll(ctxReader{ctx: ctx, r: r}) if err != nil { + // ctxReader surfaces ctx.Err() verbatim; preserve it so the + // caller can distinguish cancellation from a parse error. + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, ctxErr + } return nil, err } @@ -209,6 +250,9 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool vals = append(vals, strings.Join(lines, "\n")) } else { for _, line := range lines { + if err := ctx.Err(); err != nil { + return nil, err + } if line != "" || !slurp { vals = append(vals, line) } @@ -221,6 +265,9 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool dec := json.NewDecoder(strings.NewReader(string(data))) var streamVals []any for { + if err := ctx.Err(); err != nil { + return nil, err + } var v any if err := dec.Decode(&v); err != nil { if err == io.EOF { @@ -244,6 +291,24 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool return vals, nil } +// ctxReader wraps an io.Reader so that each Read call checks ctx first. +// This makes io.ReadAll over a large but non-blocking source (e.g. a +// bytes.Reader of the hook stdin payload) cancellable on the next chunk +// boundary. A reader that itself blocks in Read will still outlast ctx — +// the hook runner's abandon-goroutine path is the enforcer of last resort +// for that case. +type ctxReader struct { + ctx context.Context + r io.Reader +} + +func (cr ctxReader) Read(p []byte) (int, error) { + if err := cr.ctx.Err(); err != nil { + return 0, err + } + return cr.r.Read(p) +} + // writeValue writes a single jq output value. func writeValue(w io.Writer, v any, raw, compact, join bool) error { if raw { diff --git a/internal/shell/jq_test.go b/internal/shell/jq_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e449047b49b27da9b2d587b6defe9d8f81fe2888 --- /dev/null +++ b/internal/shell/jq_test.go @@ -0,0 +1,199 @@ +package shell + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "testing" + "time" +) + +// TestJQ_CtxCancel verifies that handleJQ polls ctx during iteration and +// returns ctx.Err() (not an interp.ExitStatus) when the context is +// cancelled. This is what lets hook timeouts interrupt long-running jq +// filters rather than waiting for the iterator to terminate naturally. +func TestJQ_CtxCancel(t *testing.T) { + t.Parallel() + + // `range(N)` generates a large stream of values. With a slurped input + // the filter produces all N values in sequence; ctx cancellation + // between values should short-circuit the loop. + const filter = "range(10000000)" + stdin := strings.NewReader("null\n") + + ctx, cancel := context.WithCancel(t.Context()) + // Cancel almost immediately so we catch the next iteration check. + cancel() + + err := handleJQ(ctx, []string{"jq", filter}, stdin, io.Discard, io.Discard) + if err == nil { + t.Fatal("expected ctx cancel error, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +// TestJQ_CtxCancel_DuringFilter verifies cancellation mid-stream: ctx is +// cancelled after jq has started producing output, and the loop must +// observe the cancel on the next iteration rather than running to +// completion. +func TestJQ_CtxCancel_DuringFilter(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + + // 100M values; without ctx polling this would take many seconds to + // fully emit. With ctx polling the loop exits shortly after the + // deadline. + stdin := strings.NewReader("null\n") + var stdout, stderr bytes.Buffer + + start := time.Now() + err := handleJQ(ctx, []string{"jq", "-c", "range(100000000)"}, stdin, &stdout, &stderr) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected ctx timeout error, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } + // Allow generous slack for slow CI; the important invariant is that we + // don't run all 100M iterations (which would take orders of magnitude + // longer than 1s). + if elapsed > time.Second { + t.Fatalf("handleJQ took %v after 50ms timeout; ctx polling is not tight enough", elapsed) + } +} + +// slowReader serves bytes in small chunks with a fixed delay between +// Read calls. It never blocks indefinitely — each Read returns after +// chunkDelay — so cancellation must be observed via ctxReader's ctx +// check, not by the underlying reader itself. That isolates the +// behavior we want to test: the wrapper polling ctx between chunks. +type slowReader struct { + remaining []byte + chunk int + chunkDelay time.Duration +} + +func (s *slowReader) Read(p []byte) (int, error) { + if len(s.remaining) == 0 { + return 0, io.EOF + } + time.Sleep(s.chunkDelay) + n := min(len(p), min(s.chunk, len(s.remaining))) + copy(p, s.remaining[:n]) + s.remaining = s.remaining[n:] + return n, nil +} + +// TestJQ_CtxCancel_MidReadAll verifies that ctx cancellation observed +// *during* io.ReadAll — after several chunks have already been consumed +// — short-circuits the read via ctxReader, rather than draining the +// whole source. This is the guarantee the hook runner relies on when +// it feeds a large bytes.Reader payload. +// +// The reader serves bytes in 512-byte chunks with a 5ms gap between +// reads. ctx is cancelled after ~50ms, so several chunks have already +// been read when ctxReader first observes the cancellation. The test +// asserts that (a) we got a context.Canceled error and (b) the call +// returned well before the reader would have been fully drained. +func TestJQ_CtxCancel_MidReadAll(t *testing.T) { + t.Parallel() + + const ( + size = 64 * 1024 * 1024 // 64 MiB + chunk = 512 + chunkDelay = 5 * time.Millisecond + ) + // At 512 bytes / 5ms, draining 64 MiB would take ~11 minutes. Any + // return within a second proves cancel was observed mid-stream, not + // after EOF. + reader := &slowReader{ + remaining: bytes.Repeat([]byte("a"), size), + chunk: chunk, + chunkDelay: chunkDelay, + } + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + // Cancel after enough time that several Read calls have completed + // and io.ReadAll is actively consuming the source. + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := handleJQ(ctx, []string{"jq", "-R", "."}, reader, io.Discard, io.Discard) + elapsed := time.Since(start) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + // Generous slack for slow CI; the invariant is orders-of-magnitude + // faster than draining the full source. + if elapsed > time.Second { + t.Fatalf("mid-ReadAll cancel took %v; ctxReader is not polling between chunks", elapsed) + } + // Sanity check: we should have been cancelled mid-stream, not + // before any reads happened. If remaining == size, cancel fired so + // early nothing was consumed — that's a fast-fail path, not the + // mid-read guarantee we want to verify. + consumed := size - len(reader.remaining) + if consumed == 0 { + t.Fatal("reader was never read from; test did not exercise mid-ReadAll cancel") + } + if consumed >= size { + t.Fatal("reader was fully drained; cancel was not observed mid-read") + } +} + +// TestJQ_CtxCancel_PreCancel verifies the fast-fail path: a ctx already +// cancelled before handleJQ is called returns context.Canceled +// immediately via the outer-loop guard, never entering io.ReadAll. +// Complements TestJQ_CtxCancel_MidReadAll. +func TestJQ_CtxCancel_PreCancel(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + start := time.Now() + err := handleJQ(ctx, []string{"jq", "-R", "."}, + bytes.NewReader(bytes.Repeat([]byte("a"), 1024)), + io.Discard, io.Discard) + elapsed := time.Since(start) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + if elapsed > 100*time.Millisecond { + t.Fatalf("pre-cancel fast-fail took %v; outer guard is not firing", elapsed) + } +} + +// TestJQ_Success confirms the ctx-aware refactor did not regress the +// success path. +func TestJQ_Success(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + err := handleJQ(t.Context(), + []string{"jq", "-c", ".a"}, + strings.NewReader(`{"a":1}`), + &stdout, io.Discard, + ) + if err != nil { + t.Fatalf("handleJQ returned error: %v", err) + } + if got := stdout.String(); got != "1\n" { + t.Fatalf("stdout = %q, want %q", got, "1\n") + } +} diff --git a/internal/shell/run.go b/internal/shell/run.go new file mode 100644 index 0000000000000000000000000000000000000000..cac7c530f0af415b6a9dbaa7292def1896ab318a --- /dev/null +++ b/internal/shell/run.go @@ -0,0 +1,158 @@ +package shell + +import ( + "context" + "fmt" + "io" + "strings" + + "mvdan.cc/sh/moreinterp/coreutils" + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" + "mvdan.cc/sh/v3/syntax" +) + +// RunOptions configures a single stateless shell execution via [Run]. +// +// The zero value is not useful; at minimum Command must be set. Stdin, +// Stdout, and Stderr may be nil (nil readers/writers are treated as +// empty/discard). BlockFuncs may be nil to disable block-list enforcement — +// hooks use this to run user-authored commands with the same trust level as +// a shell alias. +type RunOptions struct { + // Command is the shell source to parse and execute. + Command string + // Cwd is the working directory for the execution. Required: callers + // must supply a non-empty value. Run does not silently fall back to + // the Crush process cwd — hooks and the bash tool have different + // notions of "default" and each owns that decision. + Cwd string + // Env is the full environment visible to the command. The caller is + // responsible for inheriting from os.Environ() if that's desired. + Env []string + // Stdin is the command's standard input. nil is equivalent to an empty + // input stream. + Stdin io.Reader + // Stdout receives the command's standard output. nil discards output. + Stdout io.Writer + // Stderr receives the command's standard error. nil discards output. + Stderr io.Writer + // BlockFuncs is an optional list of deny-list matchers applied before + // each command reaches the exec layer. nil disables blocking entirely. + BlockFuncs []BlockFunc +} + +// Run parses and executes a shell command using the same mvdan.cc/sh +// interpreter stack that the stateful [Shell] type uses (builtins, +// optional block list, optional Go coreutils). It is safe to call +// concurrently from multiple goroutines: each call builds its own +// [interp.Runner] and shares no state with other callers or with any +// [Shell] instance. +// +// Errors returned from the command itself (non-zero exit, context +// cancellation, parse failures) follow the same conventions as +// [Shell.Exec]: inspect with [IsInterrupt] and [ExitCode]. +func Run(ctx context.Context, opts RunOptions) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("command execution panic: %v", r) + } + }() + + if opts.Cwd == "" { + return fmt.Errorf("shell.Run: Cwd is required") + } + + stdout := opts.Stdout + if stdout == nil { + stdout = io.Discard + } + stderr := opts.Stderr + if stderr == nil { + stderr = io.Discard + } + + line, err := syntax.NewParser().Parse(strings.NewReader(opts.Command), "") + if err != nil { + return fmt.Errorf("could not parse command: %w", err) + } + + runner, err := newRunner(opts.Cwd, opts.Env, opts.Stdin, stdout, stderr, opts.BlockFuncs) + if err != nil { + return fmt.Errorf("could not run command: %w", err) + } + + return runner.Run(ctx, line) +} + +// newRunner constructs an [interp.Runner] configured with the standard +// Crush handler stack. Shared by the stateless [Run] entrypoint and the +// stateful [Shell] so the two surfaces cannot drift. +func newRunner(cwd string, env []string, stdin io.Reader, stdout, stderr io.Writer, blockFuncs []BlockFunc) (*interp.Runner, error) { + return interp.New( + interp.StdIO(stdin, stdout, stderr), + interp.Interactive(false), + interp.Env(expand.ListEnviron(env...)), + interp.Dir(cwd), + interp.ExecHandlers(standardHandlers(blockFuncs)...), + ) +} + +// standardHandlers returns the exec-handler middleware chain used by both +// [Run] and [Shell]. Order matters: +// 1. builtins first (so Crush's in-process jq wins over any PATH binary); +// 2. script dispatch (shebang / binary / shell-source for path-prefixed +// argv[0], no-op for bare commands) — runs before the block list so +// that deny rules see the already-resolved argv of anything the +// script exec's rather than the outer path-prefixed wrapper; +// 3. block list; +// 4. optional Go coreutils (only when useGoCoreUtils is on). +func standardHandlers(blockFuncs []BlockFunc) []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{ + builtinHandler(), + scriptDispatchHandler(blockFuncs), + blockHandler(blockFuncs), + } + if useGoCoreUtils { + handlers = append(handlers, coreutils.ExecHandler) + } + return handlers +} + +// builtinHandler returns middleware that dispatches recognized Crush +// builtins to their in-process Go implementations. Currently: jq. +func builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + switch args[0] { + case "jq": + hc := interp.HandlerCtx(ctx) + return handleJQ(ctx, args, hc.Stdin, hc.Stdout, hc.Stderr) + default: + return next(ctx, args) + } + } + } +} + +// blockHandler returns middleware that rejects commands matched by any of +// the provided [BlockFunc]s before they reach the underlying exec path. +// A nil or empty blockFuncs slice is a no-op. +func blockHandler(blockFuncs []BlockFunc) func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + for _, blockFunc := range blockFuncs { + if blockFunc(args) { + return fmt.Errorf("command is not allowed for security reasons: %q", args[0]) + } + } + return next(ctx, args) + } + } +} diff --git a/internal/shell/run_test.go b/internal/shell/run_test.go new file mode 100644 index 0000000000000000000000000000000000000000..15253a1f624bda0316c935ebe52f0d1519d60e63 --- /dev/null +++ b/internal/shell/run_test.go @@ -0,0 +1,252 @@ +package shell + +import ( + "bytes" + "context" + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" +) + +func TestRun_Echo(t *testing.T) { + var stdout, stderr bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "echo hi", + Cwd: t.TempDir(), + Stdout: &stdout, + Stderr: &stderr, + }) + if err != nil { + t.Fatalf("Run returned error: %v (stderr=%q)", err, stderr.String()) + } + if got := stdout.String(); got != "hi\n" { + t.Fatalf("stdout = %q, want %q", got, "hi\n") + } +} + +func TestRun_ExitCode(t *testing.T) { + err := Run(t.Context(), RunOptions{ + Command: "exit 7", + Cwd: t.TempDir(), + }) + if err == nil { + t.Fatal("expected error for exit 7, got nil") + } + if code := ExitCode(err); code != 7 { + t.Fatalf("ExitCode = %d, want 7", code) + } +} + +func TestRun_Stdin(t *testing.T) { + // Use the `read` shell builtin so the test doesn't depend on any + // external binary being on PATH (we pass an empty Env here). + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "read line; echo got:$line", + Cwd: t.TempDir(), + Stdin: strings.NewReader("hello\n"), + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); got != "got:hello\n" { + t.Fatalf("stdout = %q, want %q", got, "got:hello\n") + } +} + +func TestRun_Env(t *testing.T) { + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: `echo "$FOO"`, + Cwd: t.TempDir(), + Env: []string{"FOO=bar"}, + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); got != "bar\n" { + t.Fatalf("stdout = %q, want %q", got, "bar\n") + } +} + +func TestRun_Cwd(t *testing.T) { + dir := t.TempDir() + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "pwd", + Cwd: dir, + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + // mvdan's pwd builtin resolves symlinks (e.g. /var -> /private/var on + // macOS). Compare against a suffix so we don't get bitten by that. + got := strings.TrimRight(stdout.String(), "\n") + if !strings.HasSuffix(got, dir) && !strings.HasSuffix(dir, got) { + t.Fatalf("pwd = %q, want it to match %q", got, dir) + } +} + +func TestRun_JqBuiltin(t *testing.T) { + var stdout bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: `echo '{"a":1}' | jq .a`, + Cwd: t.TempDir(), + Stdout: &stdout, + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + if got := stdout.String(); got != "1\n" { + t.Fatalf("stdout = %q, want %q", got, "1\n") + } +} + +func TestRun_ParallelIsolation(t *testing.T) { + const n = 10 + var wg sync.WaitGroup + wg.Add(n) + errs := make([]error, n) + outs := make([]string, n) + dirs := make([]string, n) + for i := range n { + dirs[i] = t.TempDir() + go func(i int) { + defer wg.Done() + var stdout bytes.Buffer + errs[i] = Run(t.Context(), RunOptions{ + Command: `echo "$MARKER"`, + Cwd: dirs[i], + Env: []string{fmt.Sprintf("MARKER=id-%d", i)}, + Stdout: &stdout, + }) + outs[i] = stdout.String() + }(i) + } + wg.Wait() + for i := range n { + if errs[i] != nil { + t.Errorf("goroutine %d: err = %v", i, errs[i]) + continue + } + want := fmt.Sprintf("id-%d\n", i) + if outs[i] != want { + t.Errorf("goroutine %d: stdout = %q, want %q", i, outs[i], want) + } + } +} + +// TestRun_CtxCancel_BusyLoop verifies that a pure-shell loop respects ctx +// cancellation. mvdan's interpreter checks ctx between statements, so this +// should return quickly even without any external command. The test bounds +// its own wait via a select so a regression can't hang CI. +func TestRun_CtxCancel_BusyLoop(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond) + t.Cleanup(cancel) + + done := make(chan error, 1) + go func() { + done <- Run(ctx, RunOptions{ + Command: "while true; do :; done", + Cwd: t.TempDir(), + }) + }() + + select { + case err := <-done: + if !IsInterrupt(err) && !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected interrupt/deadline error, got: %v", err) + } + case <-time.After(1500 * time.Millisecond): + t.Fatal("Run did not return within 1.5s after ctx cancel") + } +} + +// TestRun_CtxCancel_ExternalSleep verifies ctx cancellation reaches an +// external process via mvdan's default exec. Uses sleep, which lives in +// coreutils on Windows and /bin on Unix. +func TestRun_CtxCancel_ExternalSleep(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond) + t.Cleanup(cancel) + + done := make(chan error, 1) + start := time.Now() + go func() { + done <- Run(ctx, RunOptions{ + Command: "sleep 30", + Cwd: t.TempDir(), + }) + }() + + select { + case err := <-done: + elapsed := time.Since(start) + if elapsed > time.Second { + t.Fatalf("sleep took too long to cancel: %v", elapsed) + } + if err == nil { + t.Fatal("expected non-nil error from cancelled sleep") + } + case <-time.After(time.Second): + t.Fatal("Run did not return within 1s after ctx cancel") + } +} + +func TestRun_ParseError(t *testing.T) { + err := Run(t.Context(), RunOptions{ + Command: "echo 'unterminated", + Cwd: t.TempDir(), + }) + if err == nil { + t.Fatal("expected parse error, got nil") + } + if !strings.Contains(err.Error(), "parse") { + t.Fatalf("error should mention parse: %v", err) + } +} + +func TestRun_BlockFuncs(t *testing.T) { + block := CommandsBlocker([]string{"forbidden"}) + var stderr bytes.Buffer + err := Run(t.Context(), RunOptions{ + Command: "forbidden", + Cwd: t.TempDir(), + Stderr: &stderr, + BlockFuncs: []BlockFunc{block}, + }) + if err == nil { + t.Fatal("expected error when running blocked command") + } + if !strings.Contains(err.Error(), "not allowed") { + t.Fatalf("expected 'not allowed' error, got: %v", err) + } +} + +func TestRun_RequiresCwd(t *testing.T) { + err := Run(t.Context(), RunOptions{ + Command: "echo hi", + }) + if err == nil { + t.Fatal("expected error when Cwd is empty, got nil") + } + if !strings.Contains(err.Error(), "Cwd is required") { + t.Fatalf("error should mention Cwd requirement: %v", err) + } +} + +func TestRun_DiscardsNilWriters(t *testing.T) { + // No panic when Stdout/Stderr are nil. + err := Run(t.Context(), RunOptions{ + Command: "echo hi; echo err >&2", + Cwd: t.TempDir(), + }) + if err != nil { + t.Fatalf("Run returned error: %v", err) + } +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 5ca0a25b57d5e80778a6fa95bdd3eb991638db4b..a9c1c83c117fe9b58a55e3440e985ad3046e478a 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -21,8 +21,6 @@ import ( "sync" "github.com/charmbracelet/x/exp/slice" - "mvdan.cc/sh/moreinterp/coreutils" - "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/interp" "mvdan.cc/sh/v3/syntax" ) @@ -36,6 +34,20 @@ const ( ShellTypePowerShell ) +// CrushEnvMarkers returns a fresh slice of the environment variables that +// Crush unconditionally sets on every shell it spawns — both the interactive +// bash tool's [Shell] and the hook runner's [Run] calls. Tools that want to +// detect "am I being invoked by an AI agent?" can check any of these. +// Keeping them in one place guarantees the two shell surfaces cannot drift. +// A fresh slice is returned on every call so callers may append freely. +func CrushEnvMarkers() []string { + return []string{ + "CRUSH=1", + "AGENT=crush", + "AI_AGENT=crush", + } +} + // Logger interface for optional logging type Logger interface { InfoPersist(msg string, keysAndValues ...any) @@ -83,12 +95,7 @@ func NewShell(opts *Options) *Shell { } // Allow tools to detect execution by Crush. - env = append( - env, - "CRUSH=1", - "AGENT=crush", - "AI_AGENT=crush", - ) + env = append(env, CrushEnvMarkers()...) logger := opts.Logger if logger == nil { @@ -226,52 +233,10 @@ func splitArgsFlags(parts []string) (args []string, flags []string) { return args, flags } -func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - return func(ctx context.Context, args []string) error { - if len(args) == 0 { - return next(ctx, args) - } - - for _, blockFunc := range s.blockFuncs { - if blockFunc(args) { - return fmt.Errorf("command is not allowed for security reasons: %q", args[0]) - } - } - - return next(ctx, args) - } - } -} - -func (s *Shell) builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - return func(ctx context.Context, args []string) error { - if len(args) == 0 { - return next(ctx, args) - } - - // Builtins. - switch args[0] { - case "jq": - hc := interp.HandlerCtx(ctx) - return handleJQ(args, hc.Stdin, hc.Stdout, hc.Stderr) - default: - return next(ctx, args) - } - } - } -} - -// newInterp creates a new interpreter with the current shell state -func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) { - return interp.New( - interp.StdIO(nil, stdout, stderr), - interp.Interactive(false), - interp.Env(expand.ListEnviron(s.env...)), - interp.Dir(s.cwd), - interp.ExecHandlers(s.execHandlers()...), - ) +// newInterp creates a new interpreter with the current shell state. A nil +// stdin is equivalent to an empty input stream. +func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) { + return newRunner(s.cwd, s.env, stdin, stdout, stderr, s.blockFuncs) } // updateShellFromRunner updates the shell from the interpreter after execution. @@ -303,7 +268,7 @@ func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr i return fmt.Errorf("could not parse command: %w", err) } - runner, err = s.newInterp(stdout, stderr) + runner, err = s.newInterp(nil, stdout, stderr) if err != nil { return fmt.Errorf("could not run command: %w", err) } @@ -324,17 +289,6 @@ func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr i return s.execCommon(ctx, command, stdout, stderr) } -func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { - handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{ - s.builtinHandler(), - s.blockHandler(), - } - if useGoCoreUtils { - handlers = append(handlers, coreutils.ExecHandler) - } - return handlers -} - // IsInterrupt checks if an error is due to interruption func IsInterrupt(err error) bool { return errors.Is(err, context.Canceled) ||