fix(shell): ctx-aware jq builtin

Christian Rocha created

This in necessary so hook timeouts can now interrupt long-running
filters and large stdin reads for the builtin jq tool.

Change summary

.agents/skills/shell-builtins/SKILL.md |  47 ++++-
internal/shell/jq.go                   |  73 +++++++++
internal/shell/jq_test.go              | 199 ++++++++++++++++++++++++++++
internal/shell/run.go                  |   2 
4 files changed, 303 insertions(+), 18 deletions(-)

Detailed changes

.agents/skills/shell-builtins/SKILL.md 🔗

@@ -11,38 +11,59 @@ emulation. Commands can be intercepted before they reach the OS by adding
 
 ## How Builtins Work
 
-Builtins live in `Shell.builtinHandler()` in `internal/shell/shell.go`.
-This is an `interp.ExecHandlerFunc` middleware registered in
-`execHandlers()` **before** the block handler, so builtins run even for
-commands that would otherwise be blocked.
+Builtins live in `builtinHandler()` in `internal/shell/run.go`. This is an
+`interp.ExecHandlerFunc` middleware registered in `standardHandlers()`
+**before** the block handler, so builtins run even for commands that would
+otherwise be blocked. The same handler chain is shared by the stateful
+`Shell` type and the stateless `Run` entrypoint used by the hook runner,
+so builtins are available identically in the `bash` tool and in hooks.
 
 The handler is a switch on `args[0]`. Each case either handles the command
 inline or delegates to a helper function.
 
 ## Adding a New Builtin
 
-1. **Add the case** to the switch in `builtinHandler()` in `shell.go`.
+1. **Add the case** to the switch in `builtinHandler()` in `run.go`.
 2. **Get I/O from the handler context**, not from `os.Stdin`/`os.Stdout`.
    This ensures the builtin works with pipes and redirections:
    ```go
    case "mycommand":
        hc := interp.HandlerCtx(ctx)
-       return handleMyCommand(args, hc.Stdin, hc.Stdout, hc.Stderr)
+       return handleMyCommand(ctx, args, hc.Stdin, hc.Stdout, hc.Stderr)
    ```
 3. **Implement the handler** in its own file (e.g.,
-   `internal/shell/mycommand.go`). The function signature should accept
-   args, stdin, stdout, and stderr:
+   `internal/shell/mycommand.go`). The function signature must accept a
+   `context.Context` as the first parameter, plus args, stdin, stdout, and
+   stderr:
    ```go
-   func handleMyCommand(args []string, stdin io.Reader, stdout, stderr io.Writer) error {
+   func handleMyCommand(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
        // args[0] is the command name ("mycommand"), args[1:] are arguments.
        // Write output to stdout, errors to stderr.
        // Return nil on success, or interp.ExitStatus(n) for non-zero exit codes.
    }
    ```
-4. **Return values**: return `nil` for success, `interp.ExitStatus(n)` for
-   non-zero exit codes. Write error messages to `stderr` before returning.
-5. **No extra wiring needed** — `builtinHandler()` is already registered
-   in `execHandlers()`.
+4. **Poll `ctx` in every unbounded loop.** Builtins that iterate over
+   input, emit values in a generator-style loop, or do any other work
+   that can exceed a few milliseconds MUST check `ctx.Err()` on each
+   iteration and return it verbatim when non-nil. Hook timeouts rely on
+   this: an unbounded builtin that never polls ctx cannot be interrupted
+   by a hook's `timeout_sec`, and the hook runner will have to abandon
+   the goroutine (see `internal/hooks/runner.go`). Returning `ctx.Err()`
+   (not `interp.ExitStatus(n)`) lets callers distinguish "command exited
+   non-zero" from "we ran out of time".
+   ```go
+   for _, item := range items {
+       if err := ctx.Err(); err != nil {
+           return err
+       }
+       // ... process item
+   }
+   ```
+5. **Return values**: return `nil` for success, `interp.ExitStatus(n)` for
+   non-zero exit codes, or `ctx.Err()` on cancellation. Write error
+   messages to `stderr` before returning.
+6. **No extra wiring needed** — `builtinHandler()` is already registered
+   in `standardHandlers()`.
 
 ## Existing Builtins
 

internal/shell/jq.go 🔗

@@ -1,6 +1,7 @@
 package shell
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -36,10 +37,16 @@ Options:
 // flags: -r (raw output), -c (compact output), -s (slurp), -n (null input),
 // -e (exit status), -R (raw input), and --arg name value.
 //
