Detailed changes
@@ -30,6 +30,7 @@ import (
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/hooks"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/session"
@@ -83,6 +84,7 @@ type sessionAgent struct {
tools []fantasy.AgentTool
sessions session.Service
messages message.Service
+ hooks hooks.Service
disableAutoSummarize bool
isYolo bool
@@ -99,6 +101,7 @@ type SessionAgentOptions struct {
IsYolo bool
Sessions session.Service
Messages message.Service
+ Hooks hooks.Service
Tools []fantasy.AgentTool
}
@@ -113,6 +116,7 @@ func NewSessionAgent(
sessions: opts.Sessions,
messages: opts.Messages,
disableAutoSummarize: opts.DisableAutoSummarize,
+ hooks: opts.Hooks,
tools: opts.Tools,
isYolo: opts.IsYolo,
messageQueue: csync.NewMap[string, []SessionAgentCall](),
@@ -172,7 +176,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
}
// Add the user message to the session.
- _, err = a.createUserMessage(ctx, call)
+ userMsg, err := a.createUserMessage(ctx, call)
if err != nil {
return nil, err
}
@@ -182,19 +186,52 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
genCtx, cancel := context.WithCancel(ctx)
a.activeRequests.Set(call.SessionID, cancel)
-
defer cancel()
defer a.activeRequests.Del(call.SessionID)
+ // create the agent message asap to show loading
+ var currentAssistant *message.Message
+ assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
+ Model: a.largeModel.ModelCfg.Model,
+ Provider: a.largeModel.ModelCfg.Provider,
+ })
+ if err != nil {
+ return nil, err
+ }
+ currentAssistant = &assistantMessage
+
+ // run hooks after assistant message is created
+ // this way we show loading for the user
+ hooks, err := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
+ if err != nil {
+ return nil, err
+ }
+ userMsg.AddHookOutputs(hooks...)
+ if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
+ return nil, updateErr
+ }
+
+ for _, hook := range hooks {
+ // execution stopped
+ if hook.Stop {
+ deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
+ if deleteErr != nil {
+ return nil, deleteErr
+ }
+ return nil, nil
+ }
+ }
+
history, files := a.preparePrompt(msgs, call.Attachments...)
startTime := time.Now()
a.eventPromptSent(call.SessionID)
- var currentAssistant *message.Message
var shouldSummarize bool
result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
- Prompt: call.Prompt,
+ Prompt: userMsg.ContentWithHooksContext(),
Files: files,
Messages: history,
ProviderOptions: call.ProviderOptions,
@@ -206,6 +243,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
FrequencyPenalty: call.FrequencyPenalty,
// Before each step create a new assistant message.
PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
+ // only add new assistant message when its not the first step
+ if options.StepNumber != 0 {
+ var assistantMsg message.Message
+ assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
+ Model: a.largeModel.ModelCfg.Model,
+ Provider: a.largeModel.ModelCfg.Provider,
+ })
+ currentAssistant = &assistantMsg
+ // create the message first so we show loading asap
+ if err != nil {
+ return callContext, prepared, err
+ }
+ }
+
prepared.Messages = options.Messages
// Reset all cached items.
for i := range prepared.Messages {
@@ -219,6 +272,27 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
if createErr != nil {
return callContext, prepared, createErr
}
+
+ // run hooks after assistant message is created
+ // this way we show loading for the user
+ hooks, hookErr := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
+ if hookErr != nil {
+ return callContext, prepared, hookErr
+ }
+ userMsg.AddHookOutputs(hooks...)
+ for _, hook := range hooks {
+ // execution stopped
+ if hook.Stop {
+ deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
+ if deleteErr != nil {
+ return callContext, prepared, deleteErr
+ }
+ return callContext, prepared, ErrHookCancellation
+ }
+ }
+ if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
+ return callContext, prepared, updateErr
+ }
prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
}
@@ -242,18 +316,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
}
- var assistantMsg message.Message
- assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- Model: a.largeModel.ModelCfg.Model,
- Provider: a.largeModel.ModelCfg.Provider,
- })
- if err != nil {
- return callContext, prepared, err
- }
- callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
- currentAssistant = &assistantMsg
+ callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
return callContext, prepared, err
},
OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
@@ -400,6 +463,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
if err != nil {
isCancelErr := errors.Is(err, context.Canceled)
isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
+ isHookCancelled := errors.Is(err, ErrHookCancellation)
if currentAssistant == nil {
return result, err
}
@@ -468,6 +532,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
} else if isPermissionErr {
currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
+ } else if isHookCancelled {
+ currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook canceled request", "")
} else if errors.As(err, &providerErr) {
currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
} else if errors.As(err, &fantasyErr) {
@@ -882,3 +948,20 @@ func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
func (a *sessionAgent) Model() Model {
return a.largeModel
}
+
+func (a *sessionAgent) executeUserPromptSubmitHook(
+ ctx context.Context,
+ sessionID string,
+ prompt string,
+) ([]message.HookOutput, error) {
+ if a.hooks == nil {
+ return nil, nil
+ }
+ return a.hooks.Execute(ctx, hooks.HookContext{
+ EventType: config.UserPromptSubmit,
+ SessionID: sessionID,
+ UserPrompt: prompt,
+ Provider: a.largeModel.ModelCfg.Provider,
+ Model: a.largeModel.ModelCfg.Model,
+ })
+}
@@ -34,7 +34,7 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error)
return nil, err
}
- agent, err := c.buildAgent(ctx, prompt, agentCfg)
+ agent, err := c.buildAgent(ctx, prompt, agentCfg, nil)
if err != nil {
return nil, err
}
@@ -149,7 +149,7 @@ func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPro
DefaultMaxTokens: 10000,
},
}
- agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
+ agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, nil, tools})
return agent
}
@@ -21,6 +21,7 @@ import (
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/history"
+ "github.com/charmbracelet/crush/internal/hooks"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
@@ -60,6 +61,7 @@ type coordinator struct {
messages message.Service
permissions permission.Service
history history.Service
+ hooks hooks.Service
lspClients *csync.Map[string, *lsp.Client]
currentAgent SessionAgent
@@ -76,6 +78,7 @@ func NewCoordinator(
permissions permission.Service,
history history.Service,
lspClients *csync.Map[string, *lsp.Client],
+ hooks hooks.Service,
) (Coordinator, error) {
c := &coordinator{
cfg: cfg,
@@ -98,7 +101,7 @@ func NewCoordinator(
return nil, err
}
- agent, err := c.buildAgent(ctx, prompt, agentCfg)
+ agent, err := c.buildAgent(ctx, prompt, agentCfg, hooks)
if err != nil {
return nil, err
}
@@ -274,7 +277,7 @@ func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderO
return modelOptions, temp, topP, topK, freqPenalty, presPenalty
}
-func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
+func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, hooks hooks.Service) (SessionAgent, error) {
large, small, err := c.buildAgentModels(ctx)
if err != nil {
return nil, err
@@ -286,6 +289,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age
}
largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
+
result := NewSessionAgent(SessionAgentOptions{
large,
small,
@@ -295,6 +299,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age
c.permissions.SkipRequests(),
c.sessions,
c.messages,
+ hooks,
nil,
})
c.readyWg.Go(func() error {
@@ -10,6 +10,7 @@ var (
ErrSessionBusy = errors.New("session is currently processing another request")
ErrEmptyPrompt = errors.New("prompt is empty")
ErrSessionMissing = errors.New("session id is missing")
+ ErrHookCancellation = errors.New("hook cancelled the agent")
)
func isCancelledErr(err error) bool {
@@ -23,6 +23,7 @@ import (
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/format"
"github.com/charmbracelet/crush/internal/history"
+ "github.com/charmbracelet/crush/internal/hooks"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
@@ -42,6 +43,7 @@ type App struct {
Messages message.Service
History history.Service
Permissions permission.Service
+ Hooks hooks.Service
AgentCoordinator agent.Coordinator
@@ -71,10 +73,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
allowedTools = cfg.Permissions.AllowedTools
}
+ hooks := hooks.NewService(cfg.Hooks, cfg.WorkingDir(), nil, messages)
+
app := &App{
Sessions: sessions,
Messages: messages,
History: files,
+ Hooks: hooks,
Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
LSPClients: csync.NewMap[string, *lsp.Client](),
@@ -323,6 +328,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
app.Permissions,
app.History,
app.LSPClients,
+ app.Hooks,
)
if err != nil {
slog.Error("Failed to create coder agent", "err", err)
@@ -326,6 +326,8 @@ type Config struct {
Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"`
+ Hooks HookConfig `json:"hooks,omitempty" jsonschema:"description=Hook configurations for lifecycle events"`
+
Agents map[string]Agent `json:"-"`
// Internal
@@ -355,6 +357,14 @@ func (c *Config) IsConfigured() bool {
return len(c.EnabledProviders()) > 0
}
+// validateHooks validates the hooks configuration.
+func (c *Config) validateHooks() error {
+ if c.Hooks == nil {
+ return nil
+ }
+ return c.Hooks.Validate()
+}
+
func (c *Config) GetModel(provider, model string) *catwalk.Model {
if providerConfig, ok := c.Providers.Get(provider); ok {
for _, m := range providerConfig.Models {
@@ -0,0 +1,100 @@
+package config
+
+import (
+ "fmt"
+ "log/slog"
+)
+
+// HookEventType represents the lifecycle event when a hook should run.
+type HookEventType string
+
+const (
+ // PreToolUse runs before tool calls and can block them.
+ PreToolUse HookEventType = "pre_tool_use"
+ // PostToolUse runs after tool calls complete.
+ PostToolUse HookEventType = "post_tool_use"
+ // UserPromptSubmit runs when the user submits a prompt, before processing.
+ UserPromptSubmit HookEventType = "user_prompt_submit"
+ // Stop runs when Crush finishes responding.
+ Stop HookEventType = "stop"
+ // SubagentStop runs when subagent tasks complete.
+ SubagentStop HookEventType = "subagent_stop"
+ // PreCompact runs before running a compact operation.
+ PreCompact HookEventType = "pre_compact"
+ // PermissionRequested runs when a permission is requested from the user.
+ PermissionRequested HookEventType = "permission_requested"
+)
+
+// Hook represents a single hook command configuration.
+type Hook struct {
+ // Type is the hook type: "command" or "prompt".
+ Type string `json:"type" jsonschema:"description=Hook type,enum=command,enum=prompt,default=command"`
+ // Command is the shell command to execute (for type: "command").
+ // WARNING: Hook commands execute with Crush's full permissions. Only use trusted commands.
+ Command string `json:"command,omitempty" jsonschema:"description=Shell command to execute for this hook (executes with Crush's permissions),example=echo 'Hook executed'"`
+ // Prompt is the LLM prompt to execute (for type: "prompt").
+ // Use $ARGUMENTS placeholder to include hook context JSON.
+ Prompt string `json:"prompt,omitempty" jsonschema:"description=LLM prompt for intelligent decision making,example=Analyze if all tasks are complete. Context: $ARGUMENTS. Return JSON with decision and reason."`
+ // Timeout is the maximum time in seconds to wait for the hook to complete.
+ // Default is 30 seconds.
+ Timeout *int `json:"timeout,omitempty" jsonschema:"description=Maximum time in seconds to wait for hook completion,default=30,minimum=1,maximum=300"`
+}
+
+// Validate checks hook configuration invariants.
+func (h *Hook) Validate() error {
+ switch h.Type {
+ case "prompt":
+ if h.Prompt == "" {
+ return fmt.Errorf("prompt-based hook missing 'prompt' field")
+ }
+ case "", "command":
+ if h.Command == "" {
+ return fmt.Errorf("command-based hook missing 'command' field")
+ }
+ default:
+ return fmt.Errorf("unsupported hook type: %s", h.Type)
+ }
+ if h.Timeout != nil {
+ if *h.Timeout < 1 {
+ slog.Warn("Hook timeout too low, using minimum",
+ "configured", *h.Timeout, "minimum", 1)
+ v := 1
+ h.Timeout = &v
+ }
+ if *h.Timeout > 300 {
+ slog.Warn("Hook timeout too high, using maximum",
+ "configured", *h.Timeout, "maximum", 300)
+ v := 300
+ h.Timeout = &v
+ }
+ }
+ return nil
+}
+
+// HookMatcher represents a matcher for a specific event type.
+type HookMatcher struct {
+ // Matcher is the tool name or pattern to match (for tool events).
+ // For non-tool events, this can be empty or "*" to match all.
+ // Supports pipe-separated tool names like "edit|write|multiedit".
+ Matcher string `json:"matcher,omitempty" jsonschema:"description=Tool name or pattern to match (e.g. 'bash' 'edit|write' for multiple or '*' for all),example=bash,example=edit|write|multiedit,example=*"`
+ // Hooks is the list of hooks to execute when the matcher matches.
+ Hooks []Hook `json:"hooks" jsonschema:"required,description=List of hooks to execute when matcher matches"`
+}
+
+// HookConfig holds the complete hook configuration.
+type HookConfig map[HookEventType][]HookMatcher
+
+// Validate validates the entire hook configuration.
+func (c HookConfig) Validate() error {
+ for eventType, matchers := range c {
+ for i, matcher := range matchers {
+ for j := range matcher.Hooks {
+ if err := matcher.Hooks[j].Validate(); err != nil {
+ return fmt.Errorf("invalid hook config for %s matcher %d hook %d: %w",
+ eventType, i, j, err)
+ }
+ }
+ }
+ }
+ return nil
+}
@@ -81,6 +81,11 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
}
cfg.knownProviders = providers
+ // Validate hooks configuration.
+ if err := cfg.validateHooks(); err != nil {
+ return nil, fmt.Errorf("invalid hooks configuration: %w", err)
+ }
+
env := env.New()
// Configure providers
valueResolver := NewShellVariableResolver(env)
@@ -19,12 +19,13 @@ INSERT INTO messages (
model,
provider,
is_summary_message,
+ hook_outputs,
created_at,
updated_at
) VALUES (
- ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
+ ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
-RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs
`
type CreateMessageParams struct {
@@ -35,6 +36,7 @@ type CreateMessageParams struct {
Model sql.NullString `json:"model"`
Provider sql.NullString `json:"provider"`
IsSummaryMessage int64 `json:"is_summary_message"`
+ HookOutputs string `json:"hook_outputs"`
}
func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
@@ -46,6 +48,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
arg.Model,
arg.Provider,
arg.IsSummaryMessage,
+ arg.HookOutputs,
)
var i Message
err := row.Scan(
@@ -59,6 +62,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
&i.FinishedAt,
&i.Provider,
&i.IsSummaryMessage,
+ &i.HookOutputs,
)
return i, err
}
@@ -84,7 +88,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e
}
const getMessage = `-- name: GetMessage :one
-SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs
FROM messages
WHERE id = ? LIMIT 1
`
@@ -103,12 +107,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
&i.FinishedAt,
&i.Provider,
&i.IsSummaryMessage,
+ &i.HookOutputs,
)
return i, err
}
const listMessagesBySession = `-- name: ListMessagesBySession :many
-SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs
FROM messages
WHERE session_id = ?
ORDER BY created_at ASC
@@ -134,6 +139,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
&i.FinishedAt,
&i.Provider,
&i.IsSummaryMessage,
+ &i.HookOutputs,
); err != nil {
return nil, err
}
@@ -153,17 +159,24 @@ UPDATE messages
SET
parts = ?,
finished_at = ?,
+ hook_outputs = ?,
updated_at = strftime('%s', 'now')
WHERE id = ?
`
type UpdateMessageParams struct {
- Parts string `json:"parts"`
- FinishedAt sql.NullInt64 `json:"finished_at"`
- ID string `json:"id"`
+ Parts string `json:"parts"`
+ FinishedAt sql.NullInt64 `json:"finished_at"`
+ HookOutputs string `json:"hook_outputs"`
+ ID string `json:"id"`
}
func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error {
- _, err := q.exec(ctx, q.updateMessageStmt, updateMessage, arg.Parts, arg.FinishedAt, arg.ID)
+ _, err := q.exec(ctx, q.updateMessageStmt, updateMessage,
+ arg.Parts,
+ arg.FinishedAt,
+ arg.HookOutputs,
+ arg.ID,
+ )
return err
}
@@ -0,0 +1,5 @@
+-- +goose Up
+ALTER TABLE messages ADD COLUMN hook_outputs TEXT DEFAULT '[]' NOT NULL;
+
+-- +goose Down
+ALTER TABLE messages DROP COLUMN hook_outputs;
@@ -29,6 +29,7 @@ type Message struct {
FinishedAt sql.NullInt64 `json:"finished_at"`
Provider sql.NullString `json:"provider"`
IsSummaryMessage int64 `json:"is_summary_message"`
+ HookOutputs string `json:"hook_outputs"`
}
type Session struct {
@@ -18,10 +18,11 @@ INSERT INTO messages (
model,
provider,
is_summary_message,
+ hook_outputs,
created_at,
updated_at
) VALUES (
- ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
+ ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
RETURNING *;
@@ -30,6 +31,7 @@ UPDATE messages
SET
parts = ?,
finished_at = ?,
+ hook_outputs = ?,
updated_at = strftime('%s', 'now')
WHERE id = ?;
@@ -0,0 +1,415 @@
+package hooks
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "regexp"
+ "strings"
+ "sync"
+ "time"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/shell"
+)
+
+const DefaultHookTimeout = 30 * time.Second
+
+// HookContext contains context information passed to hooks.
+// Fields are populated based on event type - see field comments for availability.
+type HookContext struct {
+ // EventType is the lifecycle event that triggered this hook.
+ // Available: All events
+ EventType config.HookEventType `json:"event_type"`
+
+ // SessionID identifies the Crush session.
+ // Available: All events (if session exists)
+ SessionID string `json:"session_id,omitempty"`
+
+ // TranscriptPath is the path to exported session transcript JSON file.
+ // Available: All events (if session exists and export succeeds)
+ TranscriptPath string `json:"transcript_path,omitempty"`
+
+ // ToolName is the name of the tool being called.
+ // Available: pre_tool_use, post_tool_use, permission_requested
+ ToolName string `json:"tool_name,omitempty"`
+
+ // ToolInput contains the tool's input parameters as key-value pairs.
+ // Available: pre_tool_use, post_tool_use
+ ToolInput map[string]any `json:"tool_input,omitempty"`
+
+ // ToolResult is the string result returned by the tool.
+ // Available: post_tool_use
+ ToolResult string `json:"tool_result,omitempty"`
+
+ // ToolError indicates whether the tool execution failed.
+ // Available: post_tool_use
+ ToolError bool `json:"tool_error,omitempty"`
+
+ // UserPrompt is the prompt submitted by the user.
+ // Available: user_prompt_submit
+ UserPrompt string `json:"user_prompt,omitempty"`
+
+ // Timestamp is when the hook was triggered (RFC3339 format).
+ // Available: All events
+ Timestamp time.Time `json:"timestamp"`
+
+ // WorkingDir is the project working directory.
+ // Available: All events
+ WorkingDir string `json:"working_dir,omitempty"`
+
+ // MessageID identifies the assistant message being processed.
+ // Available: pre_tool_use, post_tool_use, stop
+ MessageID string `json:"message_id,omitempty"`
+
+ // Provider is the LLM provider name (e.g., "anthropic").
+ // Available: All events with LLM interaction
+ Provider string `json:"provider,omitempty"`
+
+ // Model is the LLM model name (e.g., "claude-3-5-sonnet-20241022").
+ // Available: All events with LLM interaction
+ Model string `json:"model,omitempty"`
+
+ // TokensUsed is the total tokens consumed in this interaction.
+ // Available: stop
+ TokensUsed int64 `json:"tokens_used,omitempty"`
+
+ // TokensInput is the input tokens consumed in this interaction.
+ // Available: stop
+ TokensInput int64 `json:"tokens_input,omitempty"`
+
+ // PermissionAction is the permission action being requested (e.g., "read", "write").
+ // Available: permission_requested
+ PermissionAction string `json:"permission_action,omitempty"`
+
+ // PermissionPath is the file path involved in the permission request.
+ // Available: permission_requested
+ PermissionPath string `json:"permission_path,omitempty"`
+
+ // PermissionParams contains additional permission parameters.
+ // Available: permission_requested
+ PermissionRequest *permission.PermissionRequest `json:"permission_request,omitempty"`
+}
+
+type HookDecision string
+
+const (
+ HookDecisionBlock HookDecision = "block"
+ HookDecisionDeny HookDecision = "deny"
+ HookDecisionAllow HookDecision = "allow"
+ HookDecisionAsk HookDecision = "ask"
+)
+
+type Service interface {
+ Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error)
+ SetSmallModel(model fantasy.LanguageModel)
+}
+
+type service struct {
+ config config.HookConfig
+ workingDir string
+ regexCache *csync.Map[string, *regexp.Regexp]
+ smallModel fantasy.LanguageModel
+ messages message.Service
+}
+
+// NewService creates a new hook executor.
+func NewService(
+ hookConfig config.HookConfig,
+ workingDir string,
+ smallModel fantasy.LanguageModel,
+ messages message.Service,
+) Service {
+ return &service{
+ config: hookConfig,
+ workingDir: workingDir,
+ regexCache: csync.NewMap[string, *regexp.Regexp](),
+ smallModel: smallModel,
+ messages: messages,
+ }
+}
+
+// Execute implements Service.
+func (s *service) Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error) {
+ if s.config == nil {
+ return nil, nil
+ }
+
+ // Check if context is already cancelled - prevents race conditions during cancellation.
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+
+ hookCtx.Timestamp = time.Now()
+ hookCtx.WorkingDir = s.workingDir
+
+ hooks := s.collectMatchingHooks(hookCtx)
+ if len(hooks) == 0 {
+ return nil, nil
+ }
+
+ transcriptPath, cleanup, err := s.setupTranscript(ctx, hookCtx)
+ if err != nil {
+ slog.Warn("Failed to export transcript for hooks", "error", err)
+ } else if transcriptPath != "" {
+ hookCtx.TranscriptPath = transcriptPath
+ // Ensure cleanup happens even on panic.
+ defer func() {
+ cleanup()
+ }()
+ }
+
+ results := make([]message.HookOutput, len(hooks))
+ var wg sync.WaitGroup
+
+ for i, hook := range hooks {
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+ wg.Add(1)
+ go func(idx int, h config.Hook) {
+ defer wg.Done()
+ result, err := s.executeHook(ctx, h, hookCtx)
+ if err != nil {
+ slog.Warn("Hook execution failed",
+ "event", hookCtx.EventType,
+ "error", err,
+ )
+ } else {
+ results[idx] = *result
+ }
+ }(i, hook)
+ }
+ wg.Wait()
+ return results, nil
+}
+
+func (s *service) setupTranscript(ctx context.Context, hookCtx HookContext) (string, func(), error) {
+ if hookCtx.SessionID == "" || s.messages == nil {
+ return "", func() {}, nil
+ }
+
+ path, err := exportTranscript(ctx, s.messages, hookCtx.SessionID)
+ if err != nil {
+ return "", func() {}, err
+ }
+
+ cleanup := func() {
+ if path != "" {
+ cleanupTranscript(path)
+ }
+ }
+
+ return path, cleanup, nil
+}
+
+func (s *service) executeHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
+ var result *message.HookOutput
+ var err error
+
+ switch hook.Type {
+ case "prompt":
+ if hook.Prompt == "" {
+ return nil, fmt.Errorf("prompt-based hook missing 'prompt' field")
+ }
+ if s.smallModel == nil {
+ return nil, fmt.Errorf("prompt-based hook requires small model configuration")
+ }
+ result, err = s.executePromptHook(ctx, hook, hookCtx)
+ case "", "command":
+ if hook.Command == "" {
+ return nil, fmt.Errorf("command-based hook missing 'command' field")
+ }
+ slog.Info("executing")
+ result, err = s.executeCommandHook(ctx, hook, hookCtx)
+ default:
+ return nil, fmt.Errorf("unsupported hook type: %s", hook.Type)
+ }
+
+ return result, err
+}
+
+func (s *service) executePromptHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
+ panic("not implemented")
+}
+
+func (s *service) executeCommandHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
+ timeout := DefaultHookTimeout
+ if hook.Timeout != nil {
+ timeout = time.Duration(*hook.Timeout) * time.Second
+ }
+
+ execCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ contextJSON, err := json.Marshal(hookCtx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal hook context: %w", err)
+ }
+
+ sh := shell.NewShell(&shell.Options{
+ WorkingDir: s.workingDir,
+ })
+
+ sh.SetEnv("CRUSH_HOOK_EVENT", string(hookCtx.EventType))
+ sh.SetEnv("CRUSH_HOOK_CONTEXT", string(contextJSON))
+ sh.SetEnv("CRUSH_PROJECT_DIR", s.workingDir)
+ if hookCtx.SessionID != "" {
+ sh.SetEnv("CRUSH_SESSION_ID", hookCtx.SessionID)
+ }
+ if hookCtx.ToolName != "" {
+ sh.SetEnv("CRUSH_TOOL_NAME", hookCtx.ToolName)
+ }
+
+ slog.Debug("Hook execution trace",
+ "event", hookCtx.EventType,
+ "command", hook.Command,
+ "timeout", timeout,
+ "context", hookCtx,
+ )
+
+ stdout, stderr, err := sh.ExecWithStdin(execCtx, hook.Command, string(contextJSON))
+
+ exitCode := shell.ExitCode(err)
+
+ var result *message.HookOutput
+ switch exitCode {
+ case 0:
+ // if the event is UserPromptSubmit we want the output to be added to the context
+ if hookCtx.EventType == config.UserPromptSubmit {
+ result = &message.HookOutput{
+ AdditionalContext: stdout,
+ }
+ } else {
+ result = &message.HookOutput{
+ Message: stdout,
+ }
+ }
+ case 2:
+ result = &message.HookOutput{
+ Decision: string(HookDecisionBlock),
+ Error: stderr,
+ }
+ return result, nil
+ default:
+ result = &message.HookOutput{
+ Error: stderr,
+ }
+ return result, nil
+ }
+
+ jsonOutput := parseHookOutput(stdout)
+ if jsonOutput == nil {
+ return result, nil
+ }
+
+ result.Message = jsonOutput.Message
+ result.Stop = jsonOutput.Stop
+ result.Decision = jsonOutput.Decision
+ result.AdditionalContext = jsonOutput.AdditionalContext
+ result.UpdatedInput = jsonOutput.UpdatedInput
+
+ // Trace output in debug mode
+ slog.Debug("Hook execution output",
+ "event", hookCtx.EventType,
+ "exit_code", exitCode,
+ "stdout_length", len(stdout),
+ "stderr_length", len(stderr),
+ "stdout", stdout,
+ "stderr", stderr,
+ )
+ return result, nil
+}
+
+func parseHookOutput(stdout string) *message.HookOutput {
+ stdout = strings.TrimSpace(stdout)
+ slog.Info(stdout)
+ if stdout == "" {
+ return nil
+ }
+
+ var output message.HookOutput
+ if err := json.Unmarshal([]byte(stdout), &output); err != nil {
+ // Failed to parse as HookOutput
+ return nil
+ }
+
+ return &output
+}
+
+// SetSmallModel implements Service.
+func (s *service) SetSmallModel(model fantasy.LanguageModel) {
+ panic("unimplemented")
+}
+
+func (s *service) collectMatchingHooks(hookCtx HookContext) []config.Hook {
+ matchers, ok := s.config[hookCtx.EventType]
+ if !ok || len(matchers) == 0 {
+ return nil
+ }
+
+ var hooks []config.Hook
+ for _, matcher := range matchers {
+ if !s.matcherApplies(matcher, hookCtx) {
+ continue
+ }
+ hooks = append(hooks, matcher.Hooks...)
+ }
+ return hooks
+}
+
+func (s *service) matcherApplies(matcher config.HookMatcher, ctx HookContext) bool {
+ if ctx.EventType == config.PreToolUse || ctx.EventType == config.PostToolUse {
+ return s.matchesToolName(matcher.Matcher, ctx.ToolName)
+ }
+
+ return matcher.Matcher == "" || matcher.Matcher == "*"
+}
+
+func (s *service) matchesToolName(pattern, toolName string) bool {
+ if pattern == "" || pattern == "*" {
+ return true
+ }
+
+ if pattern == toolName {
+ return true
+ }
+
+ if strings.Contains(pattern, "|") {
+ for tool := range strings.SplitSeq(pattern, "|") {
+ tool = strings.TrimSpace(tool)
+ if tool == toolName {
+ return true
+ }
+ }
+
+ return s.matchesRegex(pattern, toolName)
+ }
+
+ return s.matchesRegex(pattern, toolName)
+}
+
+func (s *service) matchesRegex(pattern, text string) bool {
+ re, ok := s.regexCache.Get(pattern)
+ if !ok {
+ compiled, err := regexp.Compile(pattern)
+ if err != nil {
+ // Not a valid regex, don't cache failures.
+ return false
+ }
+ re = s.regexCache.GetOrSet(pattern, func() *regexp.Regexp {
+ return compiled
+ })
+ }
+
+ if re == nil {
+ return false
+ }
+
+ return re.MatchString(text)
+}
@@ -0,0 +1,133 @@
+package hooks
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/message"
+)
+
+// TranscriptMessage represents a message in the exported transcript.
+type TranscriptMessage struct {
+ ID string `json:"id"`
+ Role string `json:"role"`
+ Content string `json:"content,omitempty"`
+ ToolCalls []TranscriptToolCall `json:"tool_calls,omitempty"`
+ ToolResults []TranscriptToolResult `json:"tool_results,omitempty"`
+ Timestamp string `json:"timestamp"`
+}
+
+// TranscriptToolCall represents a tool call in the transcript.
+type TranscriptToolCall struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input string `json:"input"`
+}
+
+// TranscriptToolResult represents a tool result in the transcript.
+type TranscriptToolResult struct {
+ ToolCallID string `json:"tool_call_id"`
+ Name string `json:"name"`
+ Content string `json:"content"`
+ IsError bool `json:"is_error"`
+}
+
+// Transcript represents the complete transcript structure.
+type Transcript struct {
+ SessionID string `json:"session_id"`
+ Messages []TranscriptMessage `json:"messages"`
+}
+
+// exportTranscript exports session messages to a temporary JSON file.
+func exportTranscript(
+ ctx context.Context,
+ messages message.Service,
+ sessionID string,
+) (string, error) {
+ // Get all messages for the session
+ msgs, err := messages.List(ctx, sessionID)
+ if err != nil {
+ return "", fmt.Errorf("failed to list messages: %w", err)
+ }
+
+ // Convert to transcript format
+ transcript := Transcript{
+ SessionID: sessionID,
+ Messages: make([]TranscriptMessage, 0, len(msgs)),
+ }
+
+ for _, msg := range msgs {
+ tm := TranscriptMessage{
+ ID: msg.ID,
+ Role: string(msg.Role),
+ Timestamp: time.Unix(msg.CreatedAt, 0).Format("2006-01-02T15:04:05Z07:00"),
+ }
+
+ // Extract content
+ for _, part := range msg.Parts {
+ if text, ok := part.(message.TextContent); ok {
+ if tm.Content != "" {
+ tm.Content += "\n"
+ }
+ tm.Content += text.Text
+ }
+ }
+
+ // Extract tool calls
+ if msg.Role == message.Assistant {
+ toolCalls := msg.ToolCalls()
+ for _, tc := range toolCalls {
+ tm.ToolCalls = append(tm.ToolCalls, TranscriptToolCall{
+ ID: tc.ID,
+ Name: tc.Name,
+ Input: tc.Input,
+ })
+ }
+ }
+
+ // Extract tool results
+ if msg.Role == message.Tool {
+ toolResults := msg.ToolResults()
+ for _, tr := range toolResults {
+ tm.ToolResults = append(tm.ToolResults, TranscriptToolResult{
+ ToolCallID: tr.ToolCallID,
+ Name: tr.Name,
+ Content: tr.Content,
+ IsError: tr.IsError,
+ })
+ }
+ }
+
+ transcript.Messages = append(transcript.Messages, tm)
+ }
+
+ // Marshal to JSON
+ data, err := json.MarshalIndent(transcript, "", " ")
+ if err != nil {
+ return "", fmt.Errorf("failed to marshal transcript: %w", err)
+ }
+
+ // Write to temporary file
+ tmpDir := os.TempDir()
+ filename := fmt.Sprintf("crush-transcript-%s.json", sessionID)
+ path := filepath.Join(tmpDir, filename)
+
+ // Use restrictive permissions (0600)
+ if err := os.WriteFile(path, data, 0o600); err != nil {
+ return "", fmt.Errorf("failed to write transcript file: %w", err)
+ }
+
+ return path, nil
+}
+
+func cleanupTranscript(path string) {
+ if path != "" {
+ if err := os.Remove(path); err != nil {
+ fmt.Fprintf(os.Stderr, "Warning: failed to cleanup transcript file %s: %v\n", path, err)
+ }
+ }
+}
@@ -133,6 +133,16 @@ type Message struct {
CreatedAt int64
UpdatedAt int64
IsSummaryMessage bool
+ HookOutputs []HookOutput
+}
+
+type HookOutput struct {
+ Stop bool `json:"stop"`
+ Error string `json:"error"`
+ Message string `json:"message"`
+ Decision string `json:"decision"`
+ UpdatedInput string `json:"updated_input"`
+ AdditionalContext string `json:"additional_context"`
}
func (m *Message) Content() TextContent {
@@ -144,6 +154,23 @@ func (m *Message) Content() TextContent {
return TextContent{}
}
+func (m *Message) ContentWithHooksContext() string {
+ text := strings.TrimSpace(m.Content().Text)
+
+ var additionalContext []string
+ for _, hookOutput := range m.HookOutputs {
+ context := strings.TrimSpace(hookOutput.AdditionalContext)
+ if context != "" {
+ additionalContext = append(additionalContext, context)
+ }
+ }
+ if len(additionalContext) > 0 {
+ text += "## Additional Context\n"
+ text += strings.Join(additionalContext, "\n")
+ }
+ return text
+}
+
func (m *Message) ReasoningContent() ReasoningContent {
for _, part := range m.Parts {
if c, ok := part.(ReasoningContent); ok {
@@ -202,6 +229,11 @@ func (m *Message) IsFinished() bool {
return false
}
+// AddHookOutputs appends multiple hook outputs to the message's hook outputs.
+func (m *Message) AddHookOutputs(outputs ...HookOutput) {
+ m.HookOutputs = append(m.HookOutputs, outputs...)
+}
+
func (m *Message) FinishPart() *Finish {
for _, part := range m.Parts {
if c, ok := part.(Finish); ok {
@@ -429,7 +461,7 @@ func (m *Message) ToAIMessage() []fantasy.Message {
switch m.Role {
case User:
var parts []fantasy.MessagePart
- text := strings.TrimSpace(m.Content().Text)
+ text := strings.TrimSpace(m.ContentWithHooksContext())
if text != "" {
parts = append(parts, fantasy.TextPart{Text: text})
}
@@ -18,6 +18,7 @@ type CreateMessageParams struct {
Model string
Provider string
IsSummaryMessage bool
+ HookOutputs []HookOutput
}
type Service interface {
@@ -69,6 +70,10 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
if params.IsSummaryMessage {
isSummary = 1
}
+ hookOutputsJSON, err := json.Marshal(params.HookOutputs)
+ if err != nil {
+ return Message{}, err
+ }
dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
ID: uuid.New().String(),
SessionID: sessionID,
@@ -77,6 +82,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
Model: sql.NullString{String: string(params.Model), Valid: true},
Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
IsSummaryMessage: isSummary,
+ HookOutputs: string(hookOutputsJSON),
})
if err != nil {
return Message{}, err
@@ -115,10 +121,15 @@ func (s *service) Update(ctx context.Context, message Message) error {
finishedAt.Int64 = f.Time
finishedAt.Valid = true
}
+ hookOutputsJSON, err := json.Marshal(message.HookOutputs)
+ if err != nil {
+ return err
+ }
err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
- ID: message.ID,
- Parts: string(parts),
- FinishedAt: finishedAt,
+ ID: message.ID,
+ Parts: string(parts),
+ FinishedAt: finishedAt,
+ HookOutputs: string(hookOutputsJSON),
})
if err != nil {
return err
@@ -156,6 +167,12 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
if err != nil {
return Message{}, err
}
+ var hookOutputs []HookOutput
+ if item.HookOutputs != "" {
+ if err := json.Unmarshal([]byte(item.HookOutputs), &hookOutputs); err != nil {
+ return Message{}, err
+ }
+ }
return Message{
ID: item.ID,
SessionID: item.SessionID,
@@ -166,6 +183,7 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
IsSummaryMessage: item.IsSummaryMessage != 0,
+ HookOutputs: hookOutputs,
}, nil
}
@@ -111,6 +111,15 @@ func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr i
return s.execStream(ctx, command, stdout, stderr)
}
+func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin string) (string, string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var stdout, stderr bytes.Buffer
+ err := s.execCommonWithStdin(ctx, command, strings.NewReader(stdin), &stdout, &stderr)
+ return stdout.String(), stderr.String(), err
+}
+
// GetWorkingDir returns the current working directory
func (s *Shell) GetWorkingDir() string {
s.mu.Lock()
@@ -237,9 +246,9 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand
}
// newInterp creates a new interpreter with the current shell state
-func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) {
+func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
return interp.New(
- interp.StdIO(nil, stdout, stderr),
+ interp.StdIO(stdin, stdout, stderr),
interp.Interactive(false),
interp.Env(expand.ListEnviron(s.env...)),
interp.Dir(s.cwd),
@@ -263,7 +272,7 @@ func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr i
return fmt.Errorf("could not parse command: %w", err)
}
- runner, err := s.newInterp(stdout, stderr)
+ runner, err := s.newInterp(nil, stdout, stderr)
if err != nil {
return fmt.Errorf("could not run command: %w", err)
}
@@ -286,6 +295,24 @@ func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr i
return s.execCommon(ctx, command, stdout, stderr)
}
+// execCommonWithStdin is like execCommon but with stdin support
+func (s *Shell) execCommonWithStdin(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error {
+ line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
+ if err != nil {
+ return fmt.Errorf("could not parse command: %w", err)
+ }
+
+ runner, err := s.newInterp(stdin, stdout, stderr)
+ if err != nil {
+ return fmt.Errorf("could not run command: %w", err)
+ }
+
+ err = runner.Run(ctx, line)
+ s.updateShellFromRunner(runner)
+ s.logger.InfoPersist("command finished", "command", command, "err", err)
+ return err
+}
+
func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
s.blockHandler(),