runner.go

  1package hooks
  2
  3import (
  4	"bytes"
  5	"context"
  6	"log/slog"
  7	"regexp"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/shell"
 14)
 15
 16// abandonGrace is how long runOne waits after ctx cancellation for the
 17// shell goroutine to yield before returning control to the caller and
 18// letting the goroutine finish on its own. Mirrors the historical
 19// cmd.WaitDelay = time.Second behavior of the previous os/exec path.
 20const abandonGrace = time.Second
 21
 22// runShell is the shell executor used by runOne. It is a package-level
 23// variable so tests can substitute a blocking or non-yielding
 24// implementation to exercise the abandon-on-timeout path without
 25// depending on the scheduling behavior of the real interpreter.
 26var runShell = shell.Run
 27
 28// compiledHook pairs a HookConfig with its compiled matcher regex. A nil
 29// matcher means "match every tool".
 30type compiledHook struct {
 31	cfg     config.HookConfig
 32	matcher *regexp.Regexp
 33}
 34
 35// Runner executes hook commands and aggregates their results.
 36type Runner struct {
 37	hooks      []compiledHook
 38	cwd        string
 39	projectDir string
 40}
 41
 42// NewRunner creates a Runner from the given hook configs. Each hook's
 43// Matcher is compiled here so the Runner is self-sufficient; callers do
 44// not have to pre-compile matchers on the config, and reloads or merges
 45// that rebuild HookConfig values can't silently strip compiled state.
 46//
 47// Hooks whose matcher fails to compile are skipped with a warning rather
 48// than treated as match-everything. ValidateHooks is expected to have
 49// caught syntax errors earlier, so this is defense in depth.
 50func NewRunner(hooks []config.HookConfig, cwd, projectDir string) *Runner {
 51	compiled := make([]compiledHook, 0, len(hooks))
 52	for _, h := range hooks {
 53		ch := compiledHook{cfg: h}
 54		if h.Matcher != "" {
 55			re, err := regexp.Compile(h.Matcher)
 56			if err != nil {
 57				slog.Warn(
 58					"Hook matcher failed to compile; skipping hook",
 59					"matcher", h.Matcher,
 60					"command", h.Command,
 61					"error", err,
 62				)
 63				continue
 64			}
 65			ch.matcher = re
 66		}
 67		compiled = append(compiled, ch)
 68	}
 69	return &Runner{
 70		hooks:      compiled,
 71		cwd:        cwd,
 72		projectDir: projectDir,
 73	}
 74}
 75
 76// Hooks returns the hook configs the runner was created with, in config
 77// order. Hooks whose matcher failed to compile at construction are
 78// omitted. Intended for diagnostics; callers should not rely on ordering
 79// or identity beyond that.
 80func (r *Runner) Hooks() []config.HookConfig {
 81	out := make([]config.HookConfig, len(r.hooks))
 82	for i, h := range r.hooks {
 83		out[i] = h.cfg
 84	}
 85	return out
 86}
 87
 88// Run executes all matching hooks for the given event and tool, returning
 89// an aggregated result.
 90func (r *Runner) Run(ctx context.Context, eventName, sessionID, toolName, toolInputJSON string) (AggregateResult, error) {
 91	matching := r.matchingHooks(toolName)
 92	if len(matching) == 0 {
 93		return AggregateResult{Decision: DecisionNone}, nil
 94	}
 95
 96	// Deduplicate by command string.
 97	seen := make(map[string]bool, len(matching))
 98	var deduped []config.HookConfig
 99	for _, h := range matching {
100		if seen[h.Command] {
101			continue
102		}
103		seen[h.Command] = true
104		deduped = append(deduped, h)
105	}
106
107	envVars := BuildEnv(eventName, toolName, sessionID, r.cwd, r.projectDir, toolInputJSON)
108	payload := BuildPayload(eventName, sessionID, r.cwd, toolName, toolInputJSON)
109
110	results := make([]HookResult, len(deduped))
111	var wg sync.WaitGroup
112	wg.Add(len(deduped))
113
114	for i, h := range deduped {
115		go func(idx int, hook config.HookConfig) {
116			defer wg.Done()
117			results[idx] = r.runOne(ctx, hook, envVars, payload)
118		}(i, h)
119	}
120	wg.Wait()
121
122	agg := aggregate(results, toolInputJSON)
123	agg.Hooks = make([]HookInfo, len(deduped))
124	for i, h := range deduped {
125		agg.Hooks[i] = HookInfo{
126			Name:         h.DisplayName(),
127			Matcher:      h.Matcher,
128			Decision:     results[i].Decision.String(),
129			Halt:         results[i].Halt,
130			Reason:       results[i].Reason,
131			InputRewrite: results[i].UpdatedInput != "",
132		}
133	}
134	slog.Info(
135		"Hook completed",
136		"event", eventName,
137		"tool", toolName,
138		"hooks", len(deduped),
139		"decision", agg.Decision.String(),
140	)
141	return agg, nil
142}
143
144// matchingHooks returns hooks whose matcher matches the tool name (or has
145// no matcher, which matches everything).
146func (r *Runner) matchingHooks(toolName string) []config.HookConfig {
147	var matched []config.HookConfig
148	for _, h := range r.hooks {
149		if h.matcher == nil || h.matcher.MatchString(toolName) {
150			matched = append(matched, h.cfg)
151		}
152	}
153	return matched
154}
155
156// runOne executes a single hook command and returns its result.
157//
158// Execution goes through Crush's embedded POSIX shell (shell.Run) so the
159// same interpreter, builtins, and coreutils are visible to hooks as to
160// the bash tool. BlockFuncs are intentionally omitted: hooks are
161// user-authored config that carry the same trust as a shell alias.
162//
163// A hook that fails to yield after its deadline has passed is abandoned
164// after abandonGrace so the caller never blocks longer than
165// timeout + abandonGrace. Ownership of the stdout and stderr buffers is
166// strictly single-goroutine:
167//   - before receiving from `done`, only the goroutine writes to them;
168//   - after `done` delivers a value, the goroutine is finished and the
169//     outer frame reads them;
170//   - on the abandon path, the goroutine may still be writing and the
171//     outer frame must not touch them again.
172func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult {
173	timeout := hook.TimeoutDuration()
174	ctx, cancel := context.WithTimeout(parentCtx, timeout)
175	defer cancel()
176
177	var stdout, stderr bytes.Buffer
178	done := make(chan error, 1)
179	go func() {
180		done <- runShell(ctx, shell.RunOptions{
181			Command: hook.Command,
182			Cwd:     r.cwd,
183			Env:     envVars,
184			Stdin:   bytes.NewReader(payload),
185			Stdout:  &stdout,
186			Stderr:  &stderr,
187		})
188	}()
189
190	var err error
191	select {
192	case err = <-done:
193		// Normal path: goroutine has finished, buffers are safe to read.
194	case <-ctx.Done():
195		select {
196		case err = <-done:
197			// Interpreter yielded within the grace period; safe to read.
198		case <-time.After(abandonGrace):
199			slog.Warn(
200				"Hook did not yield after cancel; abandoning goroutine",
201				"command", hook.Command,
202				"timeout", timeout,
203			)
204			// The goroutine may still be writing to stdout/stderr; do
205			// not read either buffer below this point.
206			return HookResult{Decision: DecisionNone}
207		}
208	}
209
210	if shell.IsInterrupt(err) {
211		// Distinguish timeout from parent cancellation.
212		if parentCtx.Err() != nil {
213			slog.Debug("Hook cancelled by parent context", "command", hook.Command)
214		} else {
215			slog.Warn("Hook timed out", "command", hook.Command, "timeout", timeout)
216		}
217		return HookResult{Decision: DecisionNone}
218	}
219
220	if err != nil {
221		exitCode := shell.ExitCode(err)
222		switch exitCode {
223		case 2:
224			// Exit code 2 = block this tool call. Stderr is the reason.
225			reason := strings.TrimSpace(stderr.String())
226			if reason == "" {
227				reason = "blocked by hook"
228			}
229			return HookResult{
230				Decision: DecisionDeny,
231				Reason:   reason,
232			}
233		case HaltExitCode:
234			// Exit code 49 = halt the whole turn. Stderr is the reason.
235			reason := strings.TrimSpace(stderr.String())
236			if reason == "" {
237				reason = "turn halted by hook"
238			}
239			return HookResult{
240				Decision: DecisionDeny,
241				Halt:     true,
242				Reason:   reason,
243			}
244		default:
245			// Other non-zero exits are non-blocking errors.
246			slog.Warn(
247				"Hook failed with non-blocking error",
248				"command", hook.Command,
249				"exit_code", exitCode,
250				"stderr", strings.TrimSpace(stderr.String()),
251				"error", err,
252			)
253			return HookResult{Decision: DecisionNone}
254		}
255	}
256
257	// Exit code 0 — parse stdout JSON.
258	result := parseStdout(stdout.String())
259	slog.Debug(
260		"Hook executed",
261		"command", hook.Command,
262		"decision", result.Decision.String(),
263	)
264	return result
265}