manager.go

  1package hooks
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"maps"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"slices"
 12	"sort"
 13	"strings"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/csync"
 17)
 18
 19type manager struct {
 20	workingDir string
 21	dataDir    string
 22	config     *Config
 23	executor   *Executor
 24	hooks      *csync.Map[HookType, []string]
 25}
 26
 27// NewManager creates a new hook manager.
 28func NewManager(workingDir, dataDir string, cfg *Config) *manager {
 29	if cfg == nil {
 30		cfg = &Config{
 31			TimeoutSeconds: 30,
 32			Directories:    []string{filepath.Join(dataDir, "hooks")},
 33		}
 34	}
 35
 36	// Ensure default directory if not specified.
 37	if len(cfg.Directories) == 0 {
 38		cfg.Directories = []string{filepath.Join(dataDir, "hooks")}
 39	}
 40
 41	return &manager{
 42		workingDir: workingDir,
 43		dataDir:    dataDir,
 44		config:     cfg,
 45		executor:   NewExecutor(workingDir),
 46		hooks:      csync.NewMap[HookType, []string](),
 47	}
 48}
 49
 50// isExecutable checks if a file is executable.
 51// On Unix: checks execute permission bits for .sh files.
 52// On Windows: only recognizes .sh extension (as we use POSIX shell emulator).
 53func isExecutable(info os.FileInfo) bool {
 54	name := strings.ToLower(info.Name())
 55	if !strings.HasSuffix(name, ".sh") {
 56		return false
 57	}
 58
 59	if runtime.GOOS == "windows" {
 60		return true
 61	}
 62	return info.Mode()&0o111 != 0
 63}
 64
 65// executeHooks is the internal method that executes hooks for a given type.
 66func (m *manager) executeHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) {
 67	if m.config.Disabled {
 68		return HookResult{Continue: true}, nil
 69	}
 70
 71	hookContext.HookType = hookType
 72	hookContext.Environment = m.config.Environment
 73
 74	hooks := m.discoverHooks(hookType)
 75	if len(hooks) == 0 {
 76		return HookResult{Continue: true}, nil
 77	}
 78
 79	slog.Debug("Executing hooks", "type", hookType, "count", len(hooks))
 80
 81	accumulated := HookResult{Continue: true}
 82	for _, hookPath := range hooks {
 83		if m.isDisabled(hookPath) {
 84			slog.Debug("Skipping disabled hook", "path", hookPath)
 85			continue
 86		}
 87
 88		hookCtx, cancel := context.WithTimeout(ctx, time.Duration(m.config.TimeoutSeconds)*time.Second)
 89
 90		result, err := m.executor.Execute(hookCtx, hookPath, hookContext)
 91		cancel()
 92
 93		if err != nil {
 94			slog.Error("Hook execution failed", "path", hookPath, "error", err)
 95			if hookType == HookPreToolUse {
 96				accumulated.Continue = false
 97				accumulated.Permission = "deny"
 98				accumulated.Message = fmt.Sprintf("Hook failed: %v", err)
 99				return accumulated, nil
100			}
101			continue
102		}
103
104		if result.Message != "" {
105			slog.Info("Hook message", "path", hookPath, "message", result.Message)
106		}
107
108		m.mergeResults(&accumulated, result)
109
110		if !result.Continue {
111			slog.Info("Hook stopped execution", "path", hookPath)
112			break
113		}
114	}
115
116	return accumulated, nil
117}
118
119// discoverHooks finds all executable hooks for the given type.
120func (m *manager) discoverHooks(hookType HookType) []string {
121	if cached, ok := m.hooks.Get(hookType); ok {
122		return cached
123	}
124
125	var hooks []string
126
127	for _, dir := range m.config.Directories {
128		if _, err := os.Stat(dir); err == nil {
129			entries, err := os.ReadDir(dir)
130			if err == nil {
131				for _, entry := range entries {
132					if entry.IsDir() {
133						continue
134					}
135
136					hookPath := filepath.Join(dir, entry.Name())
137
138					info, err := entry.Info()
139					if err != nil {
140						continue
141					}
142
143					if !isExecutable(info) {
144						continue
145					}
146
147					hooks = append(hooks, hookPath)
148					slog.Debug("Discovered catch-all hook", "path", hookPath, "type", hookType)
149				}
150			}
151		}
152
153		hookDir := filepath.Join(dir, string(hookType))
154		if _, err := os.Stat(hookDir); os.IsNotExist(err) {
155			continue
156		}
157
158		entries, err := os.ReadDir(hookDir)
159		if err != nil {
160			slog.Error("Failed to read hooks directory", "dir", hookDir, "error", err)
161			continue
162		}
163
164		for _, entry := range entries {
165			if entry.IsDir() {
166				continue
167			}
168
169			hookPath := filepath.Join(hookDir, entry.Name())
170
171			info, err := entry.Info()
172			if err != nil {
173				continue
174			}
175
176			if !isExecutable(info) {
177				slog.Debug("Skipping non-executable hook", "path", hookPath)
178				continue
179			}
180
181			hooks = append(hooks, hookPath)
182		}
183	}
184
185	if inlineHooks, ok := m.config.Inline[string(hookType)]; ok {
186		for _, inline := range inlineHooks {
187			hookPath, err := m.writeInlineHook(hookType, inline)
188			if err != nil {
189				slog.Error("Failed to write inline hook", "name", inline.Name, "error", err)
190				continue
191			}
192			hooks = append(hooks, hookPath)
193		}
194	}
195
196	sort.Strings(hooks)
197	m.hooks.Set(hookType, hooks)
198	return hooks
199}
200
201// writeInlineHook writes an inline hook script to a temp file.
202func (m *manager) writeInlineHook(hookType HookType, inline InlineHook) (string, error) {
203	tempDir := filepath.Join(m.dataDir, "hooks", ".inline", string(hookType))
204	if err := os.MkdirAll(tempDir, 0o755); err != nil {
205		return "", err
206	}
207
208	hookPath := filepath.Join(tempDir, inline.Name)
209	if err := os.WriteFile(hookPath, []byte(inline.Script), 0o755); err != nil {
210		return "", err
211	}
212
213	return hookPath, nil
214}
215
216// isDisabled checks if a hook is in the disabled list.
217func (m *manager) isDisabled(hookPath string) bool {
218	for _, dir := range m.config.Directories {
219		if rel, err := filepath.Rel(dir, hookPath); err == nil {
220			// Normalize to forward slashes for cross-platform comparison
221			rel = filepath.ToSlash(rel)
222			if slices.Contains(m.config.DisableHooks, rel) {
223				return true
224			}
225		}
226	}
227	return false
228}
229
230// mergeResults merges a new result into the accumulated result.
231func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) {
232	accumulated.Continue = accumulated.Continue && new.Continue
233
234	if new.Permission != "" {
235		if new.Permission == "deny" {
236			accumulated.Permission = "deny"
237		} else if new.Permission == "ask" && accumulated.Permission != "deny" {
238			accumulated.Permission = "ask"
239		} else if new.Permission == "approve" && accumulated.Permission == "" {
240			accumulated.Permission = "approve"
241		}
242	}
243
244	if new.ModifiedPrompt != nil {
245		accumulated.ModifiedPrompt = new.ModifiedPrompt
246	}
247
248	if len(new.ModifiedInput) > 0 {
249		if accumulated.ModifiedInput == nil {
250			accumulated.ModifiedInput = make(map[string]any)
251		}
252		maps.Copy(accumulated.ModifiedInput, new.ModifiedInput)
253	}
254
255	if len(new.ModifiedOutput) > 0 {
256		if accumulated.ModifiedOutput == nil {
257			accumulated.ModifiedOutput = make(map[string]any)
258		}
259		maps.Copy(accumulated.ModifiedOutput, new.ModifiedOutput)
260	}
261
262	if new.ContextContent != "" {
263		if accumulated.ContextContent == "" {
264			accumulated.ContextContent = new.ContextContent
265		} else {
266			accumulated.ContextContent += "\n\n" + new.ContextContent
267		}
268	}
269
270	accumulated.ContextFiles = append(accumulated.ContextFiles, new.ContextFiles...)
271
272	if new.Message != "" {
273		if accumulated.Message == "" {
274			accumulated.Message = new.Message
275		} else {
276			accumulated.Message += "; " + new.Message
277		}
278	}
279}
280
281// ListHooks implements Manager.
282func (m *manager) ListHooks(hookType HookType) []string {
283	return m.discoverHooks(hookType)
284}
285
286// ExecuteUserPromptSubmit executes user-prompt-submit hooks.
287func (m *manager) ExecuteUserPromptSubmit(ctx context.Context, sessionID, workingDir string, data UserPromptSubmitData) (HookResult, error) {
288	hookCtx := HookContext{
289		SessionID:  sessionID,
290		WorkingDir: workingDir,
291		Data:       data,
292	}
293
294	return m.executeHooks(ctx, HookUserPromptSubmit, hookCtx)
295}
296
297// ExecutePreToolUse executes pre-tool-use hooks.
298func (m *manager) ExecutePreToolUse(ctx context.Context, sessionID, workingDir string, data PreToolUseData) (HookResult, error) {
299	hookCtx := HookContext{
300		SessionID:  sessionID,
301		WorkingDir: workingDir,
302		ToolName:   data.ToolName,
303		ToolCallID: data.ToolCallID,
304		Data:       data,
305	}
306
307	return m.executeHooks(ctx, HookPreToolUse, hookCtx)
308}
309
310// ExecutePostToolUse executes post-tool-use hooks.
311func (m *manager) ExecutePostToolUse(ctx context.Context, sessionID, workingDir string, data PostToolUseData) (HookResult, error) {
312	hookCtx := HookContext{
313		SessionID:  sessionID,
314		WorkingDir: workingDir,
315		ToolName:   data.ToolName,
316		ToolCallID: data.ToolCallID,
317		Data:       data,
318	}
319
320	return m.executeHooks(ctx, HookPostToolUse, hookCtx)
321}
322
323// ExecuteStop executes stop hooks.
324func (m *manager) ExecuteStop(ctx context.Context, sessionID, workingDir string, data StopData) (HookResult, error) {
325	hookCtx := HookContext{
326		SessionID:  sessionID,
327		WorkingDir: workingDir,
328		Data:       data,
329	}
330
331	return m.executeHooks(ctx, HookStop, hookCtx)
332}