runner.go

  1package hooks
  2
  3import (
  4	"bytes"
  5	"context"
  6	"log/slog"
  7	"os/exec"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13)
 14
 15// Runner executes hook commands and aggregates their results.
 16type Runner struct {
 17	hooks      []config.HookConfig
 18	cwd        string
 19	projectDir string
 20}
 21
 22// NewRunner creates a Runner from the given hook configs.
 23func NewRunner(hooks []config.HookConfig, cwd, projectDir string) *Runner {
 24	return &Runner{
 25		hooks:      hooks,
 26		cwd:        cwd,
 27		projectDir: projectDir,
 28	}
 29}
 30
 31// Hooks returns the hook configs the runner was created with.
 32func (r *Runner) Hooks() []config.HookConfig {
 33	return r.hooks
 34}
 35
 36// Run executes all matching hooks for the given event and tool, returning
 37// an aggregated result.
 38func (r *Runner) Run(ctx context.Context, eventName, sessionID, toolName, toolInputJSON string) (AggregateResult, error) {
 39	matching := r.matchingHooks(toolName)
 40	if len(matching) == 0 {
 41		return AggregateResult{Decision: DecisionNone}, nil
 42	}
 43
 44	// Deduplicate by command string.
 45	seen := make(map[string]bool, len(matching))
 46	var deduped []config.HookConfig
 47	for _, h := range matching {
 48		if seen[h.Command] {
 49			continue
 50		}
 51		seen[h.Command] = true
 52		deduped = append(deduped, h)
 53	}
 54
 55	envVars := BuildEnv(eventName, toolName, sessionID, r.cwd, r.projectDir, toolInputJSON)
 56	payload := BuildPayload(eventName, sessionID, r.cwd, toolName, toolInputJSON)
 57
 58	results := make([]HookResult, len(deduped))
 59	var wg sync.WaitGroup
 60	wg.Add(len(deduped))
 61
 62	for i, h := range deduped {
 63		go func(idx int, hook config.HookConfig) {
 64			defer wg.Done()
 65			results[idx] = r.runOne(ctx, hook, envVars, payload)
 66		}(i, h)
 67	}
 68	wg.Wait()
 69
 70	agg := aggregate(results, toolInputJSON)
 71	agg.Hooks = make([]HookInfo, len(deduped))
 72	for i, h := range deduped {
 73		agg.Hooks[i] = HookInfo{
 74			Name:         h.Command,
 75			Matcher:      h.Matcher,
 76			Decision:     results[i].Decision.String(),
 77			Halt:         results[i].Halt,
 78			Reason:       results[i].Reason,
 79			InputRewrite: results[i].UpdatedInput != "",
 80		}
 81	}
 82	slog.Info("Hook completed",
 83		"event", eventName,
 84		"tool", toolName,
 85		"hooks", len(deduped),
 86		"decision", agg.Decision.String(),
 87	)
 88	return agg, nil
 89}
 90
 91// matchingHooks returns hooks whose matcher matches the tool name (or has
 92// no matcher, which matches everything).
 93func (r *Runner) matchingHooks(toolName string) []config.HookConfig {
 94	var matched []config.HookConfig
 95	for _, h := range r.hooks {
 96		re := h.MatcherRegex()
 97		if re == nil || re.MatchString(toolName) {
 98			matched = append(matched, h)
 99		}
100	}
101	return matched
102}
103
104// runOne executes a single hook command and returns its result.
105func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult {
106	timeout := hook.TimeoutDuration()
107	ctx, cancel := context.WithTimeout(parentCtx, timeout)
108	defer cancel()
109
110	cmd := exec.CommandContext(ctx, "sh", "-c", hook.Command)
111	cmd.WaitDelay = time.Second
112	cmd.Env = envVars
113	cmd.Dir = r.cwd
114	cmd.Stdin = bytes.NewReader(payload)
115
116	var stdout, stderr bytes.Buffer
117	cmd.Stdout = &stdout
118	cmd.Stderr = &stderr
119
120	err := cmd.Run()
121
122	if ctx.Err() != nil {
123		// Distinguish timeout from parent cancellation.
124		if parentCtx.Err() != nil {
125			slog.Debug("Hook cancelled by parent context", "command", hook.Command)
126		} else {
127			slog.Warn("Hook timed out", "command", hook.Command, "timeout", timeout)
128		}
129		return HookResult{Decision: DecisionNone}
130	}
131
132	if err != nil {
133		exitCode := -1
134		if cmd.ProcessState != nil {
135			exitCode = cmd.ProcessState.ExitCode()
136		}
137		switch exitCode {
138		case 2:
139			// Exit code 2 = block this tool call. Stderr is the reason.
140			reason := strings.TrimSpace(stderr.String())
141			if reason == "" {
142				reason = "blocked by hook"
143			}
144			return HookResult{
145				Decision: DecisionDeny,
146				Reason:   reason,
147			}
148		case HaltExitCode:
149			// Exit code 49 = halt the whole turn. Stderr is the reason.
150			reason := strings.TrimSpace(stderr.String())
151			if reason == "" {
152				reason = "turn halted by hook"
153			}
154			return HookResult{
155				Decision: DecisionDeny,
156				Halt:     true,
157				Reason:   reason,
158			}
159		default:
160			// Other non-zero exits are non-blocking errors.
161			slog.Warn("Hook failed with non-blocking error",
162				"command", hook.Command,
163				"exit_code", exitCode,
164				"stderr", strings.TrimSpace(stderr.String()),
165			)
166			return HookResult{Decision: DecisionNone}
167		}
168	}
169
170	// Exit code 0 — parse stdout JSON.
171	result := parseStdout(stdout.String())
172	slog.Debug("Hook executed",
173		"command", hook.Command,
174		"decision", result.Decision.String(),
175	)
176	return result
177}