1package hooks
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "os"
9 "regexp"
10 "strings"
11 "sync"
12 "time"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/shell"
20)
21
22const DefaultHookTimeout = 30 * time.Second
23
24// HookContext contains context information passed to hooks.
25// Fields are populated based on event type - see field comments for availability.
26type HookContext struct {
27 // EventType is the lifecycle event that triggered this hook.
28 // Available: All events
29 EventType config.HookEventType `json:"event_type"`
30
31 // SessionID identifies the Crush session.
32 // Available: All events (if session exists)
33 SessionID string `json:"session_id,omitempty"`
34
35 // TranscriptPath is the path to exported session transcript JSON file.
36 // Available: All events (if session exists and export succeeds)
37 TranscriptPath string `json:"transcript_path,omitempty"`
38
39 // ToolName is the name of the tool being called.
40 // Available: pre_tool_use, post_tool_use, permission_requested
41 ToolName string `json:"tool_name,omitempty"`
42
43 // ToolInput contains the tool's input parameters as key-value pairs.
44 // Available: pre_tool_use, post_tool_use
45 ToolInput map[string]any `json:"tool_input,omitempty"`
46
47 // ToolResult is the string result returned by the tool.
48 // Available: post_tool_use
49 ToolResult string `json:"tool_result,omitempty"`
50
51 // ToolError indicates whether the tool execution failed.
52 // Available: post_tool_use
53 ToolError bool `json:"tool_error,omitempty"`
54
55 // UserPrompt is the prompt submitted by the user.
56 // Available: user_prompt_submit
57 UserPrompt string `json:"user_prompt,omitempty"`
58
59 // Timestamp is when the hook was triggered (RFC3339 format).
60 // Available: All events
61 Timestamp time.Time `json:"timestamp"`
62
63 // WorkingDir is the project working directory.
64 // Available: All events
65 WorkingDir string `json:"working_dir,omitempty"`
66
67 // MessageID identifies the assistant message being processed.
68 // Available: pre_tool_use, post_tool_use, stop
69 MessageID string `json:"message_id,omitempty"`
70
71 // Provider is the LLM provider name (e.g., "anthropic").
72 // Available: All events with LLM interaction
73 Provider string `json:"provider,omitempty"`
74
75 // Model is the LLM model name (e.g., "claude-3-5-sonnet-20241022").
76 // Available: All events with LLM interaction
77 Model string `json:"model,omitempty"`
78
79 // TokensUsed is the total tokens consumed in this interaction.
80 // Available: stop
81 TokensUsed int64 `json:"tokens_used,omitempty"`
82
83 // TokensInput is the input tokens consumed in this interaction.
84 // Available: stop
85 TokensInput int64 `json:"tokens_input,omitempty"`
86
87 // PermissionAction is the permission action being requested (e.g., "read", "write").
88 // Available: permission_requested
89 PermissionAction string `json:"permission_action,omitempty"`
90
91 // PermissionPath is the file path involved in the permission request.
92 // Available: permission_requested
93 PermissionPath string `json:"permission_path,omitempty"`
94
95 // PermissionParams contains additional permission parameters.
96 // Available: permission_requested
97 PermissionRequest *permission.PermissionRequest `json:"permission_request,omitempty"`
98}
99
100type HookDecision string
101
102const (
103 HookDecisionBlock HookDecision = "block"
104 HookDecisionDeny HookDecision = "deny"
105 HookDecisionAllow HookDecision = "allow"
106 HookDecisionAsk HookDecision = "ask"
107)
108
109type Service interface {
110 Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error)
111 SetSmallModel(model fantasy.LanguageModel)
112}
113
114type service struct {
115 config config.HookConfig
116 workingDir string
117 regexCache *csync.Map[string, *regexp.Regexp]
118 smallModel fantasy.LanguageModel
119 messages message.Service
120}
121
122// NewService creates a new hook executor.
123func NewService(
124 hookConfig config.HookConfig,
125 workingDir string,
126 smallModel fantasy.LanguageModel,
127 messages message.Service,
128) Service {
129 return &service{
130 config: hookConfig,
131 workingDir: workingDir,
132 regexCache: csync.NewMap[string, *regexp.Regexp](),
133 smallModel: smallModel,
134 messages: messages,
135 }
136}
137
138// Execute implements Service.
139func (s *service) Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error) {
140 if s.config == nil {
141 return nil, nil
142 }
143
144 // Check if context is already cancelled - prevents race conditions during cancellation.
145 if ctx.Err() != nil {
146 return nil, ctx.Err()
147 }
148
149 hookCtx.Timestamp = time.Now()
150 hookCtx.WorkingDir = s.workingDir
151
152 hooks := s.collectMatchingHooks(hookCtx)
153 if len(hooks) == 0 {
154 return nil, nil
155 }
156
157 transcriptPath, cleanup, err := s.setupTranscript(ctx, hookCtx)
158 if err != nil {
159 slog.Warn("Failed to export transcript for hooks", "error", err)
160 } else if transcriptPath != "" {
161 hookCtx.TranscriptPath = transcriptPath
162 // Ensure cleanup happens even on panic.
163 defer func() {
164 cleanup()
165 }()
166 }
167
168 results := make([]message.HookOutput, len(hooks))
169 var wg sync.WaitGroup
170
171 for i, hook := range hooks {
172 if ctx.Err() != nil {
173 return nil, ctx.Err()
174 }
175 wg.Add(1)
176 go func(idx int, h config.Hook) {
177 defer wg.Done()
178 result, err := s.executeHook(ctx, h, hookCtx)
179 if err != nil {
180 slog.Warn("Hook execution failed",
181 "event", hookCtx.EventType,
182 "error", err,
183 )
184 } else {
185 results[idx] = *result
186 }
187 }(i, hook)
188 }
189 wg.Wait()
190 return results, nil
191}
192
193func (s *service) setupTranscript(ctx context.Context, hookCtx HookContext) (string, func(), error) {
194 if hookCtx.SessionID == "" || s.messages == nil {
195 return "", func() {}, nil
196 }
197
198 path, err := exportTranscript(ctx, s.messages, hookCtx.SessionID)
199 if err != nil {
200 return "", func() {}, err
201 }
202
203 cleanup := func() {
204 if path != "" {
205 cleanupTranscript(path)
206 }
207 }
208
209 return path, cleanup, nil
210}
211
212func (s *service) executeHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
213 var result *message.HookOutput
214 var err error
215
216 switch hook.Type {
217 case "prompt":
218 if hook.Prompt == "" {
219 return nil, fmt.Errorf("prompt-based hook missing 'prompt' field")
220 }
221 if s.smallModel == nil {
222 return nil, fmt.Errorf("prompt-based hook requires small model configuration")
223 }
224 result, err = s.executePromptHook(ctx, hook, hookCtx)
225 case "", "command":
226 if hook.Command == "" {
227 return nil, fmt.Errorf("command-based hook missing 'command' field")
228 }
229 slog.Info("executing")
230 result, err = s.executeCommandHook(ctx, hook, hookCtx)
231 default:
232 return nil, fmt.Errorf("unsupported hook type: %s", hook.Type)
233 }
234
235 if result != nil {
236 result.EventType = string(hookCtx.EventType)
237 }
238
239 return result, err
240}
241
242func (s *service) executePromptHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
243 contextJSON, err := json.Marshal(hookCtx)
244 if err != nil {
245 return nil, fmt.Errorf("failed to marshal hook context: %w", err)
246 }
247
248 var finalPrompt string
249 if strings.Contains(hook.Prompt, "$ARGUMENTS") {
250 finalPrompt = strings.ReplaceAll(hook.Prompt, "$ARGUMENTS", string(contextJSON))
251 } else {
252 finalPrompt = fmt.Sprintf("%s\n\nContext: %s", hook.Prompt, string(contextJSON))
253 }
254
255 timeout := DefaultHookTimeout
256 if hook.Timeout != nil {
257 timeout = time.Duration(*hook.Timeout) * time.Second
258 }
259
260 execCtx, cancel := context.WithTimeout(ctx, timeout)
261 defer cancel()
262
263 type readTranscriptParams struct{}
264 readTranscriptTool := fantasy.NewAgentTool(
265 "read_transcript",
266 "Used to read the conversation so far",
267 func(ctx context.Context, params readTranscriptParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
268 if hookCtx.TranscriptPath == "" {
269 return fantasy.NewTextErrorResponse("No transcript available"), nil
270 }
271 data, err := os.ReadFile(hookCtx.TranscriptPath)
272 if err != nil {
273 return fantasy.NewTextErrorResponse(err.Error()), nil
274 }
275 return fantasy.NewTextResponse(string(data)), nil
276 })
277
278 var output *message.HookOutput
279 outputTool := fantasy.NewAgentTool(
280 "output",
281 "Used to submit the output, remember you MUST call this tool at the end",
282 func(ctx context.Context, params message.HookOutput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
283 output = ¶ms
284 return fantasy.NewTextResponse("ouptut submitted"), nil
285 })
286 agent := fantasy.NewAgent(
287 s.smallModel,
288 fantasy.WithSystemPrompt(`You are a helpful sub agent used in a larger agents conversation loop,
289 your goal is to intercept the conversation and fulfill the intermediate requests, makesure to ALWAYS use the output tool at the end to output your decision`),
290 fantasy.WithTools(readTranscriptTool, outputTool),
291 )
292
293 _, err = agent.Generate(execCtx, fantasy.AgentCall{
294 Prompt: finalPrompt,
295 })
296 if err != nil {
297 return nil, err
298 }
299 return output, nil
300}
301
302func (s *service) executeCommandHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
303 timeout := DefaultHookTimeout
304 if hook.Timeout != nil {
305 timeout = time.Duration(*hook.Timeout) * time.Second
306 }
307
308 execCtx, cancel := context.WithTimeout(ctx, timeout)
309 defer cancel()
310
311 contextJSON, err := json.Marshal(hookCtx)
312 if err != nil {
313 return nil, fmt.Errorf("failed to marshal hook context: %w", err)
314 }
315
316 sh := shell.NewShell(&shell.Options{
317 WorkingDir: s.workingDir,
318 })
319
320 sh.SetEnv("CRUSH_HOOK_EVENT", string(hookCtx.EventType))
321 sh.SetEnv("CRUSH_HOOK_CONTEXT", string(contextJSON))
322 sh.SetEnv("CRUSH_PROJECT_DIR", s.workingDir)
323 if hookCtx.SessionID != "" {
324 sh.SetEnv("CRUSH_SESSION_ID", hookCtx.SessionID)
325 }
326 if hookCtx.ToolName != "" {
327 sh.SetEnv("CRUSH_TOOL_NAME", hookCtx.ToolName)
328 }
329
330 slog.Debug("Hook execution trace",
331 "event", hookCtx.EventType,
332 "command", hook.Command,
333 "timeout", timeout,
334 "context", hookCtx,
335 )
336
337 stdout, stderr, err := sh.ExecWithStdin(execCtx, hook.Command, string(contextJSON))
338
339 exitCode := shell.ExitCode(err)
340
341 var result *message.HookOutput
342 switch exitCode {
343 case 0:
344 // if the event is UserPromptSubmit we want the output to be added to the context
345 if hookCtx.EventType == config.UserPromptSubmit {
346 result = &message.HookOutput{
347 AdditionalContext: stdout,
348 }
349 } else {
350 result = &message.HookOutput{
351 Message: stdout,
352 }
353 }
354 case 2:
355 result = &message.HookOutput{
356 Decision: string(HookDecisionBlock),
357 Error: stderr,
358 }
359 return result, nil
360 default:
361 result = &message.HookOutput{
362 Error: stderr,
363 }
364 return result, nil
365 }
366
367 jsonOutput := parseHookOutput(stdout)
368 if jsonOutput == nil {
369 return result, nil
370 }
371
372 result.Message = jsonOutput.Message
373 result.Stop = jsonOutput.Stop
374 result.Decision = jsonOutput.Decision
375 result.AdditionalContext = jsonOutput.AdditionalContext
376 result.UpdatedInput = jsonOutput.UpdatedInput
377
378 // Trace output in debug mode
379 slog.Debug("Hook execution output",
380 "event", hookCtx.EventType,
381 "exit_code", exitCode,
382 "stdout_length", len(stdout),
383 "stderr_length", len(stderr),
384 "stdout", stdout,
385 "stderr", stderr,
386 )
387 return result, nil
388}
389
390func parseHookOutput(stdout string) *message.HookOutput {
391 stdout = strings.TrimSpace(stdout)
392 slog.Info(stdout)
393 if stdout == "" {
394 return nil
395 }
396
397 var output message.HookOutput
398 if err := json.Unmarshal([]byte(stdout), &output); err != nil {
399 // Failed to parse as HookOutput
400 return nil
401 }
402
403 return &output
404}
405
406func (s *service) SetSmallModel(model fantasy.LanguageModel) {
407 s.smallModel = model
408}
409
410func (s *service) collectMatchingHooks(hookCtx HookContext) []config.Hook {
411 matchers, ok := s.config[hookCtx.EventType]
412 if !ok || len(matchers) == 0 {
413 return nil
414 }
415
416 var hooks []config.Hook
417 for _, matcher := range matchers {
418 if !s.matcherApplies(matcher, hookCtx) {
419 continue
420 }
421 hooks = append(hooks, matcher.Hooks...)
422 }
423 return hooks
424}
425
426func (s *service) matcherApplies(matcher config.HookMatcher, ctx HookContext) bool {
427 if ctx.EventType == config.PreToolUse || ctx.EventType == config.PostToolUse {
428 return s.matchesToolName(matcher.Matcher, ctx.ToolName)
429 }
430
431 return matcher.Matcher == "" || matcher.Matcher == "*"
432}
433
434func (s *service) matchesToolName(pattern, toolName string) bool {
435 if pattern == "" || pattern == "*" {
436 return true
437 }
438
439 if pattern == toolName {
440 return true
441 }
442
443 if strings.Contains(pattern, "|") {
444 for tool := range strings.SplitSeq(pattern, "|") {
445 tool = strings.TrimSpace(tool)
446 if tool == toolName {
447 return true
448 }
449 }
450
451 return s.matchesRegex(pattern, toolName)
452 }
453
454 return s.matchesRegex(pattern, toolName)
455}
456
457func (s *service) matchesRegex(pattern, text string) bool {
458 re, ok := s.regexCache.Get(pattern)
459 if !ok {
460 compiled, err := regexp.Compile(pattern)
461 if err != nil {
462 // Not a valid regex, don't cache failures.
463 return false
464 }
465 re = s.regexCache.GetOrSet(pattern, func() *regexp.Regexp {
466 return compiled
467 })
468 }
469
470 if re == nil {
471 return false
472 }
473
474 return re.MatchString(text)
475}