hooks.go

  1package hooks
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"regexp"
 10	"strings"
 11	"sync"
 12	"time"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/shell"
 20)
 21
 22const DefaultHookTimeout = 30 * time.Second
 23
 24// HookContext contains context information passed to hooks.
 25// Fields are populated based on event type - see field comments for availability.
 26type HookContext struct {
 27	// EventType is the lifecycle event that triggered this hook.
 28	// Available: All events
 29	EventType config.HookEventType `json:"event_type"`
 30
 31	// SessionID identifies the Crush session.
 32	// Available: All events (if session exists)
 33	SessionID string `json:"session_id,omitempty"`
 34
 35	// TranscriptPath is the path to exported session transcript JSON file.
 36	// Available: All events (if session exists and export succeeds)
 37	TranscriptPath string `json:"transcript_path,omitempty"`
 38
 39	// ToolName is the name of the tool being called.
 40	// Available: pre_tool_use, post_tool_use, permission_requested
 41	ToolName string `json:"tool_name,omitempty"`
 42
 43	// ToolInput contains the tool's input parameters as key-value pairs.
 44	// Available: pre_tool_use, post_tool_use
 45	ToolInput map[string]any `json:"tool_input,omitempty"`
 46
 47	// ToolResult is the string result returned by the tool.
 48	// Available: post_tool_use
 49	ToolResult string `json:"tool_result,omitempty"`
 50
 51	// ToolError indicates whether the tool execution failed.
 52	// Available: post_tool_use
 53	ToolError bool `json:"tool_error,omitempty"`
 54
 55	// UserPrompt is the prompt submitted by the user.
 56	// Available: user_prompt_submit
 57	UserPrompt string `json:"user_prompt,omitempty"`
 58
 59	// Timestamp is when the hook was triggered (RFC3339 format).
 60	// Available: All events
 61	Timestamp time.Time `json:"timestamp"`
 62
 63	// WorkingDir is the project working directory.
 64	// Available: All events
 65	WorkingDir string `json:"working_dir,omitempty"`
 66
 67	// MessageID identifies the assistant message being processed.
 68	// Available: pre_tool_use, post_tool_use, stop
 69	MessageID string `json:"message_id,omitempty"`
 70
 71	// Provider is the LLM provider name (e.g., "anthropic").
 72	// Available: All events with LLM interaction
 73	Provider string `json:"provider,omitempty"`
 74
 75	// Model is the LLM model name (e.g., "claude-3-5-sonnet-20241022").
 76	// Available: All events with LLM interaction
 77	Model string `json:"model,omitempty"`
 78
 79	// TokensUsed is the total tokens consumed in this interaction.
 80	// Available: stop
 81	TokensUsed int64 `json:"tokens_used,omitempty"`
 82
 83	// TokensInput is the input tokens consumed in this interaction.
 84	// Available: stop
 85	TokensInput int64 `json:"tokens_input,omitempty"`
 86
 87	// PermissionAction is the permission action being requested (e.g., "read", "write").
 88	// Available: permission_requested
 89	PermissionAction string `json:"permission_action,omitempty"`
 90
 91	// PermissionPath is the file path involved in the permission request.
 92	// Available: permission_requested
 93	PermissionPath string `json:"permission_path,omitempty"`
 94
 95	// PermissionParams contains additional permission parameters.
 96	// Available: permission_requested
 97	PermissionRequest *permission.PermissionRequest `json:"permission_request,omitempty"`
 98}
 99
