wip: hooks implementation

Kujtim Hoxha created

Change summary

internal/agent/agent.go                                                | 115 
internal/agent/agent_tool.go                                           |   2 
internal/agent/common_test.go                                          |   2 
internal/agent/coordinator.go                                          |   9 
internal/agent/errors.go                                               |   1 
internal/app/app.go                                                    |   6 
internal/config/config.go                                              |  10 
internal/config/hooks.go                                               | 100 
internal/config/load.go                                                |   5 
internal/db/messages.sql.go                                            |  29 
internal/db/migrations/20250811000000_add_hook_outputs_to_messages.sql |   5 
internal/db/models.go                                                  |   1 
internal/db/sql/messages.sql                                           |   4 
internal/hooks/hooks.go                                                | 415 
internal/hooks/transcript.go                                           | 133 
internal/message/content.go                                            |  34 
internal/message/message.go                                            |  24 
internal/shell/shell.go                                                |  33 
18 files changed, 892 insertions(+), 36 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -30,6 +30,7 @@ import (
 	"github.com/charmbracelet/crush/internal/agent/tools"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/charmbracelet/crush/internal/hooks"
 	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/permission"
 	"github.com/charmbracelet/crush/internal/session"
@@ -83,6 +84,7 @@ type sessionAgent struct {
 	tools                []fantasy.AgentTool
 	sessions             session.Service
 	messages             message.Service
+	hooks                hooks.Service
 	disableAutoSummarize bool
 	isYolo               bool
 
@@ -99,6 +101,7 @@ type SessionAgentOptions struct {
 	IsYolo               bool
 	Sessions             session.Service
 	Messages             message.Service
+	Hooks                hooks.Service
 	Tools                []fantasy.AgentTool
 }
 
@@ -113,6 +116,7 @@ func NewSessionAgent(
 		sessions:             opts.Sessions,
 		messages:             opts.Messages,
 		disableAutoSummarize: opts.DisableAutoSummarize,
+		hooks:                opts.Hooks,
 		tools:                opts.Tools,
 		isYolo:               opts.IsYolo,
 		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
@@ -172,7 +176,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 	}
 
 	// Add the user message to the session.
-	_, err = a.createUserMessage(ctx, call)
+	userMsg, err := a.createUserMessage(ctx, call)
 	if err != nil {
 		return nil, err
 	}
@@ -182,19 +186,52 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 
 	genCtx, cancel := context.WithCancel(ctx)
 	a.activeRequests.Set(call.SessionID, cancel)
-
 	defer cancel()
 	defer a.activeRequests.Del(call.SessionID)
 
+	// create the agent message asap to show loading
+	var currentAssistant *message.Message
+	assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
+		Role:     message.Assistant,
+		Parts:    []message.ContentPart{},
+		Model:    a.largeModel.ModelCfg.Model,
+		Provider: a.largeModel.ModelCfg.Provider,
+	})
+	if err != nil {
+		return nil, err
+	}
+	currentAssistant = &assistantMessage
+
+	// run hooks after assistant message is created
+	// this way we show loading for the user
+	hooks, err := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
+	if err != nil {
+		return nil, err
+	}
+	userMsg.AddHookOutputs(hooks...)
+	if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
+		return nil, updateErr
+	}
+
+	for _, hook := range hooks {
+		// execution stopped
+		if hook.Stop {
+			deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
+			if deleteErr != nil {
+				return nil, deleteErr
+			}
+			return nil, nil
+		}
+	}
+
 	history, files := a.preparePrompt(msgs, call.Attachments...)
 
 	startTime := time.Now()
 	a.eventPromptSent(call.SessionID)
 
-	var currentAssistant *message.Message
 	var shouldSummarize bool
 	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
-		Prompt:           call.Prompt,
+		Prompt:           userMsg.ContentWithHooksContext(),
 		Files:            files,
 		Messages:         history,
 		ProviderOptions:  call.ProviderOptions,
@@ -206,6 +243,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 		FrequencyPenalty: call.FrequencyPenalty,
 		// Before each step create a new assistant message.
 		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
+			// only add new assistant message when its not the first step
+			if options.StepNumber != 0 {
+				var assistantMsg message.Message
+				assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
+					Role:     message.Assistant,
+					Parts:    []message.ContentPart{},
+					Model:    a.largeModel.ModelCfg.Model,
+					Provider: a.largeModel.ModelCfg.Provider,
+				})
+				currentAssistant = &assistantMsg
+				// create the message first so we show loading asap
+				if err != nil {
+					return callContext, prepared, err
+				}
+			}
+
 			prepared.Messages = options.Messages
 			// Reset all cached items.
 			for i := range prepared.Messages {
@@ -219,6 +272,27 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				if createErr != nil {
 					return callContext, prepared, createErr
 				}
+
+				// run hooks after assistant message is created
+				// this way we show loading for the user
+				hooks, hookErr := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
+				if hookErr != nil {
+					return callContext, prepared, hookErr
+				}
+				userMsg.AddHookOutputs(hooks...)
+				for _, hook := range hooks {
+					// execution stopped
+					if hook.Stop {
+						deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
+						if deleteErr != nil {
+							return callContext, prepared, deleteErr
+						}
+						return callContext, prepared, ErrHookCancellation
+					}
+				}
+				if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
+					return callContext, prepared, updateErr
+				}
 				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 			}
 
@@ -242,18 +316,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 			}
 
