1package hooks
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "strings"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/shell"
13)
14
15const DefaultHookTimeout = 30 * time.Second
16
17// HookContext contains context information passed to hooks.
18type HookContext struct {
19 EventType config.HookEventType `json:"event_type"`
20 SessionID string `json:"session_id,omitempty"`
21 ToolName string `json:"tool_name,omitempty"`
22 ToolInput map[string]any `json:"tool_input,omitempty"`
23 ToolResult string `json:"tool_result,omitempty"`
24 ToolError bool `json:"tool_error,omitempty"`
25 UserPrompt string `json:"user_prompt,omitempty"`
26 Timestamp time.Time `json:"timestamp"`
27 WorkingDir string `json:"working_dir,omitempty"`
28 MessageID string `json:"message_id,omitempty"`
29 Provider string `json:"provider,omitempty"`
30 Model string `json:"model,omitempty"`
31 TokensUsed int64 `json:"tokens_used,omitempty"`
32 TokensInput int64 `json:"tokens_input,omitempty"`
33 PermissionAction string `json:"permission_action,omitempty"`
34 PermissionPath string `json:"permission_path,omitempty"`
35 PermissionParams any `json:"permission_params,omitempty"`
36 PermissionToolCall string `json:"permission_tool_call,omitempty"`
37}
38
39// Executor executes hooks based on configuration.
40type Executor struct {
41 config config.HookConfig
42 workingDir string
43 shell *shell.Shell
44}
45
46// NewExecutor creates a new hook executor.
47func NewExecutor(hookConfig config.HookConfig, workingDir string) *Executor {
48 shellInst := shell.NewShell(&shell.Options{
49 WorkingDir: workingDir,
50 })
51 return &Executor{
52 config: hookConfig,
53 workingDir: workingDir,
54 shell: shellInst,
55 }
56}
57
58// Execute runs all hooks matching the given event type and context.
59// Returns the first error encountered, causing subsequent hooks to be skipped.
60func (e *Executor) Execute(ctx context.Context, hookCtx HookContext) error {
61 if e.config == nil || e.shell == nil {
62 return nil
63 }
64
65 hookCtx.Timestamp = time.Now()
66 hookCtx.WorkingDir = e.workingDir
67
68 matchers, ok := e.config[hookCtx.EventType]
69 if !ok || len(matchers) == 0 {
70 return nil
71 }
72
73 for _, matcher := range matchers {
74 if ctx.Err() != nil {
75 return ctx.Err()
76 }
77
78 if !e.matcherApplies(matcher, hookCtx) {
79 continue
80 }
81
82 for _, hook := range matcher.Hooks {
83 if err := e.executeHook(ctx, hook, hookCtx); err != nil {
84 slog.Warn("Hook execution failed",
85 "event", hookCtx.EventType,
86 "matcher", matcher.Matcher,
87 "error", err,
88 )
89 return err
90 }
91 }
92 }
93
94 return nil
95}
96
97// matcherApplies checks if a matcher applies to the given context.
98func (e *Executor) matcherApplies(matcher config.HookMatcher, ctx HookContext) bool {
99 if matcher.Matcher == "" || matcher.Matcher == "*" {
100 return true
101 }
102
103 if ctx.EventType == config.PreToolUse || ctx.EventType == config.PostToolUse {
104 return matchesToolName(matcher.Matcher, ctx.ToolName)
105 }
106
107 // For non-tool events, only empty or wildcard matchers apply
108 return matcher.Matcher == "" || matcher.Matcher == "*"
109}
110
111// matchesToolName supports pipe-separated patterns like "edit|write|multiedit".
112func matchesToolName(pattern, toolName string) bool {
113 if pattern == "" || pattern == "*" {
114 return true
115 }
116
117 // Check for exact match first
118 if pattern == toolName {
119 return true
120 }
121
122 // Check if pattern contains pipes (multiple tool names)
123 if !strings.Contains(pattern, "|") {
124 return false
125 }
126
127 // Split by pipe and check each tool name
128 for tool := range strings.SplitSeq(pattern, "|") {
129 tool = strings.TrimSpace(tool)
130 if tool == toolName {
131 return true
132 }
133 }
134
135 return false
136}
137
138// executeHook executes a single hook command.
139func (e *Executor) executeHook(ctx context.Context, hook config.Hook, hookCtx HookContext) error {
140 if hook.Type != "command" {
141 return fmt.Errorf("unsupported hook type: %s", hook.Type)
142 }
143
144 timeout := DefaultHookTimeout
145 if hook.Timeout != nil {
146 timeout = time.Duration(*hook.Timeout) * time.Second
147 }
148
149 execCtx, cancel := context.WithTimeout(ctx, timeout)
150 defer cancel()
151
152 contextJSON, err := json.Marshal(hookCtx)
153 if err != nil {
154 return fmt.Errorf("failed to marshal hook context: %w", err)
155 }
156
157 e.shell.SetEnv("CRUSH_HOOK_EVENT", string(hookCtx.EventType))
158 e.shell.SetEnv("CRUSH_HOOK_CONTEXT", string(contextJSON))
159 if hookCtx.SessionID != "" {
160 e.shell.SetEnv("CRUSH_SESSION_ID", hookCtx.SessionID)
161 }
162 if hookCtx.ToolName != "" {
163 e.shell.SetEnv("CRUSH_TOOL_NAME", hookCtx.ToolName)
164 }
165
166 slog.Debug("Executing hook",
167 "event", hookCtx.EventType,
168 "command", hook.Command,
169 "timeout", timeout,
170 )
171
172 stdout, stderr, err := e.shell.ExecWithStdin(execCtx, hook.Command, string(contextJSON))
173 if err != nil {
174 return fmt.Errorf("hook command failed: %w: stdout=%s stderr=%s", err, stdout, stderr)
175 }
176
177 if stdout != "" || stderr != "" {
178 slog.Debug("Hook output",
179 "event", hookCtx.EventType,
180 "stdout", stdout,
181 "stderr", stderr,
182 )
183 }
184
185 return nil
186}