1package hooks
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "maps"
8 "os"
9 "path/filepath"
10 "runtime"
11 "slices"
12 "sort"
13 "strings"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/csync"
17)
18
19type manager struct {
20 workingDir string
21 dataDir string
22 config *Config
23 executor *Executor
24 hooks *csync.Map[HookType, []string]
25}
26
27// NewManager creates a new hook manager.
28func NewManager(workingDir, dataDir string, cfg *Config) *manager {
29 if cfg == nil {
30 cfg = &Config{
31 TimeoutSeconds: 30,
32 Directories: []string{filepath.Join(dataDir, "hooks")},
33 }
34 }
35
36 // Ensure default directory if not specified.
37 if len(cfg.Directories) == 0 {
38 cfg.Directories = []string{filepath.Join(dataDir, "hooks")}
39 }
40
41 return &manager{
42 workingDir: workingDir,
43 dataDir: dataDir,
44 config: cfg,
45 executor: NewExecutor(workingDir),
46 hooks: csync.NewMap[HookType, []string](),
47 }
48}
49
50// isExecutable checks if a file is executable.
51// On Unix: checks execute permission bits for .sh files.
52// On Windows: only recognizes .sh extension (as we use POSIX shell emulator).
53func isExecutable(info os.FileInfo) bool {
54 name := strings.ToLower(info.Name())
55 if !strings.HasSuffix(name, ".sh") {
56 return false
57 }
58
59 if runtime.GOOS == "windows" {
60 return true
61 }
62 return info.Mode()&0o111 != 0
63}
64
65// executeHooks is the internal method that executes hooks for a given type.
66func (m *manager) executeHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) {
67 if m.config.Disabled {
68 return HookResult{Continue: true}, nil
69 }
70
71 hookContext.HookType = hookType
72 hookContext.Environment = m.config.Environment
73
74 hooks := m.discoverHooks(hookType)
75 if len(hooks) == 0 {
76 return HookResult{Continue: true}, nil
77 }
78
79 slog.Debug("Executing hooks", "type", hookType, "count", len(hooks))
80
81 accumulated := HookResult{Continue: true}
82 for _, hookPath := range hooks {
83 if m.isDisabled(hookPath) {
84 slog.Debug("Skipping disabled hook", "path", hookPath)
85 continue
86 }
87
88 hookCtx, cancel := context.WithTimeout(ctx, time.Duration(m.config.TimeoutSeconds)*time.Second)
89
90 result, err := m.executor.Execute(hookCtx, hookPath, hookContext)
91 cancel()
92
93 if err != nil {
94 slog.Error("Hook execution failed", "path", hookPath, "error", err)
95 if hookType == HookPreToolUse {
96 accumulated.Continue = false
97 accumulated.Permission = "deny"
98 accumulated.Message = fmt.Sprintf("Hook failed: %v", err)
99 return accumulated, nil
100 }
101 continue
102 }
103
104 if result.Message != "" {
105 slog.Info("Hook message", "path", hookPath, "message", result.Message)
106 }
107
108 m.mergeResults(&accumulated, result)
109
110 if !result.Continue {
111 slog.Info("Hook stopped execution", "path", hookPath)
112 break
113 }
114 }
115
116 return accumulated, nil
117}
118
119// discoverHooks finds all executable hooks for the given type.
120func (m *manager) discoverHooks(hookType HookType) []string {
121 if cached, ok := m.hooks.Get(hookType); ok {
122 return cached
123 }
124
125 var hooks []string
126
127 for _, dir := range m.config.Directories {
128 if _, err := os.Stat(dir); err == nil {
129 entries, err := os.ReadDir(dir)
130 if err == nil {
131 for _, entry := range entries {
132 if entry.IsDir() {
133 continue
134 }
135
136 hookPath := filepath.Join(dir, entry.Name())
137
138 info, err := entry.Info()
139 if err != nil {
140 continue
141 }
142
143 if !isExecutable(info) {
144 continue
145 }
146
147 hooks = append(hooks, hookPath)
148 slog.Debug("Discovered catch-all hook", "path", hookPath, "type", hookType)
149 }
150 }
151 }
152
153 hookDir := filepath.Join(dir, string(hookType))
154 if _, err := os.Stat(hookDir); os.IsNotExist(err) {
155 continue
156 }
157
158 entries, err := os.ReadDir(hookDir)
159 if err != nil {
160 slog.Error("Failed to read hooks directory", "dir", hookDir, "error", err)
161 continue
162 }
163
164 for _, entry := range entries {
165 if entry.IsDir() {
166 continue
167 }
168
169 hookPath := filepath.Join(hookDir, entry.Name())
170
171 info, err := entry.Info()
172 if err != nil {
173 continue
174 }
175
176 if !isExecutable(info) {
177 slog.Debug("Skipping non-executable hook", "path", hookPath)
178 continue
179 }
180
181 hooks = append(hooks, hookPath)
182 }
183 }
184
185 if inlineHooks, ok := m.config.Inline[string(hookType)]; ok {
186 for _, inline := range inlineHooks {
187 hookPath, err := m.writeInlineHook(hookType, inline)
188 if err != nil {
189 slog.Error("Failed to write inline hook", "name", inline.Name, "error", err)
190 continue
191 }
192 hooks = append(hooks, hookPath)
193 }
194 }
195
196 sort.Strings(hooks)
197 m.hooks.Set(hookType, hooks)
198 return hooks
199}
200
201// writeInlineHook writes an inline hook script to a temp file.
202func (m *manager) writeInlineHook(hookType HookType, inline InlineHook) (string, error) {
203 tempDir := filepath.Join(m.dataDir, "hooks", ".inline", string(hookType))
204 if err := os.MkdirAll(tempDir, 0o755); err != nil {
205 return "", err
206 }
207
208 hookPath := filepath.Join(tempDir, inline.Name)
209 if err := os.WriteFile(hookPath, []byte(inline.Script), 0o755); err != nil {
210 return "", err
211 }
212
213 return hookPath, nil
214}
215
216// isDisabled checks if a hook is in the disabled list.
217func (m *manager) isDisabled(hookPath string) bool {
218 for _, dir := range m.config.Directories {
219 if rel, err := filepath.Rel(dir, hookPath); err == nil {
220 // Normalize to forward slashes for cross-platform comparison
221 rel = filepath.ToSlash(rel)
222 if slices.Contains(m.config.DisableHooks, rel) {
223 return true
224 }
225 }
226 }
227 return false
228}
229
230// mergeResults merges a new result into the accumulated result.
231func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) {
232 accumulated.Continue = accumulated.Continue && new.Continue
233
234 if new.Permission != "" {
235 if new.Permission == "deny" {
236 accumulated.Permission = "deny"
237 } else if new.Permission == "ask" && accumulated.Permission != "deny" {
238 accumulated.Permission = "ask"
239 } else if new.Permission == "approve" && accumulated.Permission == "" {
240 accumulated.Permission = "approve"
241 }
242 }
243
244 if new.ModifiedPrompt != nil {
245 accumulated.ModifiedPrompt = new.ModifiedPrompt
246 }
247
248 if len(new.ModifiedInput) > 0 {
249 if accumulated.ModifiedInput == nil {
250 accumulated.ModifiedInput = make(map[string]any)
251 }
252 maps.Copy(accumulated.ModifiedInput, new.ModifiedInput)
253 }
254
255 if len(new.ModifiedOutput) > 0 {
256 if accumulated.ModifiedOutput == nil {
257 accumulated.ModifiedOutput = make(map[string]any)
258 }
259 maps.Copy(accumulated.ModifiedOutput, new.ModifiedOutput)
260 }
261
262 if new.ContextContent != "" {
263 if accumulated.ContextContent == "" {
264 accumulated.ContextContent = new.ContextContent
265 } else {
266 accumulated.ContextContent += "\n\n" + new.ContextContent
267 }
268 }
269
270 accumulated.ContextFiles = append(accumulated.ContextFiles, new.ContextFiles...)
271
272 if new.Message != "" {
273 if accumulated.Message == "" {
274 accumulated.Message = new.Message
275 } else {
276 accumulated.Message += "; " + new.Message
277 }
278 }
279}
280
281// ListHooks implements Manager.
282func (m *manager) ListHooks(hookType HookType) []string {
283 return m.discoverHooks(hookType)
284}
285
286// ExecuteUserPromptSubmit executes user-prompt-submit hooks.
287func (m *manager) ExecuteUserPromptSubmit(ctx context.Context, sessionID, workingDir string, data UserPromptSubmitData) (HookResult, error) {
288 hookCtx := HookContext{
289 SessionID: sessionID,
290 WorkingDir: workingDir,
291 Data: data,
292 }
293
294 return m.executeHooks(ctx, HookUserPromptSubmit, hookCtx)
295}
296
297// ExecutePreToolUse executes pre-tool-use hooks.
298func (m *manager) ExecutePreToolUse(ctx context.Context, sessionID, workingDir string, data PreToolUseData) (HookResult, error) {
299 hookCtx := HookContext{
300 SessionID: sessionID,
301 WorkingDir: workingDir,
302 ToolName: data.ToolName,
303 ToolCallID: data.ToolCallID,
304 Data: data,
305 }
306
307 return m.executeHooks(ctx, HookPreToolUse, hookCtx)
308}
309
310// ExecutePostToolUse executes post-tool-use hooks.
311func (m *manager) ExecutePostToolUse(ctx context.Context, sessionID, workingDir string, data PostToolUseData) (HookResult, error) {
312 hookCtx := HookContext{
313 SessionID: sessionID,
314 WorkingDir: workingDir,
315 ToolName: data.ToolName,
316 ToolCallID: data.ToolCallID,
317 Data: data,
318 }
319
320 return m.executeHooks(ctx, HookPostToolUse, hookCtx)
321}
322
323// ExecuteStop executes stop hooks.
324func (m *manager) ExecuteStop(ctx context.Context, sessionID, workingDir string, data StopData) (HookResult, error) {
325 hookCtx := HookContext{
326 SessionID: sessionID,
327 WorkingDir: workingDir,
328 Data: data,
329 }
330
331 return m.executeHooks(ctx, HookStop, hookCtx)
332}