+// ctx is polled at each iteration of the output loop and at each reader in
+// [readInputs] so that hook timeouts or other cancellations can interrupt
+// long-running queries. A cancelled context surfaces as ctx.Err(), not an
+// [interp.ExitStatus], so callers (e.g. the hook runner) can distinguish
+// "filter exited non-zero" from "we ran out of time".
+//
 // Note that this is somewhat of a reimplmentation of the CLI of the glorious
 // github.com/itchyny/gojq, and we'd ideally get the CLI exposed upstream to
 // avoid this falling out of sync.
-func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error {
+func handleJQ(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
 	var (
 		rawOutput  bool
 		compact    bool
@@ -143,8 +150,13 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error {
 	}
 
 	// Build input values.
-	inputs, err := readInputs(stdin, fileArgs, nullInput, rawInput, slurp)
+	inputs, err := readInputs(ctx, stdin, fileArgs, nullInput, rawInput, slurp)
 	if err != nil {
+		// Prefer surfacing ctx cancellation verbatim so timeouts are
+		// distinguishable from user input errors.
+		if ctxErr := ctx.Err(); ctxErr != nil {
+			return ctxErr
+		}
 		fmt.Fprintf(stderr, "jq: %s\n", err)
 		return interp.ExitStatus(2)
 	}
@@ -153,6 +165,12 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error {
 	for _, input := range inputs {
 		iter := code.Run(input, argValues...)
 		for {
+			// Poll ctx on every value so a long-running filter (e.g. a
+			// generator over a slurped array) can be interrupted by hook
+			// timeouts without waiting for iter.Next to yield.
+			if err := ctx.Err(); err != nil {
+				return err
+			}
 			v, ok := iter.Next()
 			if !ok {
 				break
@@ -177,7 +195,22 @@ func handleJQ(args []string, stdin io.Reader, stdout, stderr io.Writer) error {
 }
 
 // readInputs reads JSON (or raw) input values from stdin or files.
-func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool) ([]any, error) {
+//
+// ctx is polled in three places so that a cancellation observed mid-read
+// short-circuits promptly:
+//   - between readers (before opening the next file / consuming stdin);
+//   - on every io.Read call via ctxReader, so io.ReadAll on a large but
+//     non-blocking source (e.g. the bytes.NewReader payload the hook
+//     runner supplies) returns ctx.Err() on the next chunk boundary;
+//   - inside the post-read value accumulation loops (raw-input line
+//     split and JSON stream decode), which are otherwise unbounded in
+//     the size of the input.
+//
+// A reader that blocks forever in Read (e.g. an unterminated pipe) can
+// still outlast ctx; the outer abandon-goroutine path in the hook
+// runner (internal/hooks/runner.go) is the authoritative enforcer for
+// that case.
+func readInputs(ctx context.Context, stdin io.Reader, files []string, nullInput, rawInput, slurp bool) ([]any, error) {
 	if nullInput {
 		return []any{nil}, nil
 	}
@@ -198,8 +231,16 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool
 
 	var vals []any
 	for _, r := range readers {
-		data, err := io.ReadAll(r)
+		if err := ctx.Err(); err != nil {
+			return nil, err
+		}
+		data, err := io.ReadAll(ctxReader{ctx: ctx, r: r})
 		if err != nil {
+			// ctxReader surfaces ctx.Err() verbatim; preserve it so the
+			// caller can distinguish cancellation from a parse error.
+			if ctxErr := ctx.Err(); ctxErr != nil {
+				return nil, ctxErr
+			}
 			return nil, err
 		}
 
@@ -209,6 +250,9 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool
 				vals = append(vals, strings.Join(lines, "\n"))
 			} else {
 				for _, line := range lines {
+					if err := ctx.Err(); err != nil {
+						return nil, err
+					}
 					if line != "" || !slurp {
 						vals = append(vals, line)
 					}
@@ -221,6 +265,9 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool
 		dec := json.NewDecoder(strings.NewReader(string(data)))
 		var streamVals []any
 		for {
+			if err := ctx.Err(); err != nil {
+				return nil, err
+			}
 			var v any
 			if err := dec.Decode(&v); err != nil {
 				if err == io.EOF {
@@ -244,6 +291,24 @@ func readInputs(stdin io.Reader, files []string, nullInput, rawInput, slurp bool
 	return vals, nil
 }
 
+// ctxReader wraps an io.Reader so that each Read call checks ctx first.
+// This makes io.ReadAll over a large but non-blocking source (e.g. a
+// bytes.Reader of the hook stdin payload) cancellable on the next chunk
+// boundary. A reader that itself blocks in Read will still outlast ctx —
+// the hook runner's abandon-goroutine path is the enforcer of last resort
+// for that case.
+type ctxReader struct {
+	ctx context.Context
+	r   io.Reader
+}
+
+func (cr ctxReader) Read(p []byte) (int, error) {
+	if err := cr.ctx.Err(); err != nil {
+		return 0, err
+	}
+	return cr.r.Read(p)
+}
+
 // writeValue writes a single jq output value.
 func writeValue(w io.Writer, v any, raw, compact, join bool) error {
 	if raw {

internal/shell/jq_test.go 🔗

@@ -0,0 +1,199 @@
+package shell
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"io"
+	"strings"
+	"testing"
+	"time"
+)
+
+// TestJQ_CtxCancel verifies that handleJQ polls ctx during iteration and
+// returns ctx.Err() (not an interp.ExitStatus) when the context is
+// cancelled. This is what lets hook timeouts interrupt long-running jq
+// filters rather than waiting for the iterator to terminate naturally.
+func TestJQ_CtxCancel(t *testing.T) {
+	t.Parallel()
+
+	// `range(N)` generates a large stream of values. With a slurped input
+	// the filter produces all N values in sequence; ctx cancellation
+	// between values should short-circuit the loop.
+	const filter = "range(10000000)"
+	stdin := strings.NewReader("null\n")
+
+	ctx, cancel := context.WithCancel(t.Context())
+	// Cancel almost immediately so we catch the next iteration check.
+	cancel()
+
+	err := handleJQ(ctx, []string{"jq", filter}, stdin, io.Discard, io.Discard)
+	if err == nil {
+		t.Fatal("expected ctx cancel error, got nil")
+	}
+	if !errors.Is(err, context.Canceled) {
+		t.Fatalf("expected context.Canceled, got %v", err)
+	}
+}
+
+// TestJQ_CtxCancel_DuringFilter verifies cancellation mid-stream: ctx is
+// cancelled after jq has started producing output, and the loop must
+// observe the cancel on the next iteration rather than running to
+// completion.
+func TestJQ_CtxCancel_DuringFilter(t *testing.T) {
+	t.Parallel()
+
+	ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond)
+	defer cancel()
+
+	// 100M values; without ctx polling this would take many seconds to
+	// fully emit. With ctx polling the loop exits shortly after the
+	// deadline.
+	stdin := strings.NewReader("null\n")
+	var stdout, stderr bytes.Buffer
+
+	start := time.Now()
+	err := handleJQ(ctx, []string{"jq", "-c", "range(100000000)"}, stdin, &stdout, &stderr)
+	elapsed := time.Since(start)
+
+	if err == nil {
+		t.Fatal("expected ctx timeout error, got nil")
+	}
+	if !errors.Is(err, context.DeadlineExceeded) {
+		t.Fatalf("expected context.DeadlineExceeded, got %v", err)
+	}
+	// Allow generous slack for slow CI; the important invariant is that we
+	// don't run all 100M iterations (which would take orders of magnitude
+	// longer than 1s).
+	if elapsed > time.Second {
+		t.Fatalf("handleJQ took %v after 50ms timeout; ctx polling is not tight enough", elapsed)
+	}
+}
+
+// slowReader serves bytes in small chunks with a fixed delay between
+// Read calls. It never blocks indefinitely — each Read returns after
+// chunkDelay — so cancellation must be observed via ctxReader's ctx
+// check, not by the underlying reader itself. That isolates the
+// behavior we want to test: the wrapper polling ctx between chunks.
+type slowReader struct {
+	remaining  []byte
+	chunk      int
+	chunkDelay time.Duration
+}
+
+func (s *slowReader) Read(p []byte) (int, error) {
+	if len(s.remaining) == 0 {
+		return 0, io.EOF
+	}
+	time.Sleep(s.chunkDelay)
+	n := min(len(p), min(s.chunk, len(s.remaining)))
+	copy(p, s.remaining[:n])
+	s.remaining = s.remaining[n:]
+	return n, nil
+}
+
+// TestJQ_CtxCancel_MidReadAll verifies that ctx cancellation observed
+// *during* io.ReadAll — after several chunks have already been consumed
+// — short-circuits the read via ctxReader, rather than draining the
+// whole source. This is the guarantee the hook runner relies on when
+// it feeds a large bytes.Reader payload.
+//
+// The reader serves bytes in 512-byte chunks with a 5ms gap between
+// reads. ctx is cancelled after ~50ms, so several chunks have already
+// been read when ctxReader first observes the cancellation. The test
+// asserts that (a) we got a context.Canceled error and (b) the call
+// returned well before the reader would have been fully drained.
+func TestJQ_CtxCancel_MidReadAll(t *testing.T) {
+	t.Parallel()
+
+	const (
+		size       = 64 * 1024 * 1024 // 64 MiB
+		chunk      = 512
+		chunkDelay = 5 * time.Millisecond
+	)
+	// At 512 bytes / 5ms, draining 64 MiB would take ~11 minutes. Any
+	// return within a second proves cancel was observed mid-stream, not
+	// after EOF.
+	reader := &slowReader{
+		remaining:  bytes.Repeat([]byte("a"), size),
+		chunk:      chunk,
+		chunkDelay: chunkDelay,
+	}
+
+	ctx, cancel := context.WithCancel(t.Context())
+	defer cancel()
+
+	// Cancel after enough time that several Read calls have completed
+	// and io.ReadAll is actively consuming the source.
+	go func() {
+		time.Sleep(50 * time.Millisecond)
+		cancel()
+	}()
+
+	start := time.Now()
+	err := handleJQ(ctx, []string{"jq", "-R", "."}, reader, io.Discard, io.Discard)
+	elapsed := time.Since(start)
+
+	if !errors.Is(err, context.Canceled) {
+		t.Fatalf("expected context.Canceled, got %v", err)
+	}
+	// Generous slack for slow CI; the invariant is orders-of-magnitude
+	// faster than draining the full source.
+	if elapsed > time.Second {
+		t.Fatalf("mid-ReadAll cancel took %v; ctxReader is not polling between chunks", elapsed)
+	}
+	// Sanity check: we should have been cancelled mid-stream, not
+	// before any reads happened. If remaining == size, cancel fired so
+	// early nothing was consumed — that's a fast-fail path, not the
+	// mid-read guarantee we want to verify.
+	consumed := size - len(reader.remaining)
+	if consumed == 0 {
+		t.Fatal("reader was never read from; test did not exercise mid-ReadAll cancel")
+	}
+	if consumed >= size {
+		t.Fatal("reader was fully drained; cancel was not observed mid-read")
+	}
+}
+
+// TestJQ_CtxCancel_PreCancel verifies the fast-fail path: a ctx already
+// cancelled before handleJQ is called returns context.Canceled
+// immediately via the outer-loop guard, never entering io.ReadAll.
+// Complements TestJQ_CtxCancel_MidReadAll.
+func TestJQ_CtxCancel_PreCancel(t *testing.T) {
+	t.Parallel()
+
+	ctx, cancel := context.WithCancel(t.Context())
+	cancel()
+
+	start := time.Now()
+	err := handleJQ(ctx, []string{"jq", "-R", "."},
+		bytes.NewReader(bytes.Repeat([]byte("a"), 1024)),
+		io.Discard, io.Discard)
+	elapsed := time.Since(start)
+
+	if !errors.Is(err, context.Canceled) {
+		t.Fatalf("expected context.Canceled, got %v", err)
+	}
+	if elapsed > 100*time.Millisecond {
+		t.Fatalf("pre-cancel fast-fail took %v; outer guard is not firing", elapsed)
+	}
+}
+
+// TestJQ_Success confirms the ctx-aware refactor did not regress the
+// success path.
+func TestJQ_Success(t *testing.T) {
+	t.Parallel()
+
+	var stdout bytes.Buffer
+	err := handleJQ(t.Context(),
+		[]string{"jq", "-c", ".a"},
+		strings.NewReader(`{"a":1}`),
+		&stdout, io.Discard,
+	)
+	if err != nil {
+		t.Fatalf("handleJQ returned error: %v", err)
+	}
+	if got := stdout.String(); got != "1\n" {
+		t.Fatalf("stdout = %q, want %q", got, "1\n")
+	}
+}

internal/shell/run.go 🔗

@@ -130,7 +130,7 @@ func builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
 			switch args[0] {
 			case "jq":
 				hc := interp.HandlerCtx(ctx)
-				return handleJQ(args, hc.Stdin, hc.Stdout, hc.Stderr)
+				return handleJQ(ctx, args, hc.Stdin, hc.Stdout, hc.Stderr)
 			default:
 				return next(ctx, args)
 			}