1package hooks
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "maps"
8 "os"
9 "path/filepath"
10 "runtime"
11 "slices"
12 "sort"
13 "strings"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/csync"
17)
18
19type manager struct {
20 workingDir string
21 dataDir string
22 config *Config
23 executor *Executor
24 hooks *csync.Map[HookType, []string]
25}
26
27// NewManager creates a new hook manager.
28func NewManager(workingDir, dataDir string, cfg *Config) *manager {
29 if cfg == nil {
30 cfg = &Config{
31 TimeoutSeconds: 30,
32 Directories: []string{filepath.Join(dataDir, "hooks")},
33 }
34 }
35
36 defaultHooksDir := filepath.Join(dataDir, "hooks")
37
38 // Ensure default directory if not specified.
39 if len(cfg.Directories) == 0 {
40 cfg.Directories = []string{defaultHooksDir}
41 } else {
42 // Always include default hooks directory even when user overrides config.
43 if !slices.Contains(cfg.Directories, defaultHooksDir) {
44 cfg.Directories = append([]string{defaultHooksDir}, cfg.Directories...)
45 }
46 }
47
48 return &manager{
49 workingDir: workingDir,
50 dataDir: dataDir,
51 config: cfg,
52 executor: NewExecutor(workingDir),
53 hooks: csync.NewMap[HookType, []string](),
54 }
55}
56
57// isExecutable checks if a file is executable.
58// On Unix: checks execute permission bits for .sh files.
59// On Windows: only recognizes .sh extension (as we use POSIX shell emulator).
60func isExecutable(info os.FileInfo) bool {
61 name := strings.ToLower(info.Name())
62 if !strings.HasSuffix(name, ".sh") {
63 return false
64 }
65
66 if runtime.GOOS == "windows" {
67 return true
68 }
69 return info.Mode()&0o111 != 0
70}
71
72// executeHooks is the internal method that executes hooks for a given type.
73func (m *manager) executeHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) {
74 if m.config.Disabled {
75 return HookResult{Continue: true}, nil
76 }
77
78 hookContext.HookType = hookType
79 hookContext.Environment = m.config.Environment
80
81 hooks := m.discoverHooks(hookType)
82 if len(hooks) == 0 {
83 return HookResult{Continue: true}, nil
84 }
85
86 slog.Debug("Executing hooks", "type", hookType, "count", len(hooks))
87
88 accumulated := HookResult{Continue: true}
89 for _, hookPath := range hooks {
90 if m.isDisabled(hookPath) {
91 slog.Debug("Skipping disabled hook", "path", hookPath)
92 continue
93 }
94
95 hookCtx, cancel := context.WithTimeout(ctx, time.Duration(m.config.TimeoutSeconds)*time.Second)
96
97 result, err := m.executor.Execute(hookCtx, hookPath, hookContext)
98 cancel()
99
100 if err != nil {
101 slog.Error("Hook execution failed", "path", hookPath, "error", err)
102 if hookType == HookPreToolUse {
103 accumulated.Continue = false
104 accumulated.Permission = "deny"
105 accumulated.Message = fmt.Sprintf("Hook failed: %v", err)
106 return accumulated, nil
107 }
108 continue
109 }
110
111 if result.Message != "" {
112 slog.Info("Hook message", "path", hookPath, "message", result.Message)
113 }
114
115 m.mergeResults(&accumulated, result)
116
117 if !result.Continue {
118 slog.Info("Hook stopped execution", "path", hookPath)
119 break
120 }
121 }
122
123 return accumulated, nil
124}
125
126// discoverHooks finds all executable hooks for the given type.
127func (m *manager) discoverHooks(hookType HookType) []string {
128 if cached, ok := m.hooks.Get(hookType); ok {
129 return cached
130 }
131
132 var hooks []string
133
134 for _, dir := range m.config.Directories {
135 if _, err := os.Stat(dir); err == nil {
136 entries, err := os.ReadDir(dir)
137 if err == nil {
138 for _, entry := range entries {
139 if entry.IsDir() {
140 continue
141 }
142
143 hookPath := filepath.Join(dir, entry.Name())
144
145 info, err := entry.Info()
146 if err != nil {
147 continue
148 }
149
150 if !isExecutable(info) {
151 continue
152 }
153
154 hooks = append(hooks, hookPath)
155 slog.Debug("Discovered catch-all hook", "path", hookPath, "type", hookType)
156 }
157 }
158 }
159
160 hookDir := filepath.Join(dir, string(hookType))
161 if _, err := os.Stat(hookDir); os.IsNotExist(err) {
162 continue
163 }
164
165 entries, err := os.ReadDir(hookDir)
166 if err != nil {
167 slog.Error("Failed to read hooks directory", "dir", hookDir, "error", err)
168 continue
169 }
170
171 for _, entry := range entries {
172 if entry.IsDir() {
173 continue
174 }
175
176 hookPath := filepath.Join(hookDir, entry.Name())
177
178 info, err := entry.Info()
179 if err != nil {
180 continue
181 }
182
183 if !isExecutable(info) {
184 slog.Debug("Skipping non-executable hook", "path", hookPath)
185 continue
186 }
187
188 hooks = append(hooks, hookPath)
189 }
190 }
191
192 if inlineHooks, ok := m.config.Inline[string(hookType)]; ok {
193 for _, inline := range inlineHooks {
194 hookPath, err := m.writeInlineHook(hookType, inline)
195 if err != nil {
196 slog.Error("Failed to write inline hook", "name", inline.Name, "error", err)
197 continue
198 }
199 hooks = append(hooks, hookPath)
200 }
201 }
202
203 sort.Strings(hooks)
204 m.hooks.Set(hookType, hooks)
205 return hooks
206}
207
208// writeInlineHook writes an inline hook script to a temp file.
209func (m *manager) writeInlineHook(hookType HookType, inline InlineHook) (string, error) {
210 tempDir := filepath.Join(m.dataDir, "hooks", ".inline", string(hookType))
211 if err := os.MkdirAll(tempDir, 0o755); err != nil {
212 return "", err
213 }
214
215 hookPath := filepath.Join(tempDir, inline.Name)
216 if err := os.WriteFile(hookPath, []byte(inline.Script), 0o755); err != nil {
217 return "", err
218 }
219
220 return hookPath, nil
221}
222
223// isDisabled checks if a hook is in the disabled list.
224func (m *manager) isDisabled(hookPath string) bool {
225 for _, dir := range m.config.Directories {
226 if rel, err := filepath.Rel(dir, hookPath); err == nil {
227 // Normalize to forward slashes for cross-platform comparison
228 rel = filepath.ToSlash(rel)
229 if slices.Contains(m.config.DisableHooks, rel) {
230 return true
231 }
232 }
233 }
234 return false
235}
236
237// mergeResults merges a new result into the accumulated result.
238func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) {
239 accumulated.Continue = accumulated.Continue && new.Continue
240
241 if new.Permission != "" {
242 if new.Permission == "deny" {
243 accumulated.Permission = "deny"
244 } else if new.Permission == "approve" && accumulated.Permission == "" {
245 accumulated.Permission = "approve"
246 }
247 }
248
249 if new.ModifiedPrompt != nil {
250 accumulated.ModifiedPrompt = new.ModifiedPrompt
251 }
252
253 if len(new.ModifiedInput) > 0 {
254 if accumulated.ModifiedInput == nil {
255 accumulated.ModifiedInput = make(map[string]any)
256 }
257 maps.Copy(accumulated.ModifiedInput, new.ModifiedInput)
258 }
259
260 if len(new.ModifiedOutput) > 0 {
261 if accumulated.ModifiedOutput == nil {
262 accumulated.ModifiedOutput = make(map[string]any)
263 }
264 maps.Copy(accumulated.ModifiedOutput, new.ModifiedOutput)
265 }
266
267 if new.ContextContent != "" {
268 if accumulated.ContextContent == "" {
269 accumulated.ContextContent = new.ContextContent
270 } else {
271 accumulated.ContextContent += "\n\n" + new.ContextContent
272 }
273 }
274
275 accumulated.ContextFiles = append(accumulated.ContextFiles, new.ContextFiles...)
276
277 if new.Message != "" {
278 if accumulated.Message == "" {
279 accumulated.Message = new.Message
280 } else {
281 accumulated.Message += "; " + new.Message
282 }
283 }
284}
285
286// ListHooks implements Manager.
287func (m *manager) ListHooks(hookType HookType) []string {
288 return m.discoverHooks(hookType)
289}
290
291// ExecuteUserPromptSubmit executes user-prompt-submit hooks.
292func (m *manager) ExecuteUserPromptSubmit(ctx context.Context, sessionID, workingDir string, data UserPromptSubmitData) (HookResult, error) {
293 hookCtx := HookContext{
294 SessionID: sessionID,
295 WorkingDir: workingDir,
296 Data: data,
297 }
298
299 return m.executeHooks(ctx, HookUserPromptSubmit, hookCtx)
300}
301
302// ExecutePreToolUse executes pre-tool-use hooks.
303func (m *manager) ExecutePreToolUse(ctx context.Context, sessionID, workingDir string, data PreToolUseData) (HookResult, error) {
304 hookCtx := HookContext{
305 SessionID: sessionID,
306 WorkingDir: workingDir,
307 ToolName: data.ToolName,
308 ToolCallID: data.ToolCallID,
309 Data: data,
310 }
311
312 return m.executeHooks(ctx, HookPreToolUse, hookCtx)
313}
314
315// ExecutePostToolUse executes post-tool-use hooks.
316func (m *manager) ExecutePostToolUse(ctx context.Context, sessionID, workingDir string, data PostToolUseData) (HookResult, error) {
317 hookCtx := HookContext{
318 SessionID: sessionID,
319 WorkingDir: workingDir,
320 ToolName: data.ToolName,
321 ToolCallID: data.ToolCallID,
322 Data: data,
323 }
324
325 return m.executeHooks(ctx, HookPostToolUse, hookCtx)
326}
327
328// ExecuteStop executes stop hooks.
329func (m *manager) ExecuteStop(ctx context.Context, sessionID, workingDir string, data StopData) (HookResult, error) {
330 hookCtx := HookContext{
331 SessionID: sessionID,
332 WorkingDir: workingDir,
333 Data: data,
334 }
335
336 return m.executeHooks(ctx, HookStop, hookCtx)
337}