-			var assistantMsg message.Message
-			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
-				Role:     message.Assistant,
-				Parts:    []message.ContentPart{},
-				Model:    a.largeModel.ModelCfg.Model,
-				Provider: a.largeModel.ModelCfg.Provider,
-			})
-			if err != nil {
-				return callContext, prepared, err
-			}
-			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
-			currentAssistant = &assistantMsg
+			callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
 			return callContext, prepared, err
 		},
 		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
@@ -400,6 +463,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 	if err != nil {
 		isCancelErr := errors.Is(err, context.Canceled)
 		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
+		isHookCancelled := errors.Is(err, ErrHookCancellation)
 		if currentAssistant == nil {
 			return result, err
 		}
@@ -468,6 +532,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 		} else if isPermissionErr {
 			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
+		} else if isHookCancelled {
+			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook canceled request", "")
 		} else if errors.As(err, &providerErr) {
 			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 		} else if errors.As(err, &fantasyErr) {
@@ -882,3 +948,20 @@ func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 func (a *sessionAgent) Model() Model {
 	return a.largeModel
 }
+
+func (a *sessionAgent) executeUserPromptSubmitHook(
+	ctx context.Context,
+	sessionID string,
+	prompt string,
+) ([]message.HookOutput, error) {
+	if a.hooks == nil {
+		return nil, nil
+	}
+	return a.hooks.Execute(ctx, hooks.HookContext{
+		EventType:  config.UserPromptSubmit,
+		SessionID:  sessionID,
+		UserPrompt: prompt,
+		Provider:   a.largeModel.ModelCfg.Provider,
+		Model:      a.largeModel.ModelCfg.Model,
+	})
+}

internal/agent/agent_tool.go 🔗

@@ -34,7 +34,7 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error)
 		return nil, err
 	}
 
-	agent, err := c.buildAgent(ctx, prompt, agentCfg)
+	agent, err := c.buildAgent(ctx, prompt, agentCfg, nil)
 	if err != nil {
 		return nil, err
 	}

internal/agent/common_test.go 🔗

@@ -149,7 +149,7 @@ func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPro
 			DefaultMaxTokens: 10000,
 		},
 	}
-	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
+	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, nil, tools})
 	return agent
 }
 

internal/agent/coordinator.go 🔗