100type HookDecision string
101
102const (
103	HookDecisionBlock HookDecision = "block"
104	HookDecisionDeny  HookDecision = "deny"
105	HookDecisionAllow HookDecision = "allow"
106	HookDecisionAsk   HookDecision = "ask"
107)
108
109type Service interface {
110	Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error)
111	SetSmallModel(model fantasy.LanguageModel)
112}
113
114type service struct {
115	config     config.HookConfig
116	workingDir string
117	regexCache *csync.Map[string, *regexp.Regexp]
118	smallModel fantasy.LanguageModel
119	messages   message.Service
120}
121
122// NewService creates a new hook executor.
123func NewService(
124	hookConfig config.HookConfig,
125	workingDir string,
126	smallModel fantasy.LanguageModel,
127	messages message.Service,
128) Service {
129	return &service{
130		config:     hookConfig,
131		workingDir: workingDir,
132		regexCache: csync.NewMap[string, *regexp.Regexp](),
133		smallModel: smallModel,
134		messages:   messages,
135	}
136}
137
138// Execute implements Service.
139func (s *service) Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error) {
140	if s.config == nil {
141		return nil, nil
142	}
143
144	// Check if context is already cancelled - prevents race conditions during cancellation.
145	if ctx.Err() != nil {
146		return nil, ctx.Err()
147	}
148
149	hookCtx.Timestamp = time.Now()
150	hookCtx.WorkingDir = s.workingDir
151
152	hooks := s.collectMatchingHooks(hookCtx)
153	if len(hooks) == 0 {
154		return nil, nil
155	}
156
157	transcriptPath, cleanup, err := s.setupTranscript(ctx, hookCtx)
158	if err != nil {
159		slog.Warn("Failed to export transcript for hooks", "error", err)
160	} else if transcriptPath != "" {
161		hookCtx.TranscriptPath = transcriptPath
162		// Ensure cleanup happens even on panic.
163		defer func() {
164			cleanup()
165		}()
166	}
167
168	results := make([]message.HookOutput, len(hooks))
169	var wg sync.WaitGroup
170
171	for i, hook := range hooks {
172		if ctx.Err() != nil {
173			return nil, ctx.Err()
174		}
175		wg.Add(1)
176		go func(idx int, h config.Hook) {
177			defer wg.Done()
178			result, err := s.executeHook(ctx, h, hookCtx)
179			if err != nil {
180				slog.Warn("Hook execution failed",
181					"event", hookCtx.EventType,
182					"error", err,
183				)
184			} else {
185				results[idx] = *result
186			}
187		}(i, hook)
188	}
189	wg.Wait()
190	return results, nil
191}
192
193func (s *service) setupTranscript(ctx context.Context, hookCtx HookContext) (string, func(), error) {
194	if hookCtx.SessionID == "" || s.messages == nil {
195		return "", func() {}, nil
196	}
197
198	path, err := exportTranscript(ctx, s.messages, hookCtx.SessionID)
199	if err != nil {
200		return "", func() {}, err
201	}
202
203	cleanup := func() {
204		if path != "" {
205			cleanupTranscript(path)
206		}
207	}
208
209	return path, cleanup, nil
210}
211
212func (s *service) executeHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
213	var result *message.HookOutput
214	var err error
215
216	switch hook.Type {
217	case "prompt":
218		if hook.Prompt == "" {
219			return nil, fmt.Errorf("prompt-based hook missing 'prompt' field")
220		}
221		if s.smallModel == nil {
222			return nil, fmt.Errorf("prompt-based hook requires small model configuration")
223		}
224		result, err = s.executePromptHook(ctx, hook, hookCtx)
225	case "", "command":
226		if hook.Command == "" {
227			return nil, fmt.Errorf("command-based hook missing 'command' field")
228		}
229		slog.Info("executing")
230		result, err = s.executeCommandHook(ctx, hook, hookCtx)
231	default:
232		return nil, fmt.Errorf("unsupported hook type: %s", hook.Type)
233	}
234
235	if result != nil {
236		result.EventType = string(hookCtx.EventType)
237	}
238
239	return result, err
240}
241
242func (s *service) executePromptHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
243	contextJSON, err := json.Marshal(hookCtx)
244	if err != nil {
245		return nil, fmt.Errorf("failed to marshal hook context: %w", err)
246	}
247
248	var finalPrompt string
249	if strings.Contains(hook.Prompt, "$ARGUMENTS") {
250		finalPrompt = strings.ReplaceAll(hook.Prompt, "$ARGUMENTS", string(contextJSON))
251	} else {
252		finalPrompt = fmt.Sprintf("%s\n\nContext: %s", hook.Prompt, string(contextJSON))
253	}
254
255	timeout := DefaultHookTimeout
256	if hook.Timeout != nil {
257		timeout = time.Duration(*hook.Timeout) * time.Second
258	}
259
260	execCtx, cancel := context.WithTimeout(ctx, timeout)
261	defer cancel()
262
263	type readTranscriptParams struct{}
264	readTranscriptTool := fantasy.NewAgentTool(
265		"read_transcript",
266		"Used to read the conversation so far",
267		func(ctx context.Context, params readTranscriptParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
268			if hookCtx.TranscriptPath == "" {
269				return fantasy.NewTextErrorResponse("No transcript available"), nil
270			}
271			data, err := os.ReadFile(hookCtx.TranscriptPath)
272			if err != nil {
273				return fantasy.NewTextErrorResponse(err.Error()), nil
274			}
275			return fantasy.NewTextResponse(string(data)), nil
276		})
277
278	var output *message.HookOutput
279	outputTool := fantasy.NewAgentTool(
280		"output",
281		"Used to submit the output, remember you MUST call this tool at the end",
282		func(ctx context.Context, params message.HookOutput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
283			output = &params
284			return fantasy.NewTextResponse("ouptut submitted"), nil
285		})
286	agent := fantasy.NewAgent(
287		s.smallModel,
288		fantasy.WithSystemPrompt(`You are a helpful sub agent used in a larger agents conversation loop,
289			your goal is to intercept the conversation and fulfill the intermediate requests, makesure to ALWAYS use the output tool at the end to output your decision`),
290		fantasy.WithTools(readTranscriptTool, outputTool),
291	)
292
293	_, err = agent.Generate(execCtx, fantasy.AgentCall{
294		Prompt: finalPrompt,
295	})
296	if err != nil {
297		return nil, err
298	}
299	return output, nil
300}
301
302func (s *service) executeCommandHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
303	timeout := DefaultHookTimeout
304	if hook.Timeout != nil {
305		timeout = time.Duration(*hook.Timeout) * time.Second
306	}
307
308	execCtx, cancel := context.WithTimeout(ctx, timeout)
309	defer cancel()
310
311	contextJSON, err := json.Marshal(hookCtx)
312	if err != nil {
313		return nil, fmt.Errorf("failed to marshal hook context: %w", err)
314	}
315
316	sh := shell.NewShell(&shell.Options{
317		WorkingDir: s.workingDir,
318	})
319
320	sh.SetEnv("CRUSH_HOOK_EVENT", string(hookCtx.EventType))
321	sh.SetEnv("CRUSH_HOOK_CONTEXT", string(contextJSON))
322	sh.SetEnv("CRUSH_PROJECT_DIR", s.workingDir)
323	if hookCtx.SessionID != "" {
324		sh.SetEnv("CRUSH_SESSION_ID", hookCtx.SessionID)
325	}
326	if hookCtx.ToolName != "" {
327		sh.SetEnv("CRUSH_TOOL_NAME", hookCtx.ToolName)
328	}
329
330	slog.Debug("Hook execution trace",
331		"event", hookCtx.EventType,
332		"command", hook.Command,
333		"timeout", timeout,
334		"context", hookCtx,
335	)
336
337	stdout, stderr, err := sh.ExecWithStdin(execCtx, hook.Command, string(contextJSON))
338
339	exitCode := shell.ExitCode(err)
340
341	var result *message.HookOutput
342	switch exitCode {
343	case 0:
344		// if the event is  UserPromptSubmit we want the output to be added to the context
345		if hookCtx.EventType == config.UserPromptSubmit {
346			result = &message.HookOutput{
347				AdditionalContext: stdout,
348			}
349		} else {
350			result = &message.HookOutput{
351				Message: stdout,
352			}
353		}
354	case 2:
355		result = &message.HookOutput{
356			Decision: string(HookDecisionBlock),
357			Error:    stderr,
358		}
359		return result, nil
360	default:
361		result = &message.HookOutput{
362			Error: stderr,
363		}
364		return result, nil
365	}
366
367	jsonOutput := parseHookOutput(stdout)
368	if jsonOutput == nil {
369		return result, nil
370	}
371
372	result.Message = jsonOutput.Message
373	result.Stop = jsonOutput.Stop
374	result.Decision = jsonOutput.Decision
375	result.AdditionalContext = jsonOutput.AdditionalContext
376	result.UpdatedInput = jsonOutput.UpdatedInput
377
378	// Trace output in debug mode
379	slog.Debug("Hook execution output",
380		"event", hookCtx.EventType,
381		"exit_code", exitCode,
382		"stdout_length", len(stdout),
383		"stderr_length", len(stderr),
384		"stdout", stdout,
385		"stderr", stderr,
386	)
387	return result, nil
388}
389
390func parseHookOutput(stdout string) *message.HookOutput {
391	stdout = strings.TrimSpace(stdout)
392	slog.Info(stdout)
393	if stdout == "" {
394		return nil
395	}
396
397	var output message.HookOutput
398	if err := json.Unmarshal([]byte(stdout), &output); err != nil {
399		// Failed to parse as HookOutput
400		return nil
401	}
402
403	return &output
404}
405
406func (s *service) SetSmallModel(model fantasy.LanguageModel) {
407	s.smallModel = model
408}
409
410func (s *service) collectMatchingHooks(hookCtx HookContext) []config.Hook {
411	matchers, ok := s.config[hookCtx.EventType]
412	if !ok || len(matchers) == 0 {
413		return nil
414	}
415
416	var hooks []config.Hook
417	for _, matcher := range matchers {
418		if !s.matcherApplies(matcher, hookCtx) {
419			continue
420		}
421		hooks = append(hooks, matcher.Hooks...)
422	}
423	return hooks
424}
425
426func (s *service) matcherApplies(matcher config.HookMatcher, ctx HookContext) bool {
427	if ctx.EventType == config.PreToolUse || ctx.EventType == config.PostToolUse {
428		return s.matchesToolName(matcher.Matcher, ctx.ToolName)
429	}
430
431	return matcher.Matcher == "" || matcher.Matcher == "*"
432}
433
434func (s *service) matchesToolName(pattern, toolName string) bool {
435	if pattern == "" || pattern == "*" {
436		return true
437	}
438
439	if pattern == toolName {
440		return true
441	}
442
443	if strings.Contains(pattern, "|") {
444		for tool := range strings.SplitSeq(pattern, "|") {
445			tool = strings.TrimSpace(tool)
446			if tool == toolName {
447				return true
448			}
449		}
450
451		return s.matchesRegex(pattern, toolName)
452	}
453
454	return s.matchesRegex(pattern, toolName)
455}
456
457func (s *service) matchesRegex(pattern, text string) bool {
458	re, ok := s.regexCache.Get(pattern)
459	if !ok {
460		compiled, err := regexp.Compile(pattern)
461		if err != nil {
462			// Not a valid regex, don't cache failures.
463			return false
464		}
465		re = s.regexCache.GetOrSet(pattern, func() *regexp.Regexp {
466			return compiled
467		})
468	}
469
470	if re == nil {
471		return false
472	}
473
474	return re.MatchString(text)
475}