runner.go

  1package hooks
  2
  3import (
  4	"bytes"
  5	"context"
  6	"log/slog"
  7	"regexp"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/shell"
 14)
 15
 16// abandonGrace is how long runOne waits after ctx cancellation for the
 17// shell goroutine to yield before returning control to the caller and
 18// letting the goroutine finish on its own. Mirrors the historical
 19// cmd.WaitDelay = time.Second behavior of the previous os/exec path.
 20const abandonGrace = time.Second
 21
 22// runShell is the shell executor used by runOne. It is a package-level
 23// variable so tests can substitute a blocking or non-yielding
 24// implementation to exercise the abandon-on-timeout path without
 25// depending on the scheduling behavior of the real interpreter.
 26var runShell = shell.Run
 27
 28// compiledHook pairs a HookConfig with its compiled matcher regex. A nil
 29// matcher means "match every tool".
 30type compiledHook struct {
 31	cfg     config.HookConfig
 32	matcher *regexp.Regexp
 33}
 34
 35// Runner executes hook commands and aggregates their results.
 36type Runner struct {
 37	hooks      []compiledHook
 38	cwd        string
 39	projectDir string
 40}
 41
 42// NewRunner creates a Runner from the given hook configs. Each hook's
 43// Matcher is compiled here so the Runner is self-sufficient; callers do
 44// not have to pre-compile matchers on the config, and reloads or merges
 45// that rebuild HookConfig values can't silently strip compiled state.
 46//
 47// Hooks whose matcher fails to compile are skipped with a warning rather
 48// than treated as match-everything. ValidateHooks is expected to have
 49// caught syntax errors earlier, so this is defense in depth.
 50func NewRunner(hooks []config.HookConfig, cwd, projectDir string) *Runner {
 51	compiled := make([]compiledHook, 0, len(hooks))
 52	for _, h := range hooks {
 53		ch := compiledHook{cfg: h}
 54		if h.Matcher != "" {
 55			re, err := regexp.Compile(h.Matcher)
 56			if err != nil {
 57				slog.Warn("Hook matcher failed to compile; skipping hook",
 58					"matcher", h.Matcher,
 59					"command", h.Command,
 60					"error", err,
 61				)
 62				continue
 63			}
 64			ch.matcher = re
 65		}
 66		compiled = append(compiled, ch)
 67	}
 68	return &Runner{
 69		hooks:      compiled,
 70		cwd:        cwd,
 71		projectDir: projectDir,
 72	}
 73}
 74
 75// Hooks returns the hook configs the runner was created with, in config
 76// order. Hooks whose matcher failed to compile at construction are
 77// omitted. Intended for diagnostics; callers should not rely on ordering
 78// or identity beyond that.
 79func (r *Runner) Hooks() []config.HookConfig {
 80	out := make([]config.HookConfig, len(r.hooks))
 81	for i, h := range r.hooks {
 82		out[i] = h.cfg
 83	}
 84	return out
 85}
 86
 87// Run executes all matching hooks for the given event and tool, returning
 88// an aggregated result.
 89func (r *Runner) Run(ctx context.Context, eventName, sessionID, toolName, toolInputJSON string) (AggregateResult, error) {
 90	matching := r.matchingHooks(toolName)
 91	if len(matching) == 0 {
 92		return AggregateResult{Decision: DecisionNone}, nil
 93	}
 94
 95	// Deduplicate by command string.
 96	seen := make(map[string]bool, len(matching))
 97	var deduped []config.HookConfig
 98	for _, h := range matching {
 99		if seen[h.Command] {
100			continue
101		}
102		seen[h.Command] = true
103		deduped = append(deduped, h)
104	}
105
106	envVars := BuildEnv(eventName, toolName, sessionID, r.cwd, r.projectDir, toolInputJSON)
107	payload := BuildPayload(eventName, sessionID, r.cwd, toolName, toolInputJSON)
108
109	results := make([]HookResult, len(deduped))
110	var wg sync.WaitGroup
111	wg.Add(len(deduped))
112
113	for i, h := range deduped {
114		go func(idx int, hook config.HookConfig) {
115			defer wg.Done()
116			results[idx] = r.runOne(ctx, hook, envVars, payload)
117		}(i, h)
118	}
119	wg.Wait()
120
121	agg := aggregate(results, toolInputJSON)
122	agg.Hooks = make([]HookInfo, len(deduped))
123	for i, h := range deduped {
124		agg.Hooks[i] = HookInfo{
125			Name:         h.Command,
126			Matcher:      h.Matcher,
127			Decision:     results[i].Decision.String(),
128			Halt:         results[i].Halt,
129			Reason:       results[i].Reason,
130			InputRewrite: results[i].UpdatedInput != "",
131		}
132	}
133	slog.Info("Hook completed",
134		"event", eventName,
135		"tool", toolName,
136		"hooks", len(deduped),
137		"decision", agg.Decision.String(),
138	)
139	return agg, nil
140}
141
142// matchingHooks returns hooks whose matcher matches the tool name (or has
143// no matcher, which matches everything).
144func (r *Runner) matchingHooks(toolName string) []config.HookConfig {
145	var matched []config.HookConfig
146	for _, h := range r.hooks {
147		if h.matcher == nil || h.matcher.MatchString(toolName) {
148			matched = append(matched, h.cfg)
149		}
150	}
151	return matched
152}
153
154// runOne executes a single hook command and returns its result.
155//
156// Execution goes through Crush's embedded POSIX shell (shell.Run) so the
157// same interpreter, builtins, and coreutils are visible to hooks as to
158// the bash tool. BlockFuncs are intentionally omitted: hooks are
159// user-authored config that carry the same trust as a shell alias.
160//
161// A hook that fails to yield after its deadline has passed is abandoned
162// after abandonGrace so the caller never blocks longer than
163// timeout + abandonGrace. Ownership of the stdout and stderr buffers is
164// strictly single-goroutine:
165//   - before receiving from `done`, only the goroutine writes to them;
166//   - after `done` delivers a value, the goroutine is finished and the
167//     outer frame reads them;
168//   - on the abandon path, the goroutine may still be writing and the
169//     outer frame must not touch them again.
170func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult {
171	timeout := hook.TimeoutDuration()
172	ctx, cancel := context.WithTimeout(parentCtx, timeout)
173	defer cancel()
174
175	var stdout, stderr bytes.Buffer
176	done := make(chan error, 1)
177	go func() {
178		done <- runShell(ctx, shell.RunOptions{
179			Command: hook.Command,
180			Cwd:     r.cwd,
181			Env:     envVars,
182			Stdin:   bytes.NewReader(payload),
183			Stdout:  &stdout,
184			Stderr:  &stderr,
185		})
186	}()
187
188	var err error
189	select {
190	case err = <-done:
191		// Normal path: goroutine has finished, buffers are safe to read.
192	case <-ctx.Done():
193		select {
194		case err = <-done:
195			// Interpreter yielded within the grace period; safe to read.
196		case <-time.After(abandonGrace):
197			slog.Warn("Hook did not yield after cancel; abandoning goroutine",
198				"command", hook.Command,
199				"timeout", timeout,
200			)
201			// The goroutine may still be writing to stdout/stderr; do
202			// not read either buffer below this point.
203			return HookResult{Decision: DecisionNone}
204		}
205	}
206
207	if shell.IsInterrupt(err) {
208		// Distinguish timeout from parent cancellation.
209		if parentCtx.Err() != nil {
210			slog.Debug("Hook cancelled by parent context", "command", hook.Command)
211		} else {
212			slog.Warn("Hook timed out", "command", hook.Command, "timeout", timeout)
213		}
214		return HookResult{Decision: DecisionNone}
215	}
216
217	if err != nil {
218		exitCode := shell.ExitCode(err)
219		switch exitCode {
220		case 2:
221			// Exit code 2 = block this tool call. Stderr is the reason.
222			reason := strings.TrimSpace(stderr.String())
223			if reason == "" {
224				reason = "blocked by hook"
225			}
226			return HookResult{
227				Decision: DecisionDeny,
228				Reason:   reason,
229			}
230		case HaltExitCode:
231			// Exit code 49 = halt the whole turn. Stderr is the reason.
232			reason := strings.TrimSpace(stderr.String())
233			if reason == "" {
234				reason = "turn halted by hook"
235			}
236			return HookResult{
237				Decision: DecisionDeny,
238				Halt:     true,
239				Reason:   reason,
240			}
241		default:
242			// Other non-zero exits are non-blocking errors.
243			slog.Warn("Hook failed with non-blocking error",
244				"command", hook.Command,
245				"exit_code", exitCode,
246				"stderr", strings.TrimSpace(stderr.String()),
247				"error", err,
248			)
249			return HookResult{Decision: DecisionNone}
250		}
251	}
252
253	// Exit code 0 — parse stdout JSON.
254	result := parseStdout(stdout.String())
255	slog.Debug("Hook executed",
256		"command", hook.Command,
257		"decision", result.Decision.String(),
258	)
259	return result
260}