hooks.go

  1package hooks
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/shell"
 12)
 13
 14const DefaultHookTimeout = 30 * time.Second
 15
 16// HookContext contains context information passed to hooks.
 17type HookContext struct {
 18	EventType   config.HookEventType `json:"event_type"`
 19	SessionID   string               `json:"session_id,omitempty"`
 20	ToolName    string               `json:"tool_name,omitempty"`
 21	ToolInput   map[string]any       `json:"tool_input,omitempty"`
 22	ToolResult  string               `json:"tool_result,omitempty"`
 23	ToolError   bool                 `json:"tool_error,omitempty"`
 24	UserPrompt  string               `json:"user_prompt,omitempty"`
 25	Timestamp   time.Time            `json:"timestamp"`
 26	WorkingDir  string               `json:"working_dir,omitempty"`
 27	MessageID   string               `json:"message_id,omitempty"`
 28	Provider    string               `json:"provider,omitempty"`
 29	Model       string               `json:"model,omitempty"`
 30	TokensUsed  int64                `json:"tokens_used,omitempty"`
 31	TokensInput int64                `json:"tokens_input,omitempty"`
 32}
 33
 34// Executor executes hooks based on configuration.
 35type Executor struct {
 36	config     config.HookConfig
 37	workingDir string
 38	shell      *shell.Shell
 39}
 40
 41// NewExecutor creates a new hook executor.
 42func NewExecutor(hookConfig config.HookConfig, workingDir string) *Executor {
 43	shellInst := shell.NewShell(&shell.Options{
 44		WorkingDir: workingDir,
 45	})
 46	return &Executor{
 47		config:     hookConfig,
 48		workingDir: workingDir,
 49		shell:      shellInst,
 50	}
 51}
 52
 53// Execute runs all hooks matching the given event type and context.
 54// Returns the first error encountered, causing subsequent hooks to be skipped.
 55func (e *Executor) Execute(ctx context.Context, hookCtx HookContext) error {
 56	if e.config == nil || e.shell == nil {
 57		return nil
 58	}
 59
 60	hookCtx.Timestamp = time.Now()
 61	hookCtx.WorkingDir = e.workingDir
 62
 63	matchers, ok := e.config[hookCtx.EventType]
 64	if !ok || len(matchers) == 0 {
 65		return nil
 66	}
 67
 68	for _, matcher := range matchers {
 69		if ctx.Err() != nil {
 70			return ctx.Err()
 71		}
 72
 73		if !e.matcherApplies(matcher, hookCtx) {
 74			continue
 75		}
 76
 77		for _, hook := range matcher.Hooks {
 78			if err := e.executeHook(ctx, hook, hookCtx); err != nil {
 79				slog.Warn("Hook execution failed",
 80					"event", hookCtx.EventType,
 81					"matcher", matcher.Matcher,
 82					"error", err,
 83				)
 84				return err
 85			}
 86		}
 87	}
 88
 89	return nil
 90}
 91
 92// matcherApplies checks if a matcher applies to the given context.
 93func (e *Executor) matcherApplies(matcher config.HookMatcher, ctx HookContext) bool {
 94	if matcher.Matcher == "" || matcher.Matcher == "*" {
 95		return true
 96	}
 97
 98	if ctx.EventType == config.PreToolUse || ctx.EventType == config.PostToolUse {
 99		return matcher.Matcher == ctx.ToolName
100	}
101
102	return matcher.Matcher == "" || matcher.Matcher == "*"
103}
104
105// executeHook executes a single hook command.
106func (e *Executor) executeHook(ctx context.Context, hook config.Hook, hookCtx HookContext) error {
107	if hook.Type != "command" {
108		return fmt.Errorf("unsupported hook type: %s", hook.Type)
109	}
110
111	timeout := DefaultHookTimeout
112	if hook.Timeout != nil {
113		timeout = time.Duration(*hook.Timeout) * time.Second
114	}
115
116	execCtx, cancel := context.WithTimeout(ctx, timeout)
117	defer cancel()
118
119	contextJSON, err := json.Marshal(hookCtx)
120	if err != nil {
121		return fmt.Errorf("failed to marshal hook context: %w", err)
122	}
123
124	e.shell.SetEnv("CRUSH_HOOK_EVENT", string(hookCtx.EventType))
125	e.shell.SetEnv("CRUSH_HOOK_CONTEXT", string(contextJSON))
126	if hookCtx.SessionID != "" {
127		e.shell.SetEnv("CRUSH_SESSION_ID", hookCtx.SessionID)
128	}
129	if hookCtx.ToolName != "" {
130		e.shell.SetEnv("CRUSH_TOOL_NAME", hookCtx.ToolName)
131	}
132
133	slog.Debug("Executing hook",
134		"event", hookCtx.EventType,
135		"command", hook.Command,
136		"timeout", timeout,
137	)
138
139	fullCommand := fmt.Sprintf("%s <<'CRUSH_HOOK_EOF'\n%s\nCRUSH_HOOK_EOF\n", hook.Command, string(contextJSON))
140
141	stdout, stderr, err := e.shell.Exec(execCtx, fullCommand)
142	if err != nil {
143		return fmt.Errorf("hook command failed: %w: stdout=%s stderr=%s", err, stdout, stderr)
144	}
145
146	if stdout != "" || stderr != "" {
147		slog.Debug("Hook output",
148			"event", hookCtx.EventType,
149			"stdout", stdout,
150			"stderr", stderr,
151		)
152	}
153
154	return nil
155}