@@ -21,6 +21,7 @@ import (
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/history"
+	"github.com/charmbracelet/crush/internal/hooks"
 	"github.com/charmbracelet/crush/internal/log"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/message"
@@ -60,6 +61,7 @@ type coordinator struct {
 	messages    message.Service
 	permissions permission.Service
 	history     history.Service
+	hooks       hooks.Service
 	lspClients  *csync.Map[string, *lsp.Client]
 
 	currentAgent SessionAgent
@@ -76,6 +78,7 @@ func NewCoordinator(
 	permissions permission.Service,
 	history history.Service,
 	lspClients *csync.Map[string, *lsp.Client],
+	hooks hooks.Service,
 ) (Coordinator, error) {
 	c := &coordinator{
 		cfg:         cfg,
@@ -98,7 +101,7 @@ func NewCoordinator(
 		return nil, err
 	}
 
-	agent, err := c.buildAgent(ctx, prompt, agentCfg)
+	agent, err := c.buildAgent(ctx, prompt, agentCfg, hooks)
 	if err != nil {
 		return nil, err
 	}
@@ -274,7 +277,7 @@ func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderO
 	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 }
 
-func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
+func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, hooks hooks.Service) (SessionAgent, error) {
 	large, small, err := c.buildAgentModels(ctx)
 	if err != nil {
 		return nil, err
@@ -286,6 +289,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age
 	}
 
 	largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
+
 	result := NewSessionAgent(SessionAgentOptions{
 		large,
 		small,
@@ -295,6 +299,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age
 		c.permissions.SkipRequests(),
 		c.sessions,
 		c.messages,
+		hooks,
 		nil,
 	})
 	c.readyWg.Go(func() error {

internal/agent/errors.go 🔗

@@ -10,6 +10,7 @@ var (
 	ErrSessionBusy      = errors.New("session is currently processing another request")
 	ErrEmptyPrompt      = errors.New("prompt is empty")
 	ErrSessionMissing   = errors.New("session id is missing")
+	ErrHookCancellation = errors.New("hook cancelled the agent")
 )
 
 func isCancelledErr(err error) bool {

internal/app/app.go 🔗

@@ -23,6 +23,7 @@ import (
 	"github.com/charmbracelet/crush/internal/db"
 	"github.com/charmbracelet/crush/internal/format"
 	"github.com/charmbracelet/crush/internal/history"
+	"github.com/charmbracelet/crush/internal/hooks"
 	"github.com/charmbracelet/crush/internal/log"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/message"
@@ -42,6 +43,7 @@ type App struct {
 	Messages    message.Service
 	History     history.Service
 	Permissions permission.Service
+	Hooks       hooks.Service
 
 	AgentCoordinator agent.Coordinator
 
@@ -71,10 +73,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 		allowedTools = cfg.Permissions.AllowedTools
 	}
 
+	hooks := hooks.NewService(cfg.Hooks, cfg.WorkingDir(), nil, messages)
+
 	app := &App{
 		Sessions:    sessions,
 		Messages:    messages,
 		History:     files,
+		Hooks:       hooks,
 		Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
 		LSPClients:  csync.NewMap[string, *lsp.Client](),
 
@@ -323,6 +328,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
 		app.Permissions,
 		app.History,
 		app.LSPClients,
+		app.Hooks,
 	)
 	if err != nil {
 		slog.Error("Failed to create coder agent", "err", err)

internal/config/config.go 🔗

@@ -326,6 +326,8 @@ type Config struct {
 
 	Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"`
 
+	Hooks HookConfig `json:"hooks,omitempty" jsonschema:"description=Hook configurations for lifecycle events"`
+
 	Agents map[string]Agent `json:"-"`
 
 	// Internal
@@ -355,6 +357,14 @@ func (c *Config) IsConfigured() bool {
 	return len(c.EnabledProviders()) > 0
 }
 
+// validateHooks validates the hooks configuration.
+func (c *Config) validateHooks() error {
+	if c.Hooks == nil {
+		return nil
+	}
+	return c.Hooks.Validate()
+}
+
 func (c *Config) GetModel(provider, model string) *catwalk.Model {
 	if providerConfig, ok := c.Providers.Get(provider); ok {
 		for _, m := range providerConfig.Models {

internal/config/hooks.go 🔗

@@ -0,0 +1,100 @@
+package config
+
+import (
+	"fmt"
+	"log/slog"
+)
+
+// HookEventType represents the lifecycle event when a hook should run.
+type HookEventType string
+
+const (
+	// PreToolUse runs before tool calls and can block them.
+	PreToolUse HookEventType = "pre_tool_use"
+	// PostToolUse runs after tool calls complete.
+	PostToolUse HookEventType = "post_tool_use"
+	// UserPromptSubmit runs when the user submits a prompt, before processing.
+	UserPromptSubmit HookEventType = "user_prompt_submit"
+	// Stop runs when Crush finishes responding.
+	Stop HookEventType = "stop"
+	// SubagentStop runs when subagent tasks complete.
+	SubagentStop HookEventType = "subagent_stop"
+	// PreCompact runs before running a compact operation.
+	PreCompact HookEventType = "pre_compact"
+	// PermissionRequested runs when a permission is requested from the user.
+	PermissionRequested HookEventType = "permission_requested"
+)
+
+// Hook represents a single hook command configuration.
+type Hook struct {
+	// Type is the hook type: "command" or "prompt".
+	Type string `json:"type" jsonschema:"description=Hook type,enum=command,enum=prompt,default=command"`
+	// Command is the shell command to execute (for type: "command").
+	// WARNING: Hook commands execute with Crush's full permissions. Only use trusted commands.
+	Command string `json:"command,omitempty" jsonschema:"description=Shell command to execute for this hook (executes with Crush's permissions),example=echo 'Hook executed'"`
+	// Prompt is the LLM prompt to execute (for type: "prompt").
+	// Use $ARGUMENTS placeholder to include hook context JSON.
+	Prompt string `json:"prompt,omitempty" jsonschema:"description=LLM prompt for intelligent decision making,example=Analyze if all tasks are complete. Context: $ARGUMENTS. Return JSON with decision and reason."`
+	// Timeout is the maximum time in seconds to wait for the hook to complete.
+	// Default is 30 seconds.
+	Timeout *int `json:"timeout,omitempty" jsonschema:"description=Maximum time in seconds to wait for hook completion,default=30,minimum=1,maximum=300"`
+}
+
+// Validate checks hook configuration invariants.
+func (h *Hook) Validate() error {
+	switch h.Type {
+	case "prompt":
+		if h.Prompt == "" {
+			return fmt.Errorf("prompt-based hook missing 'prompt' field")
+		}
+	case "", "command":
+		if h.Command == "" {
+			return fmt.Errorf("command-based hook missing 'command' field")
+		}
+	default:
+		return fmt.Errorf("unsupported hook type: %s", h.Type)
+	}
+	if h.Timeout != nil {
+		if *h.Timeout < 1 {
+			slog.Warn("Hook timeout too low, using minimum",
+				"configured", *h.Timeout, "minimum", 1)
+			v := 1
+			h.Timeout = &v
+		}
+		if *h.Timeout > 300 {
+			slog.Warn("Hook timeout too high, using maximum",
+				"configured", *h.Timeout, "maximum", 300)
+			v := 300
+			h.Timeout = &v
+		}
+	}
+	return nil
+}
+
+// HookMatcher represents a matcher for a specific event type.
+type HookMatcher struct {
+	// Matcher is the tool name or pattern to match (for tool events).
+	// For non-tool events, this can be empty or "*" to match all.
+	// Supports pipe-separated tool names like "edit|write|multiedit".
+	Matcher string `json:"matcher,omitempty" jsonschema:"description=Tool name or pattern to match (e.g. 'bash' 'edit|write' for multiple or '*' for all),example=bash,example=edit|write|multiedit,example=*"`
+	// Hooks is the list of hooks to execute when the matcher matches.
+	Hooks []Hook `json:"hooks" jsonschema:"required,description=List of hooks to execute when matcher matches"`
+}
+
+// HookConfig holds the complete hook configuration.
+type HookConfig map[HookEventType][]HookMatcher
+
+// Validate validates the entire hook configuration.
+func (c HookConfig) Validate() error {
+	for eventType, matchers := range c {
+		for i, matcher := range matchers {
+			for j := range matcher.Hooks {
+				if err := matcher.Hooks[j].Validate(); err != nil {
+					return fmt.Errorf("invalid hook config for %s matcher %d hook %d: %w",
+						eventType, i, j, err)
+				}
+			}
+		}
+	}
+	return nil
+}

internal/config/load.go 🔗

@@ -81,6 +81,11 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
 	}
 	cfg.knownProviders = providers
 
+	// Validate hooks configuration.
+	if err := cfg.validateHooks(); err != nil {
+		return nil, fmt.Errorf("invalid hooks configuration: %w", err)
+	}
+
 	env := env.New()
 	// Configure providers
 	valueResolver := NewShellVariableResolver(env)

internal/db/messages.sql.go 🔗

@@ -19,12 +19,13 @@ INSERT INTO messages (
     model,
     provider,
     is_summary_message,
+    hook_outputs,
     created_at,
     updated_at
 ) VALUES (
-    ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
+    ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
 )
-RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs
 `
 
 type CreateMessageParams struct {
@@ -35,6 +36,7 @@ type CreateMessageParams struct {
 	Model            sql.NullString `json:"model"`
 	Provider         sql.NullString `json:"provider"`
 	IsSummaryMessage int64          `json:"is_summary_message"`
+	HookOutputs      string         `json:"hook_outputs"`
 }
 
 func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
@@ -46,6 +48,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
 		arg.Model,
 		arg.Provider,
 		arg.IsSummaryMessage,
+		arg.HookOutputs,
 	)
 	var i Message
 	err := row.Scan(
@@ -59,6 +62,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
 		&i.FinishedAt,
 		&i.Provider,
 		&i.IsSummaryMessage,
+		&i.HookOutputs,
 	)
 	return i, err
 }
@@ -84,7 +88,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e
 }
 
 const getMessage = `-- name: GetMessage :one
-SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs
 FROM messages
 WHERE id = ? LIMIT 1
 `
@@ -103,12 +107,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
 		&i.FinishedAt,
 		&i.Provider,
 		&i.IsSummaryMessage,
+		&i.HookOutputs,
 	)
 	return i, err
 }
 
 const listMessagesBySession = `-- name: ListMessagesBySession :many
-SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message
+SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs
 FROM messages
 WHERE session_id = ?
 ORDER BY created_at ASC
@@ -134,6 +139,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
 			&i.FinishedAt,
 			&i.Provider,
 			&i.IsSummaryMessage,
+			&i.HookOutputs,
 		); err != nil {
 			return nil, err
 		}
@@ -153,17 +159,24 @@ UPDATE messages
 SET
     parts = ?,
     finished_at = ?,
+    hook_outputs = ?,
     updated_at = strftime('%s', 'now')
 WHERE id = ?
 `
 
 type UpdateMessageParams struct {
-	Parts      string        `json:"parts"`
-	FinishedAt sql.NullInt64 `json:"finished_at"`
-	ID         string        `json:"id"`
+	Parts       string        `json:"parts"`
+	FinishedAt  sql.NullInt64 `json:"finished_at"`
+	HookOutputs string        `json:"hook_outputs"`
+	ID          string        `json:"id"`
 }
 
 func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error {
-	_, err := q.exec(ctx, q.updateMessageStmt, updateMessage, arg.Parts, arg.FinishedAt, arg.ID)
+	_, err := q.exec(ctx, q.updateMessageStmt, updateMessage,
+		arg.Parts,
+		arg.FinishedAt,
+		arg.HookOutputs,
+		arg.ID,
+	)
 	return err
 }

internal/db/models.go 🔗

@@ -29,6 +29,7 @@ type Message struct {
 	FinishedAt       sql.NullInt64  `json:"finished_at"`
 	Provider         sql.NullString `json:"provider"`
 	IsSummaryMessage int64          `json:"is_summary_message"`
+	HookOutputs      string         `json:"hook_outputs"`
 }
 
 type Session struct {

internal/db/sql/messages.sql 🔗

@@ -18,10 +18,11 @@ INSERT INTO messages (
     model,
     provider,
     is_summary_message,
+    hook_outputs,
     created_at,
     updated_at
 ) VALUES (
-    ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
+    ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
 )
 RETURNING *;
 
@@ -30,6 +31,7 @@ UPDATE messages
 SET
     parts = ?,
     finished_at = ?,
+    hook_outputs = ?,
     updated_at = strftime('%s', 'now')
 WHERE id = ?;
 

internal/hooks/hooks.go 🔗

@@ -0,0 +1,415 @@
+package hooks
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"log/slog"
+	"regexp"
+	"strings"
+	"sync"
+	"time"
+
+	"charm.land/fantasy"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/charmbracelet/crush/internal/permission"
+	"github.com/charmbracelet/crush/internal/shell"
+)
+
+const DefaultHookTimeout = 30 * time.Second
+
+// HookContext contains context information passed to hooks.
+// Fields are populated based on event type - see field comments for availability.
+type HookContext struct {
+	// EventType is the lifecycle event that triggered this hook.
+	// Available: All events
+	EventType config.HookEventType `json:"event_type"`
+
+	// SessionID identifies the Crush session.
+	// Available: All events (if session exists)
+	SessionID string `json:"session_id,omitempty"`
+
+	// TranscriptPath is the path to exported session transcript JSON file.
+	// Available: All events (if session exists and export succeeds)
+	TranscriptPath string `json:"transcript_path,omitempty"`
+
+	// ToolName is the name of the tool being called.
+	// Available: pre_tool_use, post_tool_use, permission_requested
+	ToolName string `json:"tool_name,omitempty"`
+
+	// ToolInput contains the tool's input parameters as key-value pairs.
+	// Available: pre_tool_use, post_tool_use
+	ToolInput map[string]any `json:"tool_input,omitempty"`
+
+	// ToolResult is the string result returned by the tool.
+	// Available: post_tool_use
+	ToolResult string `json:"tool_result,omitempty"`
+
+	// ToolError indicates whether the tool execution failed.
+	// Available: post_tool_use
+	ToolError bool `json:"tool_error,omitempty"`
+
+	// UserPrompt is the prompt submitted by the user.
+	// Available: user_prompt_submit
+	UserPrompt string `json:"user_prompt,omitempty"`
+
+	// Timestamp is when the hook was triggered (RFC3339 format).
+	// Available: All events
+	Timestamp time.Time `json:"timestamp"`
+
+	// WorkingDir is the project working directory.
+	// Available: All events
+	WorkingDir string `json:"working_dir,omitempty"`
+
+	// MessageID identifies the assistant message being processed.
+	// Available: pre_tool_use, post_tool_use, stop
+	MessageID string `json:"message_id,omitempty"`
+
+	// Provider is the LLM provider name (e.g., "anthropic").
+	// Available: All events with LLM interaction
+	Provider string `json:"provider,omitempty"`
+
+	// Model is the LLM model name (e.g., "claude-3-5-sonnet-20241022").
+	// Available: All events with LLM interaction
+	Model string `json:"model,omitempty"`
+
+	// TokensUsed is the total tokens consumed in this interaction.
+	// Available: stop
+	TokensUsed int64 `json:"tokens_used,omitempty"`
+
+	// TokensInput is the input tokens consumed in this interaction.
+	// Available: stop
+	TokensInput int64 `json:"tokens_input,omitempty"`
+
+	// PermissionAction is the permission action being requested (e.g., "read", "write").
+	// Available: permission_requested
+	PermissionAction string `json:"permission_action,omitempty"`
+
+	// PermissionPath is the file path involved in the permission request.
+	// Available: permission_requested
+	PermissionPath string `json:"permission_path,omitempty"`
+
+	// PermissionParams contains additional permission parameters.
+	// Available: permission_requested
+	PermissionRequest *permission.PermissionRequest `json:"permission_request,omitempty"`
+}
+
+type HookDecision string
+
+const (
+	HookDecisionBlock HookDecision = "block"
+	HookDecisionDeny  HookDecision = "deny"
+	HookDecisionAllow HookDecision = "allow"
+	HookDecisionAsk   HookDecision = "ask"
+)
+
+type Service interface {
+	Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error)
+	SetSmallModel(model fantasy.LanguageModel)
+}
+
+type service struct {
+	config     config.HookConfig
+	workingDir string
+	regexCache *csync.Map[string, *regexp.Regexp]
+	smallModel fantasy.LanguageModel
+	messages   message.Service
+}
+
+// NewService creates a new hook executor.
+func NewService(
+	hookConfig config.HookConfig,
+	workingDir string,
+	smallModel fantasy.LanguageModel,
+	messages message.Service,
+) Service {
+	return &service{
+		config:     hookConfig,
+		workingDir: workingDir,
+		regexCache: csync.NewMap[string, *regexp.Regexp](),
+		smallModel: smallModel,
+		messages:   messages,
+	}
+}
+
+// Execute implements Service.
+func (s *service) Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error) {
+	if s.config == nil {
+		return nil, nil
+	}
+
+	// Check if context is already cancelled - prevents race conditions during cancellation.
+	if ctx.Err() != nil {
+		return nil, ctx.Err()
+	}
+
+	hookCtx.Timestamp = time.Now()
+	hookCtx.WorkingDir = s.workingDir
+
+	hooks := s.collectMatchingHooks(hookCtx)
+	if len(hooks) == 0 {
+		return nil, nil
+	}
+
+	transcriptPath, cleanup, err := s.setupTranscript(ctx, hookCtx)
+	if err != nil {
+		slog.Warn("Failed to export transcript for hooks", "error", err)
+	} else if transcriptPath != "" {
+		hookCtx.TranscriptPath = transcriptPath
+		// Ensure cleanup happens even on panic.
+		defer func() {
+			cleanup()
+		}()
+	}
+
+	results := make([]message.HookOutput, len(hooks))
+	var wg sync.WaitGroup
+
+	for i, hook := range hooks {
+		if ctx.Err() != nil {
+			return nil, ctx.Err()
+		}
+		wg.Add(1)
+		go func(idx int, h config.Hook) {
+			defer wg.Done()
+			result, err := s.executeHook(ctx, h, hookCtx)
+			if err != nil {
+				slog.Warn("Hook execution failed",
+					"event", hookCtx.EventType,
+					"error", err,
+				)
+			} else {
+				results[idx] = *result
+			}
+		}(i, hook)
+	}
+	wg.Wait()
+	return results, nil
+}
+
+func (s *service) setupTranscript(ctx context.Context, hookCtx HookContext) (string, func(), error) {
+	if hookCtx.SessionID == "" || s.messages == nil {
+		return "", func() {}, nil
+	}
+
+	path, err := exportTranscript(ctx, s.messages, hookCtx.SessionID)
+	if err != nil {
+		return "", func() {}, err
+	}
+
+	cleanup := func() {
+		if path != "" {
+			cleanupTranscript(path)
+		}
+	}
+
+	return path, cleanup, nil
+}
+
+func (s *service) executeHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
+	var result *message.HookOutput
+	var err error
+
+	switch hook.Type {
+	case "prompt":
+		if hook.Prompt == "" {
+			return nil, fmt.Errorf("prompt-based hook missing 'prompt' field")
+		}
+		if s.smallModel == nil {
+			return nil, fmt.Errorf("prompt-based hook requires small model configuration")
+		}
+		result, err = s.executePromptHook(ctx, hook, hookCtx)
+	case "", "command":
+		if hook.Command == "" {
+			return nil, fmt.Errorf("command-based hook missing 'command' field")
+		}
+		slog.Info("executing")
+		result, err = s.executeCommandHook(ctx, hook, hookCtx)
+	default:
+		return nil, fmt.Errorf("unsupported hook type: %s", hook.Type)
+	}
+
+	return result, err
+}
+
+func (s *service) executePromptHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
+	panic("not implemented")
+}
+
+func (s *service) executeCommandHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) {
+	timeout := DefaultHookTimeout
+	if hook.Timeout != nil {
+		timeout = time.Duration(*hook.Timeout) * time.Second
+	}
+
+	execCtx, cancel := context.WithTimeout(ctx, timeout)
+	defer cancel()
+
+	contextJSON, err := json.Marshal(hookCtx)
+	if err != nil {
+		return nil, fmt.Errorf("failed to marshal hook context: %w", err)
+	}
+
+	sh := shell.NewShell(&shell.Options{
+		WorkingDir: s.workingDir,
+	})
+
+	sh.SetEnv("CRUSH_HOOK_EVENT", string(hookCtx.EventType))
+	sh.SetEnv("CRUSH_HOOK_CONTEXT", string(contextJSON))
+	sh.SetEnv("CRUSH_PROJECT_DIR", s.workingDir)
+	if hookCtx.SessionID != "" {
+		sh.SetEnv("CRUSH_SESSION_ID", hookCtx.SessionID)
+	}
+	if hookCtx.ToolName != "" {
+		sh.SetEnv("CRUSH_TOOL_NAME", hookCtx.ToolName)
+	}
+
+	slog.Debug("Hook execution trace",
+		"event", hookCtx.EventType,
+		"command", hook.Command,
+		"timeout", timeout,
+		"context", hookCtx,
+	)
+
+	stdout, stderr, err := sh.ExecWithStdin(execCtx, hook.Command, string(contextJSON))
+
+	exitCode := shell.ExitCode(err)
+
+	var result *message.HookOutput
+	switch exitCode {
+	case 0:
+		// if the event is  UserPromptSubmit we want the output to be added to the context
+		if hookCtx.EventType == config.UserPromptSubmit {
+			result = &message.HookOutput{
+				AdditionalContext: stdout,
+			}
+		} else {
+			result = &message.HookOutput{
+				Message: stdout,
+			}
+		}
+	case 2:
+		result = &message.HookOutput{
+			Decision: string(HookDecisionBlock),
+			Error:    stderr,
+		}
+		return result, nil
+	default:
+		result = &message.HookOutput{
+			Error: stderr,
+		}
+		return result, nil
+	}
+
+	jsonOutput := parseHookOutput(stdout)
+	if jsonOutput == nil {
+		return result, nil
+	}
+
+	result.Message = jsonOutput.Message
+	result.Stop = jsonOutput.Stop
+	result.Decision = jsonOutput.Decision
+	result.AdditionalContext = jsonOutput.AdditionalContext
+	result.UpdatedInput = jsonOutput.UpdatedInput
+
+	// Trace output in debug mode
+	slog.Debug("Hook execution output",
+		"event", hookCtx.EventType,
+		"exit_code", exitCode,
+		"stdout_length", len(stdout),
+		"stderr_length", len(stderr),
+		"stdout", stdout,
+		"stderr", stderr,
+	)
+	return result, nil
+}
+
+func parseHookOutput(stdout string) *message.HookOutput {
+	stdout = strings.TrimSpace(stdout)
+	slog.Info(stdout)
+	if stdout == "" {
+		return nil
+	}
+
+	var output message.HookOutput
+	if err := json.Unmarshal([]byte(stdout), &output); err != nil {
+		// Failed to parse as HookOutput
+		return nil
+	}
+
+	return &output
+}
+
+// SetSmallModel implements Service.
+func (s *service) SetSmallModel(model fantasy.LanguageModel) {
+	panic("unimplemented")
+}
+
+func (s *service) collectMatchingHooks(hookCtx HookContext) []config.Hook {
+	matchers, ok := s.config[hookCtx.EventType]
+	if !ok || len(matchers) == 0 {
+		return nil
+	}
+
+	var hooks []config.Hook
+	for _, matcher := range matchers {
+		if !s.matcherApplies(matcher, hookCtx) {
+			continue
+		}
+		hooks = append(hooks, matcher.Hooks...)
+	}
+	return hooks
+}
+
+func (s *service) matcherApplies(matcher config.HookMatcher, ctx HookContext) bool {
+	if ctx.EventType == config.PreToolUse || ctx.EventType == config.PostToolUse {
+		return s.matchesToolName(matcher.Matcher, ctx.ToolName)
+	}
+
+	return matcher.Matcher == "" || matcher.Matcher == "*"
+}
+
+func (s *service) matchesToolName(pattern, toolName string) bool {
+	if pattern == "" || pattern == "*" {
+		return true
+	}
+
+	if pattern == toolName {
+		return true
+	}
+
+	if strings.Contains(pattern, "|") {
+		for tool := range strings.SplitSeq(pattern, "|") {
+			tool = strings.TrimSpace(tool)
+			if tool == toolName {
+				return true
+			}
+		}
+
+		return s.matchesRegex(pattern, toolName)
+	}
+
+	return s.matchesRegex(pattern, toolName)
+}
+
+func (s *service) matchesRegex(pattern, text string) bool {
+	re, ok := s.regexCache.Get(pattern)
+	if !ok {
+		compiled, err := regexp.Compile(pattern)
+		if err != nil {
+			// Not a valid regex, don't cache failures.
+			return false
+		}
+		re = s.regexCache.GetOrSet(pattern, func() *regexp.Regexp {
+			return compiled
+		})
+	}
+
+	if re == nil {
+		return false
+	}
+
+	return re.MatchString(text)
+}

internal/hooks/transcript.go 🔗

@@ -0,0 +1,133 @@
+package hooks
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"os"
+	"path/filepath"
+	"time"
+
+	"github.com/charmbracelet/crush/internal/message"
+)
+
+// TranscriptMessage represents a message in the exported transcript.
+type TranscriptMessage struct {
+	ID          string                 `json:"id"`
+	Role        string                 `json:"role"`
+	Content     string                 `json:"content,omitempty"`
+	ToolCalls   []TranscriptToolCall   `json:"tool_calls,omitempty"`
+	ToolResults []TranscriptToolResult `json:"tool_results,omitempty"`
+	Timestamp   string                 `json:"timestamp"`
+}
+
+// TranscriptToolCall represents a tool call in the transcript.
+type TranscriptToolCall struct {
+	ID    string `json:"id"`
+	Name  string `json:"name"`
+	Input string `json:"input"`
+}
+
+// TranscriptToolResult represents a tool result in the transcript.
+type TranscriptToolResult struct {
+	ToolCallID string `json:"tool_call_id"`
+	Name       string `json:"name"`
+	Content    string `json:"content"`
+	IsError    bool   `json:"is_error"`
+}
+
+// Transcript represents the complete transcript structure.
+type Transcript struct {
+	SessionID string              `json:"session_id"`
+	Messages  []TranscriptMessage `json:"messages"`
+}
+
+// exportTranscript exports session messages to a temporary JSON file.
+func exportTranscript(
+	ctx context.Context,
+	messages message.Service,
+	sessionID string,
+) (string, error) {
+	// Get all messages for the session
+	msgs, err := messages.List(ctx, sessionID)
+	if err != nil {
+		return "", fmt.Errorf("failed to list messages: %w", err)
+	}
+
+	// Convert to transcript format
+	transcript := Transcript{
+		SessionID: sessionID,
+		Messages:  make([]TranscriptMessage, 0, len(msgs)),
+	}
+
+	for _, msg := range msgs {
+		tm := TranscriptMessage{
+			ID:        msg.ID,
+			Role:      string(msg.Role),
+			Timestamp: time.Unix(msg.CreatedAt, 0).Format("2006-01-02T15:04:05Z07:00"),
+		}
+
+		// Extract content
+		for _, part := range msg.Parts {
+			if text, ok := part.(message.TextContent); ok {
+				if tm.Content != "" {
+					tm.Content += "\n"
+				}
+				tm.Content += text.Text
+			}
+		}
+
+		// Extract tool calls
+		if msg.Role == message.Assistant {
+			toolCalls := msg.ToolCalls()
+			for _, tc := range toolCalls {
+				tm.ToolCalls = append(tm.ToolCalls, TranscriptToolCall{
+					ID:    tc.ID,
+					Name:  tc.Name,
+					Input: tc.Input,
+				})
+			}
+		}
+
+		// Extract tool results
+		if msg.Role == message.Tool {
+			toolResults := msg.ToolResults()
+			for _, tr := range toolResults {
+				tm.ToolResults = append(tm.ToolResults, TranscriptToolResult{
+					ToolCallID: tr.ToolCallID,
+					Name:       tr.Name,
+					Content:    tr.Content,
+					IsError:    tr.IsError,
+				})
+			}
+		}
+
+		transcript.Messages = append(transcript.Messages, tm)
+	}
+
+	// Marshal to JSON
+	data, err := json.MarshalIndent(transcript, "", "  ")
+	if err != nil {
+		return "", fmt.Errorf("failed to marshal transcript: %w", err)
+	}
+
+	// Write to temporary file
+	tmpDir := os.TempDir()
+	filename := fmt.Sprintf("crush-transcript-%s.json", sessionID)
+	path := filepath.Join(tmpDir, filename)
+
+	// Use restrictive permissions (0600)
+	if err := os.WriteFile(path, data, 0o600); err != nil {
+		return "", fmt.Errorf("failed to write transcript file: %w", err)
+	}
+
+	return path, nil
+}
+
+func cleanupTranscript(path string) {
+	if path != "" {
+		if err := os.Remove(path); err != nil {
+			fmt.Fprintf(os.Stderr, "Warning: failed to cleanup transcript file %s: %v\n", path, err)
+		}
+	}
+}

internal/message/content.go 🔗

@@ -133,6 +133,16 @@ type Message struct {
 	CreatedAt        int64
 	UpdatedAt        int64
 	IsSummaryMessage bool
+	HookOutputs      []HookOutput
+}
+
+type HookOutput struct {
+	Stop              bool   `json:"stop"`
+	Error             string `json:"error"`
+	Message           string `json:"message"`
+	Decision          string `json:"decision"`
+	UpdatedInput      string `json:"updated_input"`
+	AdditionalContext string `json:"additional_context"`
 }
 
 func (m *Message) Content() TextContent {
@@ -144,6 +154,23 @@ func (m *Message) Content() TextContent {
 	return TextContent{}
 }
 
+func (m *Message) ContentWithHooksContext() string {
+	text := strings.TrimSpace(m.Content().Text)
+
+	var additionalContext []string
+	for _, hookOutput := range m.HookOutputs {
+		context := strings.TrimSpace(hookOutput.AdditionalContext)
+		if context != "" {
+			additionalContext = append(additionalContext, context)
+		}
+	}
+	if len(additionalContext) > 0 {
+		text += "## Additional Context\n"
+		text += strings.Join(additionalContext, "\n")
+	}
+	return text
+}
+
 func (m *Message) ReasoningContent() ReasoningContent {
 	for _, part := range m.Parts {
 		if c, ok := part.(ReasoningContent); ok {
@@ -202,6 +229,11 @@ func (m *Message) IsFinished() bool {
 	return false
 }
 
+// AddHookOutputs appends multiple hook outputs to the message's hook outputs.
+func (m *Message) AddHookOutputs(outputs ...HookOutput) {
+	m.HookOutputs = append(m.HookOutputs, outputs...)
+}
+
 func (m *Message) FinishPart() *Finish {
 	for _, part := range m.Parts {
 		if c, ok := part.(Finish); ok {
@@ -429,7 +461,7 @@ func (m *Message) ToAIMessage() []fantasy.Message {
 	switch m.Role {
 	case User:
 		var parts []fantasy.MessagePart
-		text := strings.TrimSpace(m.Content().Text)
+		text := strings.TrimSpace(m.ContentWithHooksContext())
 		if text != "" {
 			parts = append(parts, fantasy.TextPart{Text: text})
 		}

internal/message/message.go 🔗

@@ -18,6 +18,7 @@ type CreateMessageParams struct {
 	Model            string
 	Provider         string
 	IsSummaryMessage bool
+	HookOutputs      []HookOutput
 }
 
 type Service interface {
@@ -69,6 +70,10 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
 	if params.IsSummaryMessage {
 		isSummary = 1
 	}
+	hookOutputsJSON, err := json.Marshal(params.HookOutputs)
+	if err != nil {
+		return Message{}, err
+	}
 	dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
 		ID:               uuid.New().String(),
 		SessionID:        sessionID,
@@ -77,6 +82,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
 		Model:            sql.NullString{String: string(params.Model), Valid: true},
 		Provider:         sql.NullString{String: params.Provider, Valid: params.Provider != ""},
 		IsSummaryMessage: isSummary,
+		HookOutputs:      string(hookOutputsJSON),
 	})
 	if err != nil {
 		return Message{}, err
@@ -115,10 +121,15 @@ func (s *service) Update(ctx context.Context, message Message) error {
 		finishedAt.Int64 = f.Time
 		finishedAt.Valid = true
 	}
+	hookOutputsJSON, err := json.Marshal(message.HookOutputs)
+	if err != nil {
+		return err
+	}
 	err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
-		ID:         message.ID,
-		Parts:      string(parts),
-		FinishedAt: finishedAt,
+		ID:          message.ID,
+		Parts:       string(parts),
+		FinishedAt:  finishedAt,
+		HookOutputs: string(hookOutputsJSON),
 	})
 	if err != nil {
 		return err
@@ -156,6 +167,12 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
 	if err != nil {
 		return Message{}, err
 	}
+	var hookOutputs []HookOutput
+	if item.HookOutputs != "" {
+		if err := json.Unmarshal([]byte(item.HookOutputs), &hookOutputs); err != nil {
+			return Message{}, err
+		}
+	}
 	return Message{
 		ID:               item.ID,
 		SessionID:        item.SessionID,
@@ -166,6 +183,7 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
 		CreatedAt:        item.CreatedAt,
 		UpdatedAt:        item.UpdatedAt,
 		IsSummaryMessage: item.IsSummaryMessage != 0,
+		HookOutputs:      hookOutputs,
 	}, nil
 }
 

internal/shell/shell.go 🔗

@@ -111,6 +111,15 @@ func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr i
 	return s.execStream(ctx, command, stdout, stderr)
 }
 
+func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin string) (string, string, error) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	var stdout, stderr bytes.Buffer
+	err := s.execCommonWithStdin(ctx, command, strings.NewReader(stdin), &stdout, &stderr)
+	return stdout.String(), stderr.String(), err
+}
+
 // GetWorkingDir returns the current working directory
 func (s *Shell) GetWorkingDir() string {
 	s.mu.Lock()
@@ -237,9 +246,9 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand
 }
 
 // newInterp creates a new interpreter with the current shell state
-func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) {
+func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
 	return interp.New(
-		interp.StdIO(nil, stdout, stderr),
+		interp.StdIO(stdin, stdout, stderr),
 		interp.Interactive(false),
 		interp.Env(expand.ListEnviron(s.env...)),
 		interp.Dir(s.cwd),
@@ -263,7 +272,7 @@ func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr i
 		return fmt.Errorf("could not parse command: %w", err)
 	}
 
-	runner, err := s.newInterp(stdout, stderr)
+	runner, err := s.newInterp(nil, stdout, stderr)
 	if err != nil {
 		return fmt.Errorf("could not run command: %w", err)
 	}
@@ -286,6 +295,24 @@ func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr i
 	return s.execCommon(ctx, command, stdout, stderr)
 }
 
+// execCommonWithStdin is like execCommon but with stdin support
+func (s *Shell) execCommonWithStdin(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error {
+	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
+	if err != nil {
+		return fmt.Errorf("could not parse command: %w", err)
+	}
+
+	runner, err := s.newInterp(stdin, stdout, stderr)
+	if err != nil {
+		return fmt.Errorf("could not run command: %w", err)
+	}
+
+	err = runner.Run(ctx, line)
+	s.updateShellFromRunner(runner)
+	s.logger.InfoPersist("command finished", "command", command, "err", err)
+	return err
+}
+
 func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
 	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
 		s.blockHandler(),