manager.go

  1package hooks
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"maps"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"slices"
 12	"sort"
 13	"strings"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/csync"
 17)
 18
 19type manager struct {
 20	workingDir string
 21	dataDir    string
 22	config     *Config
 23	executor   *Executor
 24	hooks      *csync.Map[HookType, []string]
 25}
 26
 27// NewManager creates a new hook manager.
 28func NewManager(workingDir, dataDir string, cfg *Config) *manager {
 29	if cfg == nil {
 30		cfg = &Config{
 31			TimeoutSeconds: 30,
 32			Directories:    []string{filepath.Join(dataDir, "hooks")},
 33		}
 34	}
 35
 36	defaultHooksDir := filepath.Join(dataDir, "hooks")
 37
 38	// Ensure default directory if not specified.
 39	if len(cfg.Directories) == 0 {
 40		cfg.Directories = []string{defaultHooksDir}
 41	} else {
 42		// Always include default hooks directory even when user overrides config.
 43		if !slices.Contains(cfg.Directories, defaultHooksDir) {
 44			cfg.Directories = append([]string{defaultHooksDir}, cfg.Directories...)
 45		}
 46	}
 47
 48	return &manager{
 49		workingDir: workingDir,
 50		dataDir:    dataDir,
 51		config:     cfg,
 52		executor:   NewExecutor(workingDir),
 53		hooks:      csync.NewMap[HookType, []string](),
 54	}
 55}
 56
 57// isExecutable checks if a file is executable.
 58// On Unix: checks execute permission bits for .sh files.
 59// On Windows: only recognizes .sh extension (as we use POSIX shell emulator).
 60func isExecutable(info os.FileInfo) bool {
 61	name := strings.ToLower(info.Name())
 62	if !strings.HasSuffix(name, ".sh") {
 63		return false
 64	}
 65
 66	if runtime.GOOS == "windows" {
 67		return true
 68	}
 69	return info.Mode()&0o111 != 0
 70}
 71
 72// executeHooks is the internal method that executes hooks for a given type.
 73func (m *manager) executeHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) {
 74	if m.config.Disabled {
 75		return HookResult{Continue: true}, nil
 76	}
 77
 78	hookContext.HookType = hookType
 79	hookContext.Environment = m.config.Environment
 80
 81	hooks := m.discoverHooks(hookType)
 82	if len(hooks) == 0 {
 83		return HookResult{Continue: true}, nil
 84	}
 85
 86	slog.Debug("Executing hooks", "type", hookType, "count", len(hooks))
 87
 88	accumulated := HookResult{Continue: true}
 89	for _, hookPath := range hooks {
 90		if m.isDisabled(hookPath) {
 91			slog.Debug("Skipping disabled hook", "path", hookPath)
 92			continue
 93		}
 94
 95		hookCtx, cancel := context.WithTimeout(ctx, time.Duration(m.config.TimeoutSeconds)*time.Second)
 96
 97		result, err := m.executor.Execute(hookCtx, hookPath, hookContext)
 98		cancel()
 99
100		if err != nil {
101			slog.Error("Hook execution failed", "path", hookPath, "error", err)
102			if hookType == HookPreToolUse {
103				accumulated.Continue = false
104				accumulated.Permission = "deny"
105				accumulated.Message = fmt.Sprintf("Hook failed: %v", err)
106				return accumulated, nil
107			}
108			continue
109		}
110
111		if result.Message != "" {
112			slog.Info("Hook message", "path", hookPath, "message", result.Message)
113		}
114
115		m.mergeResults(&accumulated, result)
116
117		if !result.Continue {
118			slog.Info("Hook stopped execution", "path", hookPath)
119			break
120		}
121	}
122
123	return accumulated, nil
124}
125
126// discoverHooks finds all executable hooks for the given type.
127func (m *manager) discoverHooks(hookType HookType) []string {
128	if cached, ok := m.hooks.Get(hookType); ok {
129		return cached
130	}
131
132	var hooks []string
133
134	for _, dir := range m.config.Directories {
135		if _, err := os.Stat(dir); err == nil {
136			entries, err := os.ReadDir(dir)
137			if err == nil {
138				for _, entry := range entries {
139					if entry.IsDir() {
140						continue
141					}
142
143					hookPath := filepath.Join(dir, entry.Name())
144
145					info, err := entry.Info()
146					if err != nil {
147						continue
148					}
149
150					if !isExecutable(info) {
151						continue
152					}
153
154					hooks = append(hooks, hookPath)
155					slog.Debug("Discovered catch-all hook", "path", hookPath, "type", hookType)
156				}
157			}
158		}
159
160		hookDir := filepath.Join(dir, string(hookType))
161		if _, err := os.Stat(hookDir); os.IsNotExist(err) {
162			continue
163		}
164
165		entries, err := os.ReadDir(hookDir)
166		if err != nil {
167			slog.Error("Failed to read hooks directory", "dir", hookDir, "error", err)
168			continue
169		}
170
171		for _, entry := range entries {
172			if entry.IsDir() {
173				continue
174			}
175
176			hookPath := filepath.Join(hookDir, entry.Name())
177
178			info, err := entry.Info()
179			if err != nil {
180				continue
181			}
182
183			if !isExecutable(info) {
184				slog.Debug("Skipping non-executable hook", "path", hookPath)
185				continue
186			}
187
188			hooks = append(hooks, hookPath)
189		}
190	}
191
192	if inlineHooks, ok := m.config.Inline[string(hookType)]; ok {
193		for _, inline := range inlineHooks {
194			hookPath, err := m.writeInlineHook(hookType, inline)
195			if err != nil {
196				slog.Error("Failed to write inline hook", "name", inline.Name, "error", err)
197				continue
198			}
199			hooks = append(hooks, hookPath)
200		}
201	}
202
203	sort.Strings(hooks)
204	m.hooks.Set(hookType, hooks)
205	return hooks
206}
207
208// writeInlineHook writes an inline hook script to a temp file.
209func (m *manager) writeInlineHook(hookType HookType, inline InlineHook) (string, error) {
210	tempDir := filepath.Join(m.dataDir, "hooks", ".inline", string(hookType))
211	if err := os.MkdirAll(tempDir, 0o755); err != nil {
212		return "", err
213	}
214
215	hookPath := filepath.Join(tempDir, inline.Name)
216	if err := os.WriteFile(hookPath, []byte(inline.Script), 0o755); err != nil {
217		return "", err
218	}
219
220	return hookPath, nil
221}
222
223// isDisabled checks if a hook is in the disabled list.
224func (m *manager) isDisabled(hookPath string) bool {
225	for _, dir := range m.config.Directories {
226		if rel, err := filepath.Rel(dir, hookPath); err == nil {
227			// Normalize to forward slashes for cross-platform comparison
228			rel = filepath.ToSlash(rel)
229			if slices.Contains(m.config.DisableHooks, rel) {
230				return true
231			}
232		}
233	}
234	return false
235}
236
237// mergeResults merges a new result into the accumulated result.
238func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) {
239	accumulated.Continue = accumulated.Continue && new.Continue
240
241	if new.Permission != "" {
242		if new.Permission == "deny" {
243			accumulated.Permission = "deny"
244		} else if new.Permission == "approve" && accumulated.Permission == "" {
245			accumulated.Permission = "approve"
246		}
247	}
248
249	if new.ModifiedPrompt != nil {
250		accumulated.ModifiedPrompt = new.ModifiedPrompt
251	}
252
253	if len(new.ModifiedInput) > 0 {
254		if accumulated.ModifiedInput == nil {
255			accumulated.ModifiedInput = make(map[string]any)
256		}
257		maps.Copy(accumulated.ModifiedInput, new.ModifiedInput)
258	}
259
260	if len(new.ModifiedOutput) > 0 {
261		if accumulated.ModifiedOutput == nil {
262			accumulated.ModifiedOutput = make(map[string]any)
263		}
264		maps.Copy(accumulated.ModifiedOutput, new.ModifiedOutput)
265	}
266
267	if new.ContextContent != "" {
268		if accumulated.ContextContent == "" {
269			accumulated.ContextContent = new.ContextContent
270		} else {
271			accumulated.ContextContent += "\n\n" + new.ContextContent
272		}
273	}
274
275	accumulated.ContextFiles = append(accumulated.ContextFiles, new.ContextFiles...)
276
277	if new.Message != "" {
278		if accumulated.Message == "" {
279			accumulated.Message = new.Message
280		} else {
281			accumulated.Message += "; " + new.Message
282		}
283	}
284}
285
286// ListHooks implements Manager.
287func (m *manager) ListHooks(hookType HookType) []string {
288	return m.discoverHooks(hookType)
289}
290
291// ExecuteUserPromptSubmit executes user-prompt-submit hooks.
292func (m *manager) ExecuteUserPromptSubmit(ctx context.Context, sessionID, workingDir string, data UserPromptSubmitData) (HookResult, error) {
293	hookCtx := HookContext{
294		SessionID:  sessionID,
295		WorkingDir: workingDir,
296		Data:       data,
297	}
298
299	return m.executeHooks(ctx, HookUserPromptSubmit, hookCtx)
300}
301
302// ExecutePreToolUse executes pre-tool-use hooks.
303func (m *manager) ExecutePreToolUse(ctx context.Context, sessionID, workingDir string, data PreToolUseData) (HookResult, error) {
304	hookCtx := HookContext{
305		SessionID:  sessionID,
306		WorkingDir: workingDir,
307		ToolName:   data.ToolName,
308		ToolCallID: data.ToolCallID,
309		Data:       data,
310	}
311
312	return m.executeHooks(ctx, HookPreToolUse, hookCtx)
313}
314
315// ExecutePostToolUse executes post-tool-use hooks.
316func (m *manager) ExecutePostToolUse(ctx context.Context, sessionID, workingDir string, data PostToolUseData) (HookResult, error) {
317	hookCtx := HookContext{
318		SessionID:  sessionID,
319		WorkingDir: workingDir,
320		ToolName:   data.ToolName,
321		ToolCallID: data.ToolCallID,
322		Data:       data,
323	}
324
325	return m.executeHooks(ctx, HookPostToolUse, hookCtx)
326}
327
328// ExecuteStop executes stop hooks.
329func (m *manager) ExecuteStop(ctx context.Context, sessionID, workingDir string, data StopData) (HookResult, error) {
330	hookCtx := HookContext{
331		SessionID:  sessionID,
332		WorkingDir: workingDir,
333		Data:       data,
334	}
335
336	return m.executeHooks(ctx, HookStop, hookCtx)
337}