diff --git a/.agents/skills/shell-builtins/SKILL.md b/.agents/skills/shell-builtins/SKILL.md index a9914a08aa1b506006fcb7ef35ded092efc7f895..f5953d9ff1f00631361e8598d9401a1f8b1a809f 100644 --- a/.agents/skills/shell-builtins/SKILL.md +++ b/.agents/skills/shell-builtins/SKILL.md @@ -11,38 +11,59 @@ emulation. Commands can be intercepted before they reach the OS by adding ## How Builtins Work -Builtins live in `Shell.builtinHandler()` in `internal/shell/shell.go`. -This is an `interp.ExecHandlerFunc` middleware registered in -`execHandlers()` **before** the block handler, so builtins run even for -commands that would otherwise be blocked. +Builtins live in `builtinHandler()` in `internal/shell/run.go`. This is an +`interp.ExecHandlerFunc` middleware registered in `standardHandlers()` +**before** the block handler, so builtins run even for commands that would +otherwise be blocked. The same handler chain is shared by the stateful +`Shell` type and the stateless `Run` entrypoint used by the hook runner, +so builtins are available identically in the `bash` tool and in hooks. The handler is a switch on `args[0]`. Each case either handles the command inline or delegates to a helper function. ## Adding a New Builtin -1. **Add the case** to the switch in `builtinHandler()` in `shell.go`. +1. **Add the case** to the switch in `builtinHandler()` in `run.go`. 2. **Get I/O from the handler context**, not from `os.Stdin`/`os.Stdout`. This ensures the builtin works with pipes and redirections: ```go case "mycommand": hc := interp.HandlerCtx(ctx) - return handleMyCommand(args, hc.Stdin, hc.Stdout, hc.Stderr) + return handleMyCommand(ctx, args, hc.Stdin, hc.Stdout, hc.Stderr) ``` 3. **Implement the handler** in its own file (e.g., - `internal/shell/mycommand.go`). The function signature should accept - args, stdin, stdout, and stderr: + `internal/shell/mycommand.go`). The function signature must accept a + `context.Context` as the first parameter, plus args, stdin, stdout, and + stderr: ```go - func handleMyCommand(args []string, stdin io.Reader, stdout, stderr io.Writer) error { + func handleMyCommand(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { // args[0] is the command name ("mycommand"), args[1:] are arguments. // Write output to stdout, errors to stderr. // Return nil on success, or interp.ExitStatus(n) for non-zero exit codes. } ``` -4. **Return values**: return `nil` for success, `interp.ExitStatus(n)` for - non-zero exit codes. Write error messages to `stderr` before returning. -5. **No extra wiring needed** — `builtinHandler()` is already registered - in `execHandlers()`. +4. **Poll `ctx` in every unbounded loop.** Builtins that iterate over + input, emit values in a generator-style loop, or do any other work + that can exceed a few milliseconds MUST check `ctx.Err()` on each + iteration and return it verbatim when non-nil. Hook timeouts rely on + this: an unbounded builtin that never polls ctx cannot be interrupted + by a hook's `timeout_sec`, and the hook runner will have to abandon + the goroutine (see `internal/hooks/runner.go`). Returning `ctx.Err()` + (not `interp.ExitStatus(n)`) lets callers distinguish "command exited + non-zero" from "we ran out of time". + ```go + for _, item := range items { + if err := ctx.Err(); err != nil { + return err + } + // ... process item + } + ``` +5. **Return values**: return `nil` for success, `interp.ExitStatus(n)` for + non-zero exit codes, or `ctx.Err()` on cancellation. Write error + messages to `stderr` before returning. +6. **No extra wiring needed** — `builtinHandler()` is already registered + in `standardHandlers()`. ## Existing Builtins diff --git a/internal/shell/jq.go b/internal/shell/jq.go index ceac574df13c97befa05e8817b91a8d928a96a11..4204e5722f4f5ce6e6e6e2f9f8e1acf26ba5438a 100644 --- a/internal/shell/jq.go +++ b/internal/shell/jq.go @@ -1,6 +1,7 @@ package shell import ( + "context" "encoding/json" "fmt" "io" @@ -36,10 +37,16 @@ Options: // flags: -r (raw output), -c (compact output), -s (slurp), -n (null input), // -e (exit status), -R (raw input), and --arg name value. // +// ctx is polled at each iteration of the output loop and at each reader in +// [readInputs] so that hook timeouts or other cancellations can interrupt +// long-running queries. A cancelled context surfaces as ctx.Err(), not an +// [interp.ExitStatus], so callers (e.g. the hook runner) can distinguish +// "filter exited non-zero" from "we ran out of time". +// // Note that this is somewhat of a reimplmentation of the CLI of the glorious // github.com/itchyny/gojq, and we'd ideally get the CLI exposed upstream to // avoid this falling out of sync. -func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error { +func handleJQ(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { var ( rawOutput bool compact bool @@ -143,8 +150,13 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error { } // Build input values. - inputs, err := readInputs(stdin, fileArgs, nullInput, rawInput, slurp) + inputs, err := readInputs(ctx, stdin, fileArgs, nullInput, rawInput, slurp) if err != nil { + // Prefer surfacing ctx cancellation verbatim so timeouts are + // distinguishable from user input errors. + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } fmt.Fprintf(stderr, "jq: %s\n", err) return interp.ExitStatus(2) } @@ -153,6 +165,12 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error { for _, input := range inputs { iter := code.Run(input, argValues...) for { + // Poll ctx on every value so a long-running filter (e.g. a + // generator over a slurped array) can be interrupted by hook + // timeouts without waiting for iter.Next to yield. + if err := ctx.Err(); err != nil { + return err + } v, ok := iter.Next() if !ok { break @@ -177,7 +195,22 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error { } // readInputs reads JSON (or raw) input values from stdin or files. -func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool) ([]any, error) { +// +// ctx is polled in three places so that a cancellation observed mid-read +// short-circuits promptly: +// - between readers (before opening the next file / consuming stdin); +// - on every io.Read call via ctxReader, so io.ReadAll on a large but +// non-blocking source (e.g. the bytes.NewReader payload the hook +// runner supplies) returns ctx.Err() on the next chunk boundary; +// - inside the post-read value accumulation loops (raw-input line +// split and JSON stream decode), which are otherwise unbounded in +// the size of the input. +// +// A reader that blocks forever in Read (e.g. an unterminated pipe) can +// still outlast ctx; the outer abandon-goroutine path in the hook +// runner (internal/hooks/runner.go) is the authoritative enforcer for +// that case. +func readInputs(ctx context.Context, stdin io.Reader, files []string, nullInput, rawInput, slurp bool) ([]any, error) { if nullInput { return []any{nil}, nil } @@ -198,8 +231,16 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool var vals []any for _, r := range readers { - data, err := io.ReadAll(r) + if err := ctx.Err(); err != nil { + return nil, err + } + data, err := io.ReadAll(ctxReader{ctx: ctx, r: r}) if err != nil { + // ctxReader surfaces ctx.Err() verbatim; preserve it so the + // caller can distinguish cancellation from a parse error. + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, ctxErr + } return nil, err } @@ -209,6 +250,9 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool vals = append(vals, strings.Join(lines, "\n")) } else { for _, line := range lines { + if err := ctx.Err(); err != nil { + return nil, err + } if line != "" || !slurp { vals = append(vals, line) } @@ -221,6 +265,9 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool dec := json.NewDecoder(strings.NewReader(string(data))) var streamVals []any for { + if err := ctx.Err(); err != nil { + return nil, err + } var v any if err := dec.Decode(&v); err != nil { if err == io.EOF { @@ -244,6 +291,24 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool return vals, nil } +// ctxReader wraps an io.Reader so that each Read call checks ctx first. +// This makes io.ReadAll over a large but non-blocking source (e.g. a +// bytes.Reader of the hook stdin payload) cancellable on the next chunk +// boundary. A reader that itself blocks in Read will still outlast ctx — +// the hook runner's abandon-goroutine path is the enforcer of last resort +// for that case. +type ctxReader struct { + ctx context.Context + r io.Reader +} + +func (cr ctxReader) Read(p []byte) (int, error) { + if err := cr.ctx.Err(); err != nil { + return 0, err + } + return cr.r.Read(p) +} + // writeValue writes a single jq output value. func writeValue(w io.Writer, v any, raw, compact, join bool) error { if raw { diff --git a/internal/shell/jq_test.go b/internal/shell/jq_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e449047b49b27da9b2d587b6defe9d8f81fe2888 --- /dev/null +++ b/internal/shell/jq_test.go @@ -0,0 +1,199 @@ +package shell + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "testing" + "time" +) + +// TestJQ_CtxCancel verifies that handleJQ polls ctx during iteration and +// returns ctx.Err() (not an interp.ExitStatus) when the context is +// cancelled. This is what lets hook timeouts interrupt long-running jq +// filters rather than waiting for the iterator to terminate naturally. +func TestJQ_CtxCancel(t *testing.T) { + t.Parallel() + + // `range(N)` generates a large stream of values. With a slurped input + // the filter produces all N values in sequence; ctx cancellation + // between values should short-circuit the loop. + const filter = "range(10000000)" + stdin := strings.NewReader("null\n") + + ctx, cancel := context.WithCancel(t.Context()) + // Cancel almost immediately so we catch the next iteration check. + cancel() + + err := handleJQ(ctx, []string{"jq", filter}, stdin, io.Discard, io.Discard) + if err == nil { + t.Fatal("expected ctx cancel error, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +// TestJQ_CtxCancel_DuringFilter verifies cancellation mid-stream: ctx is +// cancelled after jq has started producing output, and the loop must +// observe the cancel on the next iteration rather than running to +// completion. +func TestJQ_CtxCancel_DuringFilter(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + + // 100M values; without ctx polling this would take many seconds to + // fully emit. With ctx polling the loop exits shortly after the + // deadline. + stdin := strings.NewReader("null\n") + var stdout, stderr bytes.Buffer + + start := time.Now() + err := handleJQ(ctx, []string{"jq", "-c", "range(100000000)"}, stdin, &stdout, &stderr) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected ctx timeout error, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } + // Allow generous slack for slow CI; the important invariant is that we + // don't run all 100M iterations (which would take orders of magnitude + // longer than 1s). + if elapsed > time.Second { + t.Fatalf("handleJQ took %v after 50ms timeout; ctx polling is not tight enough", elapsed) + } +} + +// slowReader serves bytes in small chunks with a fixed delay between +// Read calls. It never blocks indefinitely — each Read returns after +// chunkDelay — so cancellation must be observed via ctxReader's ctx +// check, not by the underlying reader itself. That isolates the +// behavior we want to test: the wrapper polling ctx between chunks. +type slowReader struct { + remaining []byte + chunk int + chunkDelay time.Duration +} + +func (s *slowReader) Read(p []byte) (int, error) { + if len(s.remaining) == 0 { + return 0, io.EOF + } + time.Sleep(s.chunkDelay) + n := min(len(p), min(s.chunk, len(s.remaining))) + copy(p, s.remaining[:n]) + s.remaining = s.remaining[n:] + return n, nil +} + +// TestJQ_CtxCancel_MidReadAll verifies that ctx cancellation observed +// *during* io.ReadAll — after several chunks have already been consumed +// — short-circuits the read via ctxReader, rather than draining the +// whole source. This is the guarantee the hook runner relies on when +// it feeds a large bytes.Reader payload. +// +// The reader serves bytes in 512-byte chunks with a 5ms gap between +// reads. ctx is cancelled after ~50ms, so several chunks have already +// been read when ctxReader first observes the cancellation. The test +// asserts that (a) we got a context.Canceled error and (b) the call +// returned well before the reader would have been fully drained. +func TestJQ_CtxCancel_MidReadAll(t *testing.T) { + t.Parallel() + + const ( + size = 64 * 1024 * 1024 // 64 MiB + chunk = 512 + chunkDelay = 5 * time.Millisecond + ) + // At 512 bytes / 5ms, draining 64 MiB would take ~11 minutes. Any + // return within a second proves cancel was observed mid-stream, not + // after EOF. + reader := &slowReader{ + remaining: bytes.Repeat([]byte("a"), size), + chunk: chunk, + chunkDelay: chunkDelay, + } + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + // Cancel after enough time that several Read calls have completed + // and io.ReadAll is actively consuming the source. + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := handleJQ(ctx, []string{"jq", "-R", "."}, reader, io.Discard, io.Discard) + elapsed := time.Since(start) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + // Generous slack for slow CI; the invariant is orders-of-magnitude + // faster than draining the full source. + if elapsed > time.Second { + t.Fatalf("mid-ReadAll cancel took %v; ctxReader is not polling between chunks", elapsed) + } + // Sanity check: we should have been cancelled mid-stream, not + // before any reads happened. If remaining == size, cancel fired so + // early nothing was consumed — that's a fast-fail path, not the + // mid-read guarantee we want to verify. + consumed := size - len(reader.remaining) + if consumed == 0 { + t.Fatal("reader was never read from; test did not exercise mid-ReadAll cancel") + } + if consumed >= size { + t.Fatal("reader was fully drained; cancel was not observed mid-read") + } +} + +// TestJQ_CtxCancel_PreCancel verifies the fast-fail path: a ctx already +// cancelled before handleJQ is called returns context.Canceled +// immediately via the outer-loop guard, never entering io.ReadAll. +// Complements TestJQ_CtxCancel_MidReadAll. +func TestJQ_CtxCancel_PreCancel(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + start := time.Now() + err := handleJQ(ctx, []string{"jq", "-R", "."}, + bytes.NewReader(bytes.Repeat([]byte("a"), 1024)), + io.Discard, io.Discard) + elapsed := time.Since(start) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + if elapsed > 100*time.Millisecond { + t.Fatalf("pre-cancel fast-fail took %v; outer guard is not firing", elapsed) + } +} + +// TestJQ_Success confirms the ctx-aware refactor did not regress the +// success path. +func TestJQ_Success(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + err := handleJQ(t.Context(), + []string{"jq", "-c", ".a"}, + strings.NewReader(`{"a":1}`), + &stdout, io.Discard, + ) + if err != nil { + t.Fatalf("handleJQ returned error: %v", err) + } + if got := stdout.String(); got != "1\n" { + t.Fatalf("stdout = %q, want %q", got, "1\n") + } +} diff --git a/internal/shell/run.go b/internal/shell/run.go index d0ff921e7e31a479d9e03ff33ed1fb05340e308c..cac7c530f0af415b6a9dbaa7292def1896ab318a 100644 --- a/internal/shell/run.go +++ b/internal/shell/run.go @@ -130,7 +130,7 @@ func builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { switch args[0] { case "jq": hc := interp.HandlerCtx(ctx) - return handleJQ(args, hc.Stdin, hc.Stdout, hc.Stderr) + return handleJQ(ctx, args, hc.Stdin, hc.Stdout, hc.Stderr) default: return next(ctx, args) }