manager.go

  1package hooks
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"maps"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"slices"
 12	"sort"
 13	"strings"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/csync"
 17)
 18
 19type manager struct {
 20	workingDir string
 21	dataDir    string
 22	config     *Config
 23	executor   *Executor
 24	hooks      *csync.Map[HookType, []string]
 25}
 26
 27// NewManager creates a new hook manager.
 28func NewManager(workingDir, dataDir string, cfg *Config) Manager {
 29	if cfg == nil {
 30		cfg = &Config{
 31			Enabled:        true,
 32			TimeoutSeconds: 30,
 33			Directories:    []string{filepath.Join(dataDir, "hooks")},
 34		}
 35	}
 36
 37	// Ensure default directory if not specified.
 38	if len(cfg.Directories) == 0 {
 39		cfg.Directories = []string{filepath.Join(dataDir, "hooks")}
 40	}
 41
 42	return &manager{
 43		workingDir: workingDir,
 44		dataDir:    dataDir,
 45		config:     cfg,
 46		executor:   NewExecutor(workingDir),
 47		hooks:      csync.NewMap[HookType, []string](),
 48	}
 49}
 50
 51// isExecutable checks if a file is executable.
 52// On Unix: checks execute permission bits for .sh files.
 53// On Windows: only recognizes .sh extension (as we use POSIX shell emulator).
 54func isExecutable(info os.FileInfo) bool {
 55	name := strings.ToLower(info.Name())
 56	if !strings.HasSuffix(name, ".sh") {
 57		return false
 58	}
 59
 60	if runtime.GOOS == "windows" {
 61		return true
 62	}
 63	return info.Mode()&0o111 != 0
 64}
 65
 66// ExecuteHooks implements Manager.
 67func (m *manager) ExecuteHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) {
 68	if !m.config.Enabled {
 69		return HookResult{Continue: true}, nil
 70	}
 71
 72	hookContext.HookType = hookType
 73	hookContext.Environment = m.config.Environment
 74
 75	hooks := m.discoverHooks(hookType)
 76	if len(hooks) == 0 {
 77		return HookResult{Continue: true}, nil
 78	}
 79
 80	slog.Debug("Executing hooks", "type", hookType, "count", len(hooks))
 81
 82	accumulated := HookResult{Continue: true}
 83	for _, hookPath := range hooks {
 84		if m.isDisabled(hookPath) {
 85			slog.Debug("Skipping disabled hook", "path", hookPath)
 86			continue
 87		}
 88
 89		hookCtx, cancel := context.WithTimeout(ctx, time.Duration(m.config.TimeoutSeconds)*time.Second)
 90
 91		result, err := m.executor.Execute(hookCtx, hookPath, hookContext)
 92		cancel()
 93
 94		if err != nil {
 95			slog.Error("Hook execution failed", "path", hookPath, "error", err)
 96			if hookType == HookPreToolUse {
 97				accumulated.Continue = false
 98				accumulated.Permission = "deny"
 99				accumulated.Message = fmt.Sprintf("Hook failed: %v", err)
100				return accumulated, nil
101			}
102			continue
103		}
104
105		if result.Message != "" {
106			slog.Info("Hook message", "path", hookPath, "message", result.Message)
107		}
108
109		m.mergeResults(&accumulated, result)
110
111		if !result.Continue {
112			slog.Info("Hook stopped execution", "path", hookPath)
113			break
114		}
115	}
116
117	return accumulated, nil
118}
119
120// discoverHooks finds all executable hooks for the given type.
121func (m *manager) discoverHooks(hookType HookType) []string {
122	if cached, ok := m.hooks.Get(hookType); ok {
123		return cached
124	}
125
126	var hooks []string
127
128	for _, dir := range m.config.Directories {
129		if _, err := os.Stat(dir); err == nil {
130			entries, err := os.ReadDir(dir)
131			if err == nil {
132				for _, entry := range entries {
133					if entry.IsDir() {
134						continue
135					}
136
137					hookPath := filepath.Join(dir, entry.Name())
138
139					info, err := entry.Info()
140					if err != nil {
141						continue
142					}
143
144					if !isExecutable(info) {
145						continue
146					}
147
148					hooks = append(hooks, hookPath)
149					slog.Debug("Discovered catch-all hook", "path", hookPath, "type", hookType)
150				}
151			}
152		}
153
154		hookDir := filepath.Join(dir, string(hookType))
155		if _, err := os.Stat(hookDir); os.IsNotExist(err) {
156			continue
157		}
158
159		entries, err := os.ReadDir(hookDir)
160		if err != nil {
161			slog.Error("Failed to read hooks directory", "dir", hookDir, "error", err)
162			continue
163		}
164
165		for _, entry := range entries {
166			if entry.IsDir() {
167				continue
168			}
169
170			hookPath := filepath.Join(hookDir, entry.Name())
171
172			info, err := entry.Info()
173			if err != nil {
174				continue
175			}
176
177			if !isExecutable(info) {
178				slog.Debug("Skipping non-executable hook", "path", hookPath)
179				continue
180			}
181
182			hooks = append(hooks, hookPath)
183		}
184	}
185
186	if inlineHooks, ok := m.config.Inline[string(hookType)]; ok {
187		for _, inline := range inlineHooks {
188			hookPath, err := m.writeInlineHook(hookType, inline)
189			if err != nil {
190				slog.Error("Failed to write inline hook", "name", inline.Name, "error", err)
191				continue
192			}
193			hooks = append(hooks, hookPath)
194		}
195	}
196
197	sort.Strings(hooks)
198	m.hooks.Set(hookType, hooks)
199	return hooks
200}
201
202// writeInlineHook writes an inline hook script to a temp file.
203func (m *manager) writeInlineHook(hookType HookType, inline InlineHook) (string, error) {
204	tempDir := filepath.Join(m.dataDir, "hooks", ".inline", string(hookType))
205	if err := os.MkdirAll(tempDir, 0o755); err != nil {
206		return "", err
207	}
208
209	hookPath := filepath.Join(tempDir, inline.Name)
210	if err := os.WriteFile(hookPath, []byte(inline.Script), 0o755); err != nil {
211		return "", err
212	}
213
214	return hookPath, nil
215}
216
217// isDisabled checks if a hook is in the disabled list.
218func (m *manager) isDisabled(hookPath string) bool {
219	for _, dir := range m.config.Directories {
220		if rel, err := filepath.Rel(dir, hookPath); err == nil {
221			// Normalize to forward slashes for cross-platform comparison
222			rel = filepath.ToSlash(rel)
223			if slices.Contains(m.config.Disabled, rel) {
224				return true
225			}
226		}
227	}
228	return false
229}
230
231// mergeResults merges a new result into the accumulated result.
232func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) {
233	accumulated.Continue = accumulated.Continue && new.Continue
234
235	if new.Permission != "" {
236		if new.Permission == "deny" {
237			accumulated.Permission = "deny"
238		} else if new.Permission == "ask" && accumulated.Permission != "deny" {
239			accumulated.Permission = "ask"
240		} else if new.Permission == "approve" && accumulated.Permission == "" {
241			accumulated.Permission = "approve"
242		}
243	}
244
245	if new.ModifiedPrompt != nil {
246		accumulated.ModifiedPrompt = new.ModifiedPrompt
247	}
248
249	if len(new.ModifiedInput) > 0 {
250		if accumulated.ModifiedInput == nil {
251			accumulated.ModifiedInput = make(map[string]any)
252		}
253		maps.Copy(accumulated.ModifiedInput, new.ModifiedInput)
254	}
255
256	if len(new.ModifiedOutput) > 0 {
257		if accumulated.ModifiedOutput == nil {
258			accumulated.ModifiedOutput = make(map[string]any)
259		}
260		maps.Copy(accumulated.ModifiedOutput, new.ModifiedOutput)
261	}
262
263	if new.ContextContent != "" {
264		if accumulated.ContextContent == "" {
265			accumulated.ContextContent = new.ContextContent
266		} else {
267			accumulated.ContextContent += "\n\n" + new.ContextContent
268		}
269	}
270
271	accumulated.ContextFiles = append(accumulated.ContextFiles, new.ContextFiles...)
272
273	if new.Message != "" {
274		if accumulated.Message == "" {
275			accumulated.Message = new.Message
276		} else {
277			accumulated.Message += "; " + new.Message
278		}
279	}
280}
281
282// ListHooks implements Manager.
283func (m *manager) ListHooks(hookType HookType) []string {
284	return m.discoverHooks(hookType)
285}