1package hooks
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "maps"
8 "os"
9 "path/filepath"
10 "runtime"
11 "slices"
12 "sort"
13 "strings"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/csync"
17)
18
19type manager struct {
20 workingDir string
21 dataDir string
22 config *Config
23 executor *Executor
24 hooks *csync.Map[HookType, []string]
25}
26
27// NewManager creates a new hook manager.
28func NewManager(workingDir, dataDir string, cfg *Config) Manager {
29 if cfg == nil {
30 cfg = &Config{
31 Enabled: true,
32 TimeoutSeconds: 30,
33 Directories: []string{filepath.Join(dataDir, "hooks")},
34 }
35 }
36
37 // Ensure default directory if not specified.
38 if len(cfg.Directories) == 0 {
39 cfg.Directories = []string{filepath.Join(dataDir, "hooks")}
40 }
41
42 return &manager{
43 workingDir: workingDir,
44 dataDir: dataDir,
45 config: cfg,
46 executor: NewExecutor(workingDir),
47 hooks: csync.NewMap[HookType, []string](),
48 }
49}
50
51// isExecutable checks if a file is executable.
52// On Unix: checks execute permission bits for .sh files.
53// On Windows: only recognizes .sh extension (as we use POSIX shell emulator).
54func isExecutable(info os.FileInfo) bool {
55 name := strings.ToLower(info.Name())
56 if !strings.HasSuffix(name, ".sh") {
57 return false
58 }
59
60 if runtime.GOOS == "windows" {
61 return true
62 }
63 return info.Mode()&0o111 != 0
64}
65
66// ExecuteHooks implements Manager.
67func (m *manager) ExecuteHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) {
68 if !m.config.Enabled {
69 return HookResult{Continue: true}, nil
70 }
71
72 hookContext.HookType = hookType
73 hookContext.Environment = m.config.Environment
74
75 hooks := m.discoverHooks(hookType)
76 if len(hooks) == 0 {
77 return HookResult{Continue: true}, nil
78 }
79
80 slog.Debug("Executing hooks", "type", hookType, "count", len(hooks))
81
82 accumulated := HookResult{Continue: true}
83 for _, hookPath := range hooks {
84 if m.isDisabled(hookPath) {
85 slog.Debug("Skipping disabled hook", "path", hookPath)
86 continue
87 }
88
89 hookCtx, cancel := context.WithTimeout(ctx, time.Duration(m.config.TimeoutSeconds)*time.Second)
90
91 result, err := m.executor.Execute(hookCtx, hookPath, hookContext)
92 cancel()
93
94 if err != nil {
95 slog.Error("Hook execution failed", "path", hookPath, "error", err)
96 if hookType == HookPreToolUse {
97 accumulated.Continue = false
98 accumulated.Permission = "deny"
99 accumulated.Message = fmt.Sprintf("Hook failed: %v", err)
100 return accumulated, nil
101 }
102 continue
103 }
104
105 if result.Message != "" {
106 slog.Info("Hook message", "path", hookPath, "message", result.Message)
107 }
108
109 m.mergeResults(&accumulated, result)
110
111 if !result.Continue {
112 slog.Info("Hook stopped execution", "path", hookPath)
113 break
114 }
115 }
116
117 return accumulated, nil
118}
119
120// discoverHooks finds all executable hooks for the given type.
121func (m *manager) discoverHooks(hookType HookType) []string {
122 if cached, ok := m.hooks.Get(hookType); ok {
123 return cached
124 }
125
126 var hooks []string
127
128 for _, dir := range m.config.Directories {
129 if _, err := os.Stat(dir); err == nil {
130 entries, err := os.ReadDir(dir)
131 if err == nil {
132 for _, entry := range entries {
133 if entry.IsDir() {
134 continue
135 }
136
137 hookPath := filepath.Join(dir, entry.Name())
138
139 info, err := entry.Info()
140 if err != nil {
141 continue
142 }
143
144 if !isExecutable(info) {
145 continue
146 }
147
148 hooks = append(hooks, hookPath)
149 slog.Debug("Discovered catch-all hook", "path", hookPath, "type", hookType)
150 }
151 }
152 }
153
154 hookDir := filepath.Join(dir, string(hookType))
155 if _, err := os.Stat(hookDir); os.IsNotExist(err) {
156 continue
157 }
158
159 entries, err := os.ReadDir(hookDir)
160 if err != nil {
161 slog.Error("Failed to read hooks directory", "dir", hookDir, "error", err)
162 continue
163 }
164
165 for _, entry := range entries {
166 if entry.IsDir() {
167 continue
168 }
169
170 hookPath := filepath.Join(hookDir, entry.Name())
171
172 info, err := entry.Info()
173 if err != nil {
174 continue
175 }
176
177 if !isExecutable(info) {
178 slog.Debug("Skipping non-executable hook", "path", hookPath)
179 continue
180 }
181
182 hooks = append(hooks, hookPath)
183 }
184 }
185
186 if inlineHooks, ok := m.config.Inline[string(hookType)]; ok {
187 for _, inline := range inlineHooks {
188 hookPath, err := m.writeInlineHook(hookType, inline)
189 if err != nil {
190 slog.Error("Failed to write inline hook", "name", inline.Name, "error", err)
191 continue
192 }
193 hooks = append(hooks, hookPath)
194 }
195 }
196
197 sort.Strings(hooks)
198 m.hooks.Set(hookType, hooks)
199 return hooks
200}
201
202// writeInlineHook writes an inline hook script to a temp file.
203func (m *manager) writeInlineHook(hookType HookType, inline InlineHook) (string, error) {
204 tempDir := filepath.Join(m.dataDir, "hooks", ".inline", string(hookType))
205 if err := os.MkdirAll(tempDir, 0o755); err != nil {
206 return "", err
207 }
208
209 hookPath := filepath.Join(tempDir, inline.Name)
210 if err := os.WriteFile(hookPath, []byte(inline.Script), 0o755); err != nil {
211 return "", err
212 }
213
214 return hookPath, nil
215}
216
217// isDisabled checks if a hook is in the disabled list.
218func (m *manager) isDisabled(hookPath string) bool {
219 for _, dir := range m.config.Directories {
220 if rel, err := filepath.Rel(dir, hookPath); err == nil {
221 // Normalize to forward slashes for cross-platform comparison
222 rel = filepath.ToSlash(rel)
223 if slices.Contains(m.config.Disabled, rel) {
224 return true
225 }
226 }
227 }
228 return false
229}
230
231// mergeResults merges a new result into the accumulated result.
232func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) {
233 accumulated.Continue = accumulated.Continue && new.Continue
234
235 if new.Permission != "" {
236 if new.Permission == "deny" {
237 accumulated.Permission = "deny"
238 } else if new.Permission == "ask" && accumulated.Permission != "deny" {
239 accumulated.Permission = "ask"
240 } else if new.Permission == "approve" && accumulated.Permission == "" {
241 accumulated.Permission = "approve"
242 }
243 }
244
245 if new.ModifiedPrompt != nil {
246 accumulated.ModifiedPrompt = new.ModifiedPrompt
247 }
248
249 if len(new.ModifiedInput) > 0 {
250 if accumulated.ModifiedInput == nil {
251 accumulated.ModifiedInput = make(map[string]any)
252 }
253 maps.Copy(accumulated.ModifiedInput, new.ModifiedInput)
254 }
255
256 if len(new.ModifiedOutput) > 0 {
257 if accumulated.ModifiedOutput == nil {
258 accumulated.ModifiedOutput = make(map[string]any)
259 }
260 maps.Copy(accumulated.ModifiedOutput, new.ModifiedOutput)
261 }
262
263 if new.ContextContent != "" {
264 if accumulated.ContextContent == "" {
265 accumulated.ContextContent = new.ContextContent
266 } else {
267 accumulated.ContextContent += "\n\n" + new.ContextContent
268 }
269 }
270
271 accumulated.ContextFiles = append(accumulated.ContextFiles, new.ContextFiles...)
272
273 if new.Message != "" {
274 if accumulated.Message == "" {
275 accumulated.Message = new.Message
276 } else {
277 accumulated.Message += "; " + new.Message
278 }
279 }
280}
281
282// ListHooks implements Manager.
283func (m *manager) ListHooks(hookType HookType) []string {
284 return m.discoverHooks(hookType)
285}