runner.go

  1package hooks
  2
  3import (
  4	"bytes"
  5	"context"
  6	"log/slog"
  7	"strings"
  8	"sync"
  9	"time"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/shell"
 13)
 14
 15// abandonGrace is how long runOne waits after ctx cancellation for the
 16// shell goroutine to yield before returning control to the caller and
 17// letting the goroutine finish on its own. Mirrors the historical
 18// cmd.WaitDelay = time.Second behavior of the previous os/exec path.
 19const abandonGrace = time.Second
 20
 21// runShell is the shell executor used by runOne. It is a package-level
 22// variable so tests can substitute a blocking or non-yielding
 23// implementation to exercise the abandon-on-timeout path without
 24// depending on the scheduling behavior of the real interpreter.
 25var runShell = shell.Run
 26
 27// Runner executes hook commands and aggregates their results.
 28type Runner struct {
 29	hooks      []config.HookConfig
 30	cwd        string
 31	projectDir string
 32}
 33
 34// NewRunner creates a Runner from the given hook configs.
 35func NewRunner(hooks []config.HookConfig, cwd, projectDir string) *Runner {
 36	return &Runner{
 37		hooks:      hooks,
 38		cwd:        cwd,
 39		projectDir: projectDir,
 40	}
 41}
 42
 43// Hooks returns the hook configs the runner was created with.
 44func (r *Runner) Hooks() []config.HookConfig {
 45	return r.hooks
 46}
 47
 48// Run executes all matching hooks for the given event and tool, returning
 49// an aggregated result.
 50func (r *Runner) Run(ctx context.Context, eventName, sessionID, toolName, toolInputJSON string) (AggregateResult, error) {
 51	matching := r.matchingHooks(toolName)
 52	if len(matching) == 0 {
 53		return AggregateResult{Decision: DecisionNone}, nil
 54	}
 55
 56	// Deduplicate by command string.
 57	seen := make(map[string]bool, len(matching))
 58	var deduped []config.HookConfig
 59	for _, h := range matching {
 60		if seen[h.Command] {
 61			continue
 62		}
 63		seen[h.Command] = true
 64		deduped = append(deduped, h)
 65	}
 66
 67	envVars := BuildEnv(eventName, toolName, sessionID, r.cwd, r.projectDir, toolInputJSON)
 68	payload := BuildPayload(eventName, sessionID, r.cwd, toolName, toolInputJSON)
 69
 70	results := make([]HookResult, len(deduped))
 71	var wg sync.WaitGroup
 72	wg.Add(len(deduped))
 73
 74	for i, h := range deduped {
 75		go func(idx int, hook config.HookConfig) {
 76			defer wg.Done()
 77			results[idx] = r.runOne(ctx, hook, envVars, payload)
 78		}(i, h)
 79	}
 80	wg.Wait()
 81
 82	agg := aggregate(results, toolInputJSON)
 83	agg.Hooks = make([]HookInfo, len(deduped))
 84	for i, h := range deduped {
 85		agg.Hooks[i] = HookInfo{
 86			Name:         h.Command,
 87			Matcher:      h.Matcher,
 88			Decision:     results[i].Decision.String(),
 89			Halt:         results[i].Halt,
 90			Reason:       results[i].Reason,
 91			InputRewrite: results[i].UpdatedInput != "",
 92		}
 93	}
 94	slog.Info("Hook completed",
 95		"event", eventName,
 96		"tool", toolName,
 97		"hooks", len(deduped),
 98		"decision", agg.Decision.String(),
 99	)
100	return agg, nil
101}
102
103// matchingHooks returns hooks whose matcher matches the tool name (or has
104// no matcher, which matches everything).
105func (r *Runner) matchingHooks(toolName string) []config.HookConfig {
106	var matched []config.HookConfig
107	for _, h := range r.hooks {
108		re := h.MatcherRegex()
109		if re == nil || re.MatchString(toolName) {
110			matched = append(matched, h)
111		}
112	}
113	return matched
114}
115
116// runOne executes a single hook command and returns its result.
117//
118// Execution goes through Crush's embedded POSIX shell (shell.Run) so the
119// same interpreter, builtins, and coreutils are visible to hooks as to
120// the bash tool. BlockFuncs are intentionally omitted: hooks are
121// user-authored config that carry the same trust as a shell alias.
122//
123// A hook that fails to yield after its deadline has passed is abandoned
124// after abandonGrace so the caller never blocks longer than
125// timeout + abandonGrace. Ownership of the stdout and stderr buffers is
126// strictly single-goroutine:
127//   - before receiving from `done`, only the goroutine writes to them;
128//   - after `done` delivers a value, the goroutine is finished and the
129//     outer frame reads them;
130//   - on the abandon path, the goroutine may still be writing and the
131//     outer frame must not touch them again.
132func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult {
133	timeout := hook.TimeoutDuration()
134	ctx, cancel := context.WithTimeout(parentCtx, timeout)
135	defer cancel()
136
137	var stdout, stderr bytes.Buffer
138	done := make(chan error, 1)
139	go func() {
140		done <- runShell(ctx, shell.RunOptions{
141			Command: hook.Command,
142			Cwd:     r.cwd,
143			Env:     envVars,
144			Stdin:   bytes.NewReader(payload),
145			Stdout:  &stdout,
146			Stderr:  &stderr,
147		})
148	}()
149
150	var err error
151	select {
152	case err = <-done:
153		// Normal path: goroutine has finished, buffers are safe to read.
154	case <-ctx.Done():
155		select {
156		case err = <-done:
157			// Interpreter yielded within the grace period; safe to read.
158		case <-time.After(abandonGrace):
159			slog.Warn("Hook did not yield after cancel; abandoning goroutine",
160				"command", hook.Command,
161				"timeout", timeout,
162			)
163			// The goroutine may still be writing to stdout/stderr; do
164			// not read either buffer below this point.
165			return HookResult{Decision: DecisionNone}
166		}
167	}
168
169	if shell.IsInterrupt(err) {
170		// Distinguish timeout from parent cancellation.
171		if parentCtx.Err() != nil {
172			slog.Debug("Hook cancelled by parent context", "command", hook.Command)
173		} else {
174			slog.Warn("Hook timed out", "command", hook.Command, "timeout", timeout)
175		}
176		return HookResult{Decision: DecisionNone}
177	}
178
179	if err != nil {
180		exitCode := shell.ExitCode(err)
181		switch exitCode {
182		case 2:
183			// Exit code 2 = block this tool call. Stderr is the reason.
184			reason := strings.TrimSpace(stderr.String())
185			if reason == "" {
186				reason = "blocked by hook"
187			}
188			return HookResult{
189				Decision: DecisionDeny,
190				Reason:   reason,
191			}
192		case HaltExitCode:
193			// Exit code 49 = halt the whole turn. Stderr is the reason.
194			reason := strings.TrimSpace(stderr.String())
195			if reason == "" {
196				reason = "turn halted by hook"
197			}
198			return HookResult{
199				Decision: DecisionDeny,
200				Halt:     true,
201				Reason:   reason,
202			}
203		default:
204			// Other non-zero exits are non-blocking errors.
205			slog.Warn("Hook failed with non-blocking error",
206				"command", hook.Command,
207				"exit_code", exitCode,
208				"stderr", strings.TrimSpace(stderr.String()),
209				"error", err,
210			)
211			return HookResult{Decision: DecisionNone}
212		}
213	}
214
215	// Exit code 0 — parse stdout JSON.
216	result := parseStdout(stdout.String())
217	slog.Debug("Hook executed",
218		"command", hook.Command,
219		"decision", result.Decision.String(),
220	)
221	return result
222}