1package hooks
2
3import (
4 "bytes"
5 "context"
6 "log/slog"
7 "os/exec"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13)
14
15// Runner executes hook commands and aggregates their results.
16type Runner struct {
17 hooks []config.HookConfig
18 cwd string
19 projectDir string
20}
21
22// NewRunner creates a Runner from the given hook configs.
23func NewRunner(hooks []config.HookConfig, cwd, projectDir string) *Runner {
24 return &Runner{
25 hooks: hooks,
26 cwd: cwd,
27 projectDir: projectDir,
28 }
29}
30
31// Hooks returns the hook configs the runner was created with.
32func (r *Runner) Hooks() []config.HookConfig {
33 return r.hooks
34}
35
36// Run executes all matching hooks for the given event and tool, returning
37// an aggregated result.
38func (r *Runner) Run(ctx context.Context, eventName, sessionID, toolName, toolInputJSON string) (AggregateResult, error) {
39 matching := r.matchingHooks(toolName)
40 if len(matching) == 0 {
41 return AggregateResult{Decision: DecisionNone}, nil
42 }
43
44 // Deduplicate by command string.
45 seen := make(map[string]bool, len(matching))
46 var deduped []config.HookConfig
47 for _, h := range matching {
48 if seen[h.Command] {
49 continue
50 }
51 seen[h.Command] = true
52 deduped = append(deduped, h)
53 }
54
55 envVars := BuildEnv(eventName, toolName, sessionID, r.cwd, r.projectDir, toolInputJSON)
56 payload := BuildPayload(eventName, sessionID, r.cwd, toolName, toolInputJSON)
57
58 results := make([]HookResult, len(deduped))
59 var wg sync.WaitGroup
60 wg.Add(len(deduped))
61
62 for i, h := range deduped {
63 go func(idx int, hook config.HookConfig) {
64 defer wg.Done()
65 results[idx] = r.runOne(ctx, hook, envVars, payload)
66 }(i, h)
67 }
68 wg.Wait()
69
70 agg := aggregate(results, toolInputJSON)
71 agg.Hooks = make([]HookInfo, len(deduped))
72 for i, h := range deduped {
73 agg.Hooks[i] = HookInfo{
74 Name: h.Command,
75 Matcher: h.Matcher,
76 Decision: results[i].Decision.String(),
77 Halt: results[i].Halt,
78 Reason: results[i].Reason,
79 InputRewrite: results[i].UpdatedInput != "",
80 }
81 }
82 slog.Info("Hook completed",
83 "event", eventName,
84 "tool", toolName,
85 "hooks", len(deduped),
86 "decision", agg.Decision.String(),
87 )
88 return agg, nil
89}
90
91// matchingHooks returns hooks whose matcher matches the tool name (or has
92// no matcher, which matches everything).
93func (r *Runner) matchingHooks(toolName string) []config.HookConfig {
94 var matched []config.HookConfig
95 for _, h := range r.hooks {
96 re := h.MatcherRegex()
97 if re == nil || re.MatchString(toolName) {
98 matched = append(matched, h)
99 }
100 }
101 return matched
102}
103
104// runOne executes a single hook command and returns its result.
105func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult {
106 timeout := hook.TimeoutDuration()
107 ctx, cancel := context.WithTimeout(parentCtx, timeout)
108 defer cancel()
109
110 cmd := exec.CommandContext(ctx, "sh", "-c", hook.Command)
111 cmd.WaitDelay = time.Second
112 cmd.Env = envVars
113 cmd.Dir = r.cwd
114 cmd.Stdin = bytes.NewReader(payload)
115
116 var stdout, stderr bytes.Buffer
117 cmd.Stdout = &stdout
118 cmd.Stderr = &stderr
119
120 err := cmd.Run()
121
122 if ctx.Err() != nil {
123 // Distinguish timeout from parent cancellation.
124 if parentCtx.Err() != nil {
125 slog.Debug("Hook cancelled by parent context", "command", hook.Command)
126 } else {
127 slog.Warn("Hook timed out", "command", hook.Command, "timeout", timeout)
128 }
129 return HookResult{Decision: DecisionNone}
130 }
131
132 if err != nil {
133 exitCode := -1
134 if cmd.ProcessState != nil {
135 exitCode = cmd.ProcessState.ExitCode()
136 }
137 switch exitCode {
138 case 2:
139 // Exit code 2 = block this tool call. Stderr is the reason.
140 reason := strings.TrimSpace(stderr.String())
141 if reason == "" {
142 reason = "blocked by hook"
143 }
144 return HookResult{
145 Decision: DecisionDeny,
146 Reason: reason,
147 }
148 case HaltExitCode:
149 // Exit code 49 = halt the whole turn. Stderr is the reason.
150 reason := strings.TrimSpace(stderr.String())
151 if reason == "" {
152 reason = "turn halted by hook"
153 }
154 return HookResult{
155 Decision: DecisionDeny,
156 Halt: true,
157 Reason: reason,
158 }
159 default:
160 // Other non-zero exits are non-blocking errors.
161 slog.Warn("Hook failed with non-blocking error",
162 "command", hook.Command,
163 "exit_code", exitCode,
164 "stderr", strings.TrimSpace(stderr.String()),
165 )
166 return HookResult{Decision: DecisionNone}
167 }
168 }
169
170 // Exit code 0 — parse stdout JSON.
171 result := parseStdout(stdout.String())
172 slog.Debug("Hook executed",
173 "command", hook.Command,
174 "decision", result.Decision.String(),
175 )
176 return result
177}