runner.go

  1package hooks
  2
  3import (
  4	"bytes"
  5	"context"
  6	"log/slog"
  7	"os/exec"
  8	"regexp"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14)
 15
 16// compiledHook pairs a HookConfig with its compiled matcher regex. A nil
 17// matcher means "match every tool".
 18type compiledHook struct {
 19	cfg     config.HookConfig
 20	matcher *regexp.Regexp
 21}
 22
 23// Runner executes hook commands and aggregates their results.
 24type Runner struct {
 25	hooks      []compiledHook
 26	cwd        string
 27	projectDir string
 28}
 29
 30// NewRunner creates a Runner from the given hook configs. Each hook's
 31// Matcher is compiled here so the Runner is self-sufficient; callers do
 32// not have to pre-compile matchers on the config, and reloads or merges
 33// that rebuild HookConfig values can't silently strip compiled state.
 34//
 35// Hooks whose matcher fails to compile are skipped with a warning rather
 36// than treated as match-everything. ValidateHooks is expected to have
 37// caught syntax errors earlier, so this is defense in depth.
 38func NewRunner(hooks []config.HookConfig, cwd, projectDir string) *Runner {
 39	compiled := make([]compiledHook, 0, len(hooks))
 40	for _, h := range hooks {
 41		ch := compiledHook{cfg: h}
 42		if h.Matcher != "" {
 43			re, err := regexp.Compile(h.Matcher)
 44			if err != nil {
 45				slog.Warn("Hook matcher failed to compile; skipping hook",
 46					"matcher", h.Matcher,
 47					"command", h.Command,
 48					"error", err,
 49				)
 50				continue
 51			}
 52			ch.matcher = re
 53		}
 54		compiled = append(compiled, ch)
 55	}
 56	return &Runner{
 57		hooks:      compiled,
 58		cwd:        cwd,
 59		projectDir: projectDir,
 60	}
 61}
 62
 63// Hooks returns the hook configs the runner was created with, in config
 64// order. Hooks whose matcher failed to compile at construction are
 65// omitted. Intended for diagnostics; callers should not rely on ordering
 66// or identity beyond that.
 67func (r *Runner) Hooks() []config.HookConfig {
 68	out := make([]config.HookConfig, len(r.hooks))
 69	for i, h := range r.hooks {
 70		out[i] = h.cfg
 71	}
 72	return out
 73}
 74
 75// Run executes all matching hooks for the given event and tool, returning
 76// an aggregated result.
 77func (r *Runner) Run(ctx context.Context, eventName, sessionID, toolName, toolInputJSON string) (AggregateResult, error) {
 78	matching := r.matchingHooks(toolName)
 79	if len(matching) == 0 {
 80		return AggregateResult{Decision: DecisionNone}, nil
 81	}
 82
 83	// Deduplicate by command string.
 84	seen := make(map[string]bool, len(matching))
 85	var deduped []config.HookConfig
 86	for _, h := range matching {
 87		if seen[h.Command] {
 88			continue
 89		}
 90		seen[h.Command] = true
 91		deduped = append(deduped, h)
 92	}
 93
 94	envVars := BuildEnv(eventName, toolName, sessionID, r.cwd, r.projectDir, toolInputJSON)
 95	payload := BuildPayload(eventName, sessionID, r.cwd, toolName, toolInputJSON)
 96
 97	results := make([]HookResult, len(deduped))
 98	var wg sync.WaitGroup
 99	wg.Add(len(deduped))
100
101	for i, h := range deduped {
102		go func(idx int, hook config.HookConfig) {
103			defer wg.Done()
104			results[idx] = r.runOne(ctx, hook, envVars, payload)
105		}(i, h)
106	}
107	wg.Wait()
108
109	agg := aggregate(results, toolInputJSON)
110	agg.Hooks = make([]HookInfo, len(deduped))
111	for i, h := range deduped {
112		agg.Hooks[i] = HookInfo{
113			Name:         h.Command,
114			Matcher:      h.Matcher,
115			Decision:     results[i].Decision.String(),
116			Halt:         results[i].Halt,
117			Reason:       results[i].Reason,
118			InputRewrite: results[i].UpdatedInput != "",
119		}
120	}
121	slog.Info("Hook completed",
122		"event", eventName,
123		"tool", toolName,
124		"hooks", len(deduped),
125		"decision", agg.Decision.String(),
126	)
127	return agg, nil
128}
129
130// matchingHooks returns hooks whose matcher matches the tool name (or has
131// no matcher, which matches everything).
132func (r *Runner) matchingHooks(toolName string) []config.HookConfig {
133	var matched []config.HookConfig
134	for _, h := range r.hooks {
135		if h.matcher == nil || h.matcher.MatchString(toolName) {
136			matched = append(matched, h.cfg)
137		}
138	}
139	return matched
140}
141
142// runOne executes a single hook command and returns its result.
143func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult {
144	timeout := hook.TimeoutDuration()
145	ctx, cancel := context.WithTimeout(parentCtx, timeout)
146	defer cancel()
147
148	cmd := exec.CommandContext(ctx, "sh", "-c", hook.Command)
149	cmd.WaitDelay = time.Second
150	cmd.Env = envVars
151	cmd.Dir = r.cwd
152	cmd.Stdin = bytes.NewReader(payload)
153
154	var stdout, stderr bytes.Buffer
155	cmd.Stdout = &stdout
156	cmd.Stderr = &stderr
157
158	err := cmd.Run()
159
160	if ctx.Err() != nil {
161		// Distinguish timeout from parent cancellation.
162		if parentCtx.Err() != nil {
163			slog.Debug("Hook cancelled by parent context", "command", hook.Command)
164		} else {
165			slog.Warn("Hook timed out", "command", hook.Command, "timeout", timeout)
166		}
167		return HookResult{Decision: DecisionNone}
168	}
169
170	if err != nil {
171		exitCode := -1
172		if cmd.ProcessState != nil {
173			exitCode = cmd.ProcessState.ExitCode()
174		}
175		switch exitCode {
176		case 2:
177			// Exit code 2 = block this tool call. Stderr is the reason.
178			reason := strings.TrimSpace(stderr.String())
179			if reason == "" {
180				reason = "blocked by hook"
181			}
182			return HookResult{
183				Decision: DecisionDeny,
184				Reason:   reason,
185			}
186		case HaltExitCode:
187			// Exit code 49 = halt the whole turn. Stderr is the reason.
188			reason := strings.TrimSpace(stderr.String())
189			if reason == "" {
190				reason = "turn halted by hook"
191			}
192			return HookResult{
193				Decision: DecisionDeny,
194				Halt:     true,
195				Reason:   reason,
196			}
197		default:
198			// Other non-zero exits are non-blocking errors.
199			slog.Warn("Hook failed with non-blocking error",
200				"command", hook.Command,
201				"exit_code", exitCode,
202				"stderr", strings.TrimSpace(stderr.String()),
203			)
204			return HookResult{Decision: DecisionNone}
205		}
206	}
207
208	// Exit code 0 — parse stdout JSON.
209	result := parseStdout(stdout.String())
210	slog.Debug("Hook executed",
211		"command", hook.Command,
212		"decision", result.Decision.String(),
213	)
214	return result
215}