1package hooks
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/shell"
12)
13
14const DefaultHookTimeout = 30 * time.Second
15
16// HookContext contains context information passed to hooks.
17type HookContext struct {
18 EventType config.HookEventType `json:"event_type"`
19 SessionID string `json:"session_id,omitempty"`
20 ToolName string `json:"tool_name,omitempty"`
21 ToolInput map[string]any `json:"tool_input,omitempty"`
22 ToolResult string `json:"tool_result,omitempty"`
23 ToolError bool `json:"tool_error,omitempty"`
24 UserPrompt string `json:"user_prompt,omitempty"`
25 Timestamp time.Time `json:"timestamp"`
26 WorkingDir string `json:"working_dir,omitempty"`
27 MessageID string `json:"message_id,omitempty"`
28 Provider string `json:"provider,omitempty"`
29 Model string `json:"model,omitempty"`
30 TokensUsed int64 `json:"tokens_used,omitempty"`
31 TokensInput int64 `json:"tokens_input,omitempty"`
32}
33
34// Executor executes hooks based on configuration.
35type Executor struct {
36 config config.HookConfig
37 workingDir string
38 shell *shell.Shell
39}
40
41// NewExecutor creates a new hook executor.
42func NewExecutor(hookConfig config.HookConfig, workingDir string) *Executor {
43 shellInst := shell.NewShell(&shell.Options{
44 WorkingDir: workingDir,
45 })
46 return &Executor{
47 config: hookConfig,
48 workingDir: workingDir,
49 shell: shellInst,
50 }
51}
52
53// Execute runs all hooks matching the given event type and context.
54// Returns the first error encountered, causing subsequent hooks to be skipped.
55func (e *Executor) Execute(ctx context.Context, hookCtx HookContext) error {
56 if e.config == nil || e.shell == nil {
57 return nil
58 }
59
60 hookCtx.Timestamp = time.Now()
61 hookCtx.WorkingDir = e.workingDir
62
63 matchers, ok := e.config[hookCtx.EventType]
64 if !ok || len(matchers) == 0 {
65 return nil
66 }
67
68 for _, matcher := range matchers {
69 if ctx.Err() != nil {
70 return ctx.Err()
71 }
72
73 if !e.matcherApplies(matcher, hookCtx) {
74 continue
75 }
76
77 for _, hook := range matcher.Hooks {
78 if err := e.executeHook(ctx, hook, hookCtx); err != nil {
79 slog.Warn("Hook execution failed",
80 "event", hookCtx.EventType,
81 "matcher", matcher.Matcher,
82 "error", err,
83 )
84 return err
85 }
86 }
87 }
88
89 return nil
90}
91
92// matcherApplies checks if a matcher applies to the given context.
93func (e *Executor) matcherApplies(matcher config.HookMatcher, ctx HookContext) bool {
94 if matcher.Matcher == "" || matcher.Matcher == "*" {
95 return true
96 }
97
98 if ctx.EventType == config.PreToolUse || ctx.EventType == config.PostToolUse {
99 return matcher.Matcher == ctx.ToolName
100 }
101
102 return matcher.Matcher == "" || matcher.Matcher == "*"
103}
104
105// executeHook executes a single hook command.
106func (e *Executor) executeHook(ctx context.Context, hook config.Hook, hookCtx HookContext) error {
107 if hook.Type != "command" {
108 return fmt.Errorf("unsupported hook type: %s", hook.Type)
109 }
110
111 timeout := DefaultHookTimeout
112 if hook.Timeout != nil {
113 timeout = time.Duration(*hook.Timeout) * time.Second
114 }
115
116 execCtx, cancel := context.WithTimeout(ctx, timeout)
117 defer cancel()
118
119 contextJSON, err := json.Marshal(hookCtx)
120 if err != nil {
121 return fmt.Errorf("failed to marshal hook context: %w", err)
122 }
123
124 e.shell.SetEnv("CRUSH_HOOK_EVENT", string(hookCtx.EventType))
125 e.shell.SetEnv("CRUSH_HOOK_CONTEXT", string(contextJSON))
126 if hookCtx.SessionID != "" {
127 e.shell.SetEnv("CRUSH_SESSION_ID", hookCtx.SessionID)
128 }
129 if hookCtx.ToolName != "" {
130 e.shell.SetEnv("CRUSH_TOOL_NAME", hookCtx.ToolName)
131 }
132
133 slog.Debug("Executing hook",
134 "event", hookCtx.EventType,
135 "command", hook.Command,
136 "timeout", timeout,
137 )
138
139 fullCommand := fmt.Sprintf("%s <<'CRUSH_HOOK_EOF'\n%s\nCRUSH_HOOK_EOF\n", hook.Command, string(contextJSON))
140
141 stdout, stderr, err := e.shell.Exec(execCtx, fullCommand)
142 if err != nil {
143 return fmt.Errorf("hook command failed: %w: stdout=%s stderr=%s", err, stdout, stderr)
144 }
145
146 if stdout != "" || stderr != "" {
147 slog.Debug("Hook output",
148 "event", hookCtx.EventType,
149 "stdout", stdout,
150 "stderr", stderr,
151 )
152 }
153
154 return nil
155}