1package hooks
2
3import (
4 "bytes"
5 "context"
6 "log/slog"
7 "strings"
8 "sync"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/shell"
13)
14
15// abandonGrace is how long runOne waits after ctx cancellation for the
16// shell goroutine to yield before returning control to the caller and
17// letting the goroutine finish on its own. Mirrors the historical
18// cmd.WaitDelay = time.Second behavior of the previous os/exec path.
19const abandonGrace = time.Second
20
21// runShell is the shell executor used by runOne. It is a package-level
22// variable so tests can substitute a blocking or non-yielding
23// implementation to exercise the abandon-on-timeout path without
24// depending on the scheduling behavior of the real interpreter.
25var runShell = shell.Run
26
27// Runner executes hook commands and aggregates their results.
28type Runner struct {
29 hooks []config.HookConfig
30 cwd string
31 projectDir string
32}
33
34// NewRunner creates a Runner from the given hook configs.
35func NewRunner(hooks []config.HookConfig, cwd, projectDir string) *Runner {
36 return &Runner{
37 hooks: hooks,
38 cwd: cwd,
39 projectDir: projectDir,
40 }
41}
42
43// Hooks returns the hook configs the runner was created with.
44func (r *Runner) Hooks() []config.HookConfig {
45 return r.hooks
46}
47
48// Run executes all matching hooks for the given event and tool, returning
49// an aggregated result.
50func (r *Runner) Run(ctx context.Context, eventName, sessionID, toolName, toolInputJSON string) (AggregateResult, error) {
51 matching := r.matchingHooks(toolName)
52 if len(matching) == 0 {
53 return AggregateResult{Decision: DecisionNone}, nil
54 }
55
56 // Deduplicate by command string.
57 seen := make(map[string]bool, len(matching))
58 var deduped []config.HookConfig
59 for _, h := range matching {
60 if seen[h.Command] {
61 continue
62 }
63 seen[h.Command] = true
64 deduped = append(deduped, h)
65 }
66
67 envVars := BuildEnv(eventName, toolName, sessionID, r.cwd, r.projectDir, toolInputJSON)
68 payload := BuildPayload(eventName, sessionID, r.cwd, toolName, toolInputJSON)
69
70 results := make([]HookResult, len(deduped))
71 var wg sync.WaitGroup
72 wg.Add(len(deduped))
73
74 for i, h := range deduped {
75 go func(idx int, hook config.HookConfig) {
76 defer wg.Done()
77 results[idx] = r.runOne(ctx, hook, envVars, payload)
78 }(i, h)
79 }
80 wg.Wait()
81
82 agg := aggregate(results, toolInputJSON)
83 agg.Hooks = make([]HookInfo, len(deduped))
84 for i, h := range deduped {
85 agg.Hooks[i] = HookInfo{
86 Name: h.Command,
87 Matcher: h.Matcher,
88 Decision: results[i].Decision.String(),
89 Halt: results[i].Halt,
90 Reason: results[i].Reason,
91 InputRewrite: results[i].UpdatedInput != "",
92 }
93 }
94 slog.Info("Hook completed",
95 "event", eventName,
96 "tool", toolName,
97 "hooks", len(deduped),
98 "decision", agg.Decision.String(),
99 )
100 return agg, nil
101}
102
103// matchingHooks returns hooks whose matcher matches the tool name (or has
104// no matcher, which matches everything).
105func (r *Runner) matchingHooks(toolName string) []config.HookConfig {
106 var matched []config.HookConfig
107 for _, h := range r.hooks {
108 re := h.MatcherRegex()
109 if re == nil || re.MatchString(toolName) {
110 matched = append(matched, h)
111 }
112 }
113 return matched
114}
115
116// runOne executes a single hook command and returns its result.
117//
118// Execution goes through Crush's embedded POSIX shell (shell.Run) so the
119// same interpreter, builtins, and coreutils are visible to hooks as to
120// the bash tool. BlockFuncs are intentionally omitted: hooks are
121// user-authored config that carry the same trust as a shell alias.
122//
123// A hook that fails to yield after its deadline has passed is abandoned
124// after abandonGrace so the caller never blocks longer than
125// timeout + abandonGrace. Ownership of the stdout and stderr buffers is
126// strictly single-goroutine:
127// - before receiving from `done`, only the goroutine writes to them;
128// - after `done` delivers a value, the goroutine is finished and the
129// outer frame reads them;
130// - on the abandon path, the goroutine may still be writing and the
131// outer frame must not touch them again.
132func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult {
133 timeout := hook.TimeoutDuration()
134 ctx, cancel := context.WithTimeout(parentCtx, timeout)
135 defer cancel()
136
137 var stdout, stderr bytes.Buffer
138 done := make(chan error, 1)
139 go func() {
140 done <- runShell(ctx, shell.RunOptions{
141 Command: hook.Command,
142 Cwd: r.cwd,
143 Env: envVars,
144 Stdin: bytes.NewReader(payload),
145 Stdout: &stdout,
146 Stderr: &stderr,
147 })
148 }()
149
150 var err error
151 select {
152 case err = <-done:
153 // Normal path: goroutine has finished, buffers are safe to read.
154 case <-ctx.Done():
155 select {
156 case err = <-done:
157 // Interpreter yielded within the grace period; safe to read.
158 case <-time.After(abandonGrace):
159 slog.Warn("Hook did not yield after cancel; abandoning goroutine",
160 "command", hook.Command,
161 "timeout", timeout,
162 )
163 // The goroutine may still be writing to stdout/stderr; do
164 // not read either buffer below this point.
165 return HookResult{Decision: DecisionNone}
166 }
167 }
168
169 if shell.IsInterrupt(err) {
170 // Distinguish timeout from parent cancellation.
171 if parentCtx.Err() != nil {
172 slog.Debug("Hook cancelled by parent context", "command", hook.Command)
173 } else {
174 slog.Warn("Hook timed out", "command", hook.Command, "timeout", timeout)
175 }
176 return HookResult{Decision: DecisionNone}
177 }
178
179 if err != nil {
180 exitCode := shell.ExitCode(err)
181 switch exitCode {
182 case 2:
183 // Exit code 2 = block this tool call. Stderr is the reason.
184 reason := strings.TrimSpace(stderr.String())
185 if reason == "" {
186 reason = "blocked by hook"
187 }
188 return HookResult{
189 Decision: DecisionDeny,
190 Reason: reason,
191 }
192 case HaltExitCode:
193 // Exit code 49 = halt the whole turn. Stderr is the reason.
194 reason := strings.TrimSpace(stderr.String())
195 if reason == "" {
196 reason = "turn halted by hook"
197 }
198 return HookResult{
199 Decision: DecisionDeny,
200 Halt: true,
201 Reason: reason,
202 }
203 default:
204 // Other non-zero exits are non-blocking errors.
205 slog.Warn("Hook failed with non-blocking error",
206 "command", hook.Command,
207 "exit_code", exitCode,
208 "stderr", strings.TrimSpace(stderr.String()),
209 "error", err,
210 )
211 return HookResult{Decision: DecisionNone}
212 }
213 }
214
215 // Exit code 0 — parse stdout JSON.
216 result := parseStdout(stdout.String())
217 slog.Debug("Hook executed",
218 "command", hook.Command,
219 "decision", result.Decision.String(),
220 )
221 return result
222}