From d03f9be2fbf0bb13893289bccad42bf88aeeb809 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 17 Nov 2025 12:08:51 +0100 Subject: [PATCH] wip: hooks implementation --- internal/agent/agent.go | 115 ++++- internal/agent/agent_tool.go | 2 +- internal/agent/common_test.go | 2 +- internal/agent/coordinator.go | 9 +- internal/agent/errors.go | 1 + internal/app/app.go | 6 + internal/config/config.go | 10 + internal/config/hooks.go | 100 +++++ internal/config/load.go | 5 + internal/db/messages.sql.go | 29 +- ...811000000_add_hook_outputs_to_messages.sql | 5 + internal/db/models.go | 1 + internal/db/sql/messages.sql | 4 +- internal/hooks/hooks.go | 415 ++++++++++++++++++ internal/hooks/transcript.go | 133 ++++++ internal/message/content.go | 34 +- internal/message/message.go | 24 +- internal/shell/shell.go | 33 +- 18 files changed, 892 insertions(+), 36 deletions(-) create mode 100644 internal/config/hooks.go create mode 100644 internal/db/migrations/20250811000000_add_hook_outputs_to_messages.sql create mode 100644 internal/hooks/hooks.go create mode 100644 internal/hooks/transcript.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index f4d8af377934245bd1c94f7cadf61bbadf60dd44..2dcd82b9dd34ba9bfec7a6250a76153bc64ba5ba 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -30,6 +30,7 @@ import ( "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/hooks" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/session" @@ -83,6 +84,7 @@ type sessionAgent struct { tools []fantasy.AgentTool sessions session.Service messages message.Service + hooks hooks.Service disableAutoSummarize bool isYolo bool @@ -99,6 +101,7 @@ type SessionAgentOptions struct { IsYolo bool Sessions session.Service Messages message.Service + Hooks hooks.Service Tools []fantasy.AgentTool } @@ -113,6 +116,7 @@ func NewSessionAgent( sessions: opts.Sessions, messages: opts.Messages, disableAutoSummarize: opts.DisableAutoSummarize, + hooks: opts.Hooks, tools: opts.Tools, isYolo: opts.IsYolo, messageQueue: csync.NewMap[string, []SessionAgentCall](), @@ -172,7 +176,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } // Add the user message to the session. - _, err = a.createUserMessage(ctx, call) + userMsg, err := a.createUserMessage(ctx, call) if err != nil { return nil, err } @@ -182,19 +186,52 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy genCtx, cancel := context.WithCancel(ctx) a.activeRequests.Set(call.SessionID, cancel) - defer cancel() defer a.activeRequests.Del(call.SessionID) + // create the agent message asap to show loading + var currentAssistant *message.Message + assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.largeModel.ModelCfg.Model, + Provider: a.largeModel.ModelCfg.Provider, + }) + if err != nil { + return nil, err + } + currentAssistant = &assistantMessage + + // run hooks after assistant message is created + // this way we show loading for the user + hooks, err := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt) + if err != nil { + return nil, err + } + userMsg.AddHookOutputs(hooks...) + if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil { + return nil, updateErr + } + + for _, hook := range hooks { + // execution stopped + if hook.Stop { + deleteErr := a.messages.Delete(genCtx, assistantMessage.ID) + if deleteErr != nil { + return nil, deleteErr + } + return nil, nil + } + } + history, files := a.preparePrompt(msgs, call.Attachments...) startTime := time.Now() a.eventPromptSent(call.SessionID) - var currentAssistant *message.Message var shouldSummarize bool result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{ - Prompt: call.Prompt, + Prompt: userMsg.ContentWithHooksContext(), Files: files, Messages: history, ProviderOptions: call.ProviderOptions, @@ -206,6 +243,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy FrequencyPenalty: call.FrequencyPenalty, // Before each step create a new assistant message. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { + // only add new assistant message when its not the first step + if options.StepNumber != 0 { + var assistantMsg message.Message + assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.largeModel.ModelCfg.Model, + Provider: a.largeModel.ModelCfg.Provider, + }) + currentAssistant = &assistantMsg + // create the message first so we show loading asap + if err != nil { + return callContext, prepared, err + } + } + prepared.Messages = options.Messages // Reset all cached items. for i := range prepared.Messages { @@ -219,6 +272,27 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy if createErr != nil { return callContext, prepared, createErr } + + // run hooks after assistant message is created + // this way we show loading for the user + hooks, hookErr := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt) + if hookErr != nil { + return callContext, prepared, hookErr + } + userMsg.AddHookOutputs(hooks...) + for _, hook := range hooks { + // execution stopped + if hook.Stop { + deleteErr := a.messages.Delete(genCtx, assistantMessage.ID) + if deleteErr != nil { + return callContext, prepared, deleteErr + } + return callContext, prepared, ErrHookCancellation + } + } + if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil { + return callContext, prepared, updateErr + } prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } @@ -242,18 +316,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...) } - var assistantMsg message.Message - assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.largeModel.ModelCfg.Model, - Provider: a.largeModel.ModelCfg.Provider, - }) - if err != nil { - return callContext, prepared, err - } - callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID) - currentAssistant = &assistantMsg + callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID) return callContext, prepared, err }, OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error { @@ -400,6 +463,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy if err != nil { isCancelErr := errors.Is(err, context.Canceled) isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied) + isHookCancelled := errors.Is(err, ErrHookCancellation) if currentAssistant == nil { return result, err } @@ -468,6 +532,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "") } else if isPermissionErr { currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "") + } else if isHookCancelled { + currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook canceled request", "") } else if errors.As(err, &providerErr) { currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message) } else if errors.As(err, &fantasyErr) { @@ -882,3 +948,20 @@ func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) { func (a *sessionAgent) Model() Model { return a.largeModel } + +func (a *sessionAgent) executeUserPromptSubmitHook( + ctx context.Context, + sessionID string, + prompt string, +) ([]message.HookOutput, error) { + if a.hooks == nil { + return nil, nil + } + return a.hooks.Execute(ctx, hooks.HookContext{ + EventType: config.UserPromptSubmit, + SessionID: sessionID, + UserPrompt: prompt, + Provider: a.largeModel.ModelCfg.Provider, + Model: a.largeModel.ModelCfg.Model, + }) +} diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index b937c20dd510eb9dff52bd348bf7baf21aafcf8c..3d0743f9bba54811182e22d509552c341e73161b 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -34,7 +34,7 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) return nil, err } - agent, err := c.buildAgent(ctx, prompt, agentCfg) + agent, err := c.buildAgent(ctx, prompt, agentCfg, nil) if err != nil { return nil, err } diff --git a/internal/agent/common_test.go b/internal/agent/common_test.go index 58264e2637f3a44e45d54a66c72ee3b8d6c642a3..bebea79e9ffe022d6cc4b6dd485a04a61c20a3e1 100644 --- a/internal/agent/common_test.go +++ b/internal/agent/common_test.go @@ -149,7 +149,7 @@ func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPro DefaultMaxTokens: 10000, }, } - agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools}) + agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, nil, tools}) return agent } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 4bfcc0062ae9a06dc858989f2cce925976d6d32b..1dcae362835955cd8d446480432035ed36f4069e 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -21,6 +21,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/hooks" "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" @@ -60,6 +61,7 @@ type coordinator struct { messages message.Service permissions permission.Service history history.Service + hooks hooks.Service lspClients *csync.Map[string, *lsp.Client] currentAgent SessionAgent @@ -76,6 +78,7 @@ func NewCoordinator( permissions permission.Service, history history.Service, lspClients *csync.Map[string, *lsp.Client], + hooks hooks.Service, ) (Coordinator, error) { c := &coordinator{ cfg: cfg, @@ -98,7 +101,7 @@ func NewCoordinator( return nil, err } - agent, err := c.buildAgent(ctx, prompt, agentCfg) + agent, err := c.buildAgent(ctx, prompt, agentCfg, hooks) if err != nil { return nil, err } @@ -274,7 +277,7 @@ func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderO return modelOptions, temp, topP, topK, freqPenalty, presPenalty } -func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) { +func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, hooks hooks.Service) (SessionAgent, error) { large, small, err := c.buildAgentModels(ctx) if err != nil { return nil, err @@ -286,6 +289,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age } largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider) + result := NewSessionAgent(SessionAgentOptions{ large, small, @@ -295,6 +299,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age c.permissions.SkipRequests(), c.sessions, c.messages, + hooks, nil, }) c.readyWg.Go(func() error { diff --git a/internal/agent/errors.go b/internal/agent/errors.go index 1b4f0dfce6c8b13a0b22c42ed651ba895e24ed9b..1cfed9366189a3afb646ccc34fb1ceb6f78b730e 100644 --- a/internal/agent/errors.go +++ b/internal/agent/errors.go @@ -10,6 +10,7 @@ var ( ErrSessionBusy = errors.New("session is currently processing another request") ErrEmptyPrompt = errors.New("prompt is empty") ErrSessionMissing = errors.New("session id is missing") + ErrHookCancellation = errors.New("hook cancelled the agent") ) func isCancelledErr(err error) bool { diff --git a/internal/app/app.go b/internal/app/app.go index d3e6d2133346df1adc11fc13a612b67cf25b46bd..18d7504b583cfc1cb6e6bc639b718f7a89b743f9 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -23,6 +23,7 @@ import ( "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/hooks" "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" @@ -42,6 +43,7 @@ type App struct { Messages message.Service History history.Service Permissions permission.Service + Hooks hooks.Service AgentCoordinator agent.Coordinator @@ -71,10 +73,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { allowedTools = cfg.Permissions.AllowedTools } + hooks := hooks.NewService(cfg.Hooks, cfg.WorkingDir(), nil, messages) + app := &App{ Sessions: sessions, Messages: messages, History: files, + Hooks: hooks, Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools), LSPClients: csync.NewMap[string, *lsp.Client](), @@ -323,6 +328,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error { app.Permissions, app.History, app.LSPClients, + app.Hooks, ) if err != nil { slog.Error("Failed to create coder agent", "err", err) diff --git a/internal/config/config.go b/internal/config/config.go index 2adc45f050b46afb8788042035531c032159926e..defe26c472a1fed0ba13b09a53bfb938476dfdb9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -326,6 +326,8 @@ type Config struct { Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"` + Hooks HookConfig `json:"hooks,omitempty" jsonschema:"description=Hook configurations for lifecycle events"` + Agents map[string]Agent `json:"-"` // Internal @@ -355,6 +357,14 @@ func (c *Config) IsConfigured() bool { return len(c.EnabledProviders()) > 0 } +// validateHooks validates the hooks configuration. +func (c *Config) validateHooks() error { + if c.Hooks == nil { + return nil + } + return c.Hooks.Validate() +} + func (c *Config) GetModel(provider, model string) *catwalk.Model { if providerConfig, ok := c.Providers.Get(provider); ok { for _, m := range providerConfig.Models { diff --git a/internal/config/hooks.go b/internal/config/hooks.go new file mode 100644 index 0000000000000000000000000000000000000000..4d3a6c7a1b0199da5849e17f1bffe8cc7ff406f2 --- /dev/null +++ b/internal/config/hooks.go @@ -0,0 +1,100 @@ +package config + +import ( + "fmt" + "log/slog" +) + +// HookEventType represents the lifecycle event when a hook should run. +type HookEventType string + +const ( + // PreToolUse runs before tool calls and can block them. + PreToolUse HookEventType = "pre_tool_use" + // PostToolUse runs after tool calls complete. + PostToolUse HookEventType = "post_tool_use" + // UserPromptSubmit runs when the user submits a prompt, before processing. + UserPromptSubmit HookEventType = "user_prompt_submit" + // Stop runs when Crush finishes responding. + Stop HookEventType = "stop" + // SubagentStop runs when subagent tasks complete. + SubagentStop HookEventType = "subagent_stop" + // PreCompact runs before running a compact operation. + PreCompact HookEventType = "pre_compact" + // PermissionRequested runs when a permission is requested from the user. + PermissionRequested HookEventType = "permission_requested" +) + +// Hook represents a single hook command configuration. +type Hook struct { + // Type is the hook type: "command" or "prompt". + Type string `json:"type" jsonschema:"description=Hook type,enum=command,enum=prompt,default=command"` + // Command is the shell command to execute (for type: "command"). + // WARNING: Hook commands execute with Crush's full permissions. Only use trusted commands. + Command string `json:"command,omitempty" jsonschema:"description=Shell command to execute for this hook (executes with Crush's permissions),example=echo 'Hook executed'"` + // Prompt is the LLM prompt to execute (for type: "prompt"). + // Use $ARGUMENTS placeholder to include hook context JSON. + Prompt string `json:"prompt,omitempty" jsonschema:"description=LLM prompt for intelligent decision making,example=Analyze if all tasks are complete. Context: $ARGUMENTS. Return JSON with decision and reason."` + // Timeout is the maximum time in seconds to wait for the hook to complete. + // Default is 30 seconds. + Timeout *int `json:"timeout,omitempty" jsonschema:"description=Maximum time in seconds to wait for hook completion,default=30,minimum=1,maximum=300"` +} + +// Validate checks hook configuration invariants. +func (h *Hook) Validate() error { + switch h.Type { + case "prompt": + if h.Prompt == "" { + return fmt.Errorf("prompt-based hook missing 'prompt' field") + } + case "", "command": + if h.Command == "" { + return fmt.Errorf("command-based hook missing 'command' field") + } + default: + return fmt.Errorf("unsupported hook type: %s", h.Type) + } + if h.Timeout != nil { + if *h.Timeout < 1 { + slog.Warn("Hook timeout too low, using minimum", + "configured", *h.Timeout, "minimum", 1) + v := 1 + h.Timeout = &v + } + if *h.Timeout > 300 { + slog.Warn("Hook timeout too high, using maximum", + "configured", *h.Timeout, "maximum", 300) + v := 300 + h.Timeout = &v + } + } + return nil +} + +// HookMatcher represents a matcher for a specific event type. +type HookMatcher struct { + // Matcher is the tool name or pattern to match (for tool events). + // For non-tool events, this can be empty or "*" to match all. + // Supports pipe-separated tool names like "edit|write|multiedit". + Matcher string `json:"matcher,omitempty" jsonschema:"description=Tool name or pattern to match (e.g. 'bash' 'edit|write' for multiple or '*' for all),example=bash,example=edit|write|multiedit,example=*"` + // Hooks is the list of hooks to execute when the matcher matches. + Hooks []Hook `json:"hooks" jsonschema:"required,description=List of hooks to execute when matcher matches"` +} + +// HookConfig holds the complete hook configuration. +type HookConfig map[HookEventType][]HookMatcher + +// Validate validates the entire hook configuration. +func (c HookConfig) Validate() error { + for eventType, matchers := range c { + for i, matcher := range matchers { + for j := range matcher.Hooks { + if err := matcher.Hooks[j].Validate(); err != nil { + return fmt.Errorf("invalid hook config for %s matcher %d hook %d: %w", + eventType, i, j, err) + } + } + } + } + return nil +} diff --git a/internal/config/load.go b/internal/config/load.go index a766f838225692d0c4f732043f04613899e35a40..1a45b43f46bfb709430f0f68ae3a6882ef8e5cd7 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -81,6 +81,11 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { } cfg.knownProviders = providers + // Validate hooks configuration. + if err := cfg.validateHooks(); err != nil { + return nil, fmt.Errorf("invalid hooks configuration: %w", err) + } + env := env.New() // Configure providers valueResolver := NewShellVariableResolver(env) diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go index f10b9d5e2c47ec90aec9dc0f206d4a157fa7f6b0..e184379d65e8a41cfd6232bdbd51decd7f57b9ff 100644 --- a/internal/db/messages.sql.go +++ b/internal/db/messages.sql.go @@ -19,12 +19,13 @@ INSERT INTO messages ( model, provider, is_summary_message, + hook_outputs, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) -RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message +RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs ` type CreateMessageParams struct { @@ -35,6 +36,7 @@ type CreateMessageParams struct { Model sql.NullString `json:"model"` Provider sql.NullString `json:"provider"` IsSummaryMessage int64 `json:"is_summary_message"` + HookOutputs string `json:"hook_outputs"` } func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) { @@ -46,6 +48,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M arg.Model, arg.Provider, arg.IsSummaryMessage, + arg.HookOutputs, ) var i Message err := row.Scan( @@ -59,6 +62,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M &i.FinishedAt, &i.Provider, &i.IsSummaryMessage, + &i.HookOutputs, ) return i, err } @@ -84,7 +88,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e } const getMessage = `-- name: GetMessage :one -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs FROM messages WHERE id = ? LIMIT 1 ` @@ -103,12 +107,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) { &i.FinishedAt, &i.Provider, &i.IsSummaryMessage, + &i.HookOutputs, ) return i, err } const listMessagesBySession = `-- name: ListMessagesBySession :many -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, hook_outputs FROM messages WHERE session_id = ? ORDER BY created_at ASC @@ -134,6 +139,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) ( &i.FinishedAt, &i.Provider, &i.IsSummaryMessage, + &i.HookOutputs, ); err != nil { return nil, err } @@ -153,17 +159,24 @@ UPDATE messages SET parts = ?, finished_at = ?, + hook_outputs = ?, updated_at = strftime('%s', 'now') WHERE id = ? ` type UpdateMessageParams struct { - Parts string `json:"parts"` - FinishedAt sql.NullInt64 `json:"finished_at"` - ID string `json:"id"` + Parts string `json:"parts"` + FinishedAt sql.NullInt64 `json:"finished_at"` + HookOutputs string `json:"hook_outputs"` + ID string `json:"id"` } func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error { - _, err := q.exec(ctx, q.updateMessageStmt, updateMessage, arg.Parts, arg.FinishedAt, arg.ID) + _, err := q.exec(ctx, q.updateMessageStmt, updateMessage, + arg.Parts, + arg.FinishedAt, + arg.HookOutputs, + arg.ID, + ) return err } diff --git a/internal/db/migrations/20250811000000_add_hook_outputs_to_messages.sql b/internal/db/migrations/20250811000000_add_hook_outputs_to_messages.sql new file mode 100644 index 0000000000000000000000000000000000000000..790071594a14e0e3c21abccae9eb953bf5acc513 --- /dev/null +++ b/internal/db/migrations/20250811000000_add_hook_outputs_to_messages.sql @@ -0,0 +1,5 @@ +-- +goose Up +ALTER TABLE messages ADD COLUMN hook_outputs TEXT DEFAULT '[]' NOT NULL; + +-- +goose Down +ALTER TABLE messages DROP COLUMN hook_outputs; diff --git a/internal/db/models.go b/internal/db/models.go index ddced85da6628097d981b219ef8c768f50474c85..d1bc6a0ec6865324f94437cb7538b74a9b6354df 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -29,6 +29,7 @@ type Message struct { FinishedAt sql.NullInt64 `json:"finished_at"` Provider sql.NullString `json:"provider"` IsSummaryMessage int64 `json:"is_summary_message"` + HookOutputs string `json:"hook_outputs"` } type Session struct { diff --git a/internal/db/sql/messages.sql b/internal/db/sql/messages.sql index fc66b78c08b85c8fe1f7ec79985fb2edd4a03668..dfd605e52368d24a17f85b28f8264a14af6a3085 100644 --- a/internal/db/sql/messages.sql +++ b/internal/db/sql/messages.sql @@ -18,10 +18,11 @@ INSERT INTO messages ( model, provider, is_summary_message, + hook_outputs, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) RETURNING *; @@ -30,6 +31,7 @@ UPDATE messages SET parts = ?, finished_at = ?, + hook_outputs = ?, updated_at = strftime('%s', 'now') WHERE id = ?; diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go new file mode 100644 index 0000000000000000000000000000000000000000..8d3f6101f6b15568c4dbf14426da672bed77d30c --- /dev/null +++ b/internal/hooks/hooks.go @@ -0,0 +1,415 @@ +package hooks + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "regexp" + "strings" + "sync" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/shell" +) + +const DefaultHookTimeout = 30 * time.Second + +// HookContext contains context information passed to hooks. +// Fields are populated based on event type - see field comments for availability. +type HookContext struct { + // EventType is the lifecycle event that triggered this hook. + // Available: All events + EventType config.HookEventType `json:"event_type"` + + // SessionID identifies the Crush session. + // Available: All events (if session exists) + SessionID string `json:"session_id,omitempty"` + + // TranscriptPath is the path to exported session transcript JSON file. + // Available: All events (if session exists and export succeeds) + TranscriptPath string `json:"transcript_path,omitempty"` + + // ToolName is the name of the tool being called. + // Available: pre_tool_use, post_tool_use, permission_requested + ToolName string `json:"tool_name,omitempty"` + + // ToolInput contains the tool's input parameters as key-value pairs. + // Available: pre_tool_use, post_tool_use + ToolInput map[string]any `json:"tool_input,omitempty"` + + // ToolResult is the string result returned by the tool. + // Available: post_tool_use + ToolResult string `json:"tool_result,omitempty"` + + // ToolError indicates whether the tool execution failed. + // Available: post_tool_use + ToolError bool `json:"tool_error,omitempty"` + + // UserPrompt is the prompt submitted by the user. + // Available: user_prompt_submit + UserPrompt string `json:"user_prompt,omitempty"` + + // Timestamp is when the hook was triggered (RFC3339 format). + // Available: All events + Timestamp time.Time `json:"timestamp"` + + // WorkingDir is the project working directory. + // Available: All events + WorkingDir string `json:"working_dir,omitempty"` + + // MessageID identifies the assistant message being processed. + // Available: pre_tool_use, post_tool_use, stop + MessageID string `json:"message_id,omitempty"` + + // Provider is the LLM provider name (e.g., "anthropic"). + // Available: All events with LLM interaction + Provider string `json:"provider,omitempty"` + + // Model is the LLM model name (e.g., "claude-3-5-sonnet-20241022"). + // Available: All events with LLM interaction + Model string `json:"model,omitempty"` + + // TokensUsed is the total tokens consumed in this interaction. + // Available: stop + TokensUsed int64 `json:"tokens_used,omitempty"` + + // TokensInput is the input tokens consumed in this interaction. + // Available: stop + TokensInput int64 `json:"tokens_input,omitempty"` + + // PermissionAction is the permission action being requested (e.g., "read", "write"). + // Available: permission_requested + PermissionAction string `json:"permission_action,omitempty"` + + // PermissionPath is the file path involved in the permission request. + // Available: permission_requested + PermissionPath string `json:"permission_path,omitempty"` + + // PermissionParams contains additional permission parameters. + // Available: permission_requested + PermissionRequest *permission.PermissionRequest `json:"permission_request,omitempty"` +} + +type HookDecision string + +const ( + HookDecisionBlock HookDecision = "block" + HookDecisionDeny HookDecision = "deny" + HookDecisionAllow HookDecision = "allow" + HookDecisionAsk HookDecision = "ask" +) + +type Service interface { + Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error) + SetSmallModel(model fantasy.LanguageModel) +} + +type service struct { + config config.HookConfig + workingDir string + regexCache *csync.Map[string, *regexp.Regexp] + smallModel fantasy.LanguageModel + messages message.Service +} + +// NewService creates a new hook executor. +func NewService( + hookConfig config.HookConfig, + workingDir string, + smallModel fantasy.LanguageModel, + messages message.Service, +) Service { + return &service{ + config: hookConfig, + workingDir: workingDir, + regexCache: csync.NewMap[string, *regexp.Regexp](), + smallModel: smallModel, + messages: messages, + } +} + +// Execute implements Service. +func (s *service) Execute(ctx context.Context, hookCtx HookContext) ([]message.HookOutput, error) { + if s.config == nil { + return nil, nil + } + + // Check if context is already cancelled - prevents race conditions during cancellation. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + hookCtx.Timestamp = time.Now() + hookCtx.WorkingDir = s.workingDir + + hooks := s.collectMatchingHooks(hookCtx) + if len(hooks) == 0 { + return nil, nil + } + + transcriptPath, cleanup, err := s.setupTranscript(ctx, hookCtx) + if err != nil { + slog.Warn("Failed to export transcript for hooks", "error", err) + } else if transcriptPath != "" { + hookCtx.TranscriptPath = transcriptPath + // Ensure cleanup happens even on panic. + defer func() { + cleanup() + }() + } + + results := make([]message.HookOutput, len(hooks)) + var wg sync.WaitGroup + + for i, hook := range hooks { + if ctx.Err() != nil { + return nil, ctx.Err() + } + wg.Add(1) + go func(idx int, h config.Hook) { + defer wg.Done() + result, err := s.executeHook(ctx, h, hookCtx) + if err != nil { + slog.Warn("Hook execution failed", + "event", hookCtx.EventType, + "error", err, + ) + } else { + results[idx] = *result + } + }(i, hook) + } + wg.Wait() + return results, nil +} + +func (s *service) setupTranscript(ctx context.Context, hookCtx HookContext) (string, func(), error) { + if hookCtx.SessionID == "" || s.messages == nil { + return "", func() {}, nil + } + + path, err := exportTranscript(ctx, s.messages, hookCtx.SessionID) + if err != nil { + return "", func() {}, err + } + + cleanup := func() { + if path != "" { + cleanupTranscript(path) + } + } + + return path, cleanup, nil +} + +func (s *service) executeHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) { + var result *message.HookOutput + var err error + + switch hook.Type { + case "prompt": + if hook.Prompt == "" { + return nil, fmt.Errorf("prompt-based hook missing 'prompt' field") + } + if s.smallModel == nil { + return nil, fmt.Errorf("prompt-based hook requires small model configuration") + } + result, err = s.executePromptHook(ctx, hook, hookCtx) + case "", "command": + if hook.Command == "" { + return nil, fmt.Errorf("command-based hook missing 'command' field") + } + slog.Info("executing") + result, err = s.executeCommandHook(ctx, hook, hookCtx) + default: + return nil, fmt.Errorf("unsupported hook type: %s", hook.Type) + } + + return result, err +} + +func (s *service) executePromptHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) { + panic("not implemented") +} + +func (s *service) executeCommandHook(ctx context.Context, hook config.Hook, hookCtx HookContext) (*message.HookOutput, error) { + timeout := DefaultHookTimeout + if hook.Timeout != nil { + timeout = time.Duration(*hook.Timeout) * time.Second + } + + execCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + contextJSON, err := json.Marshal(hookCtx) + if err != nil { + return nil, fmt.Errorf("failed to marshal hook context: %w", err) + } + + sh := shell.NewShell(&shell.Options{ + WorkingDir: s.workingDir, + }) + + sh.SetEnv("CRUSH_HOOK_EVENT", string(hookCtx.EventType)) + sh.SetEnv("CRUSH_HOOK_CONTEXT", string(contextJSON)) + sh.SetEnv("CRUSH_PROJECT_DIR", s.workingDir) + if hookCtx.SessionID != "" { + sh.SetEnv("CRUSH_SESSION_ID", hookCtx.SessionID) + } + if hookCtx.ToolName != "" { + sh.SetEnv("CRUSH_TOOL_NAME", hookCtx.ToolName) + } + + slog.Debug("Hook execution trace", + "event", hookCtx.EventType, + "command", hook.Command, + "timeout", timeout, + "context", hookCtx, + ) + + stdout, stderr, err := sh.ExecWithStdin(execCtx, hook.Command, string(contextJSON)) + + exitCode := shell.ExitCode(err) + + var result *message.HookOutput + switch exitCode { + case 0: + // if the event is UserPromptSubmit we want the output to be added to the context + if hookCtx.EventType == config.UserPromptSubmit { + result = &message.HookOutput{ + AdditionalContext: stdout, + } + } else { + result = &message.HookOutput{ + Message: stdout, + } + } + case 2: + result = &message.HookOutput{ + Decision: string(HookDecisionBlock), + Error: stderr, + } + return result, nil + default: + result = &message.HookOutput{ + Error: stderr, + } + return result, nil + } + + jsonOutput := parseHookOutput(stdout) + if jsonOutput == nil { + return result, nil + } + + result.Message = jsonOutput.Message + result.Stop = jsonOutput.Stop + result.Decision = jsonOutput.Decision + result.AdditionalContext = jsonOutput.AdditionalContext + result.UpdatedInput = jsonOutput.UpdatedInput + + // Trace output in debug mode + slog.Debug("Hook execution output", + "event", hookCtx.EventType, + "exit_code", exitCode, + "stdout_length", len(stdout), + "stderr_length", len(stderr), + "stdout", stdout, + "stderr", stderr, + ) + return result, nil +} + +func parseHookOutput(stdout string) *message.HookOutput { + stdout = strings.TrimSpace(stdout) + slog.Info(stdout) + if stdout == "" { + return nil + } + + var output message.HookOutput + if err := json.Unmarshal([]byte(stdout), &output); err != nil { + // Failed to parse as HookOutput + return nil + } + + return &output +} + +// SetSmallModel implements Service. +func (s *service) SetSmallModel(model fantasy.LanguageModel) { + panic("unimplemented") +} + +func (s *service) collectMatchingHooks(hookCtx HookContext) []config.Hook { + matchers, ok := s.config[hookCtx.EventType] + if !ok || len(matchers) == 0 { + return nil + } + + var hooks []config.Hook + for _, matcher := range matchers { + if !s.matcherApplies(matcher, hookCtx) { + continue + } + hooks = append(hooks, matcher.Hooks...) + } + return hooks +} + +func (s *service) matcherApplies(matcher config.HookMatcher, ctx HookContext) bool { + if ctx.EventType == config.PreToolUse || ctx.EventType == config.PostToolUse { + return s.matchesToolName(matcher.Matcher, ctx.ToolName) + } + + return matcher.Matcher == "" || matcher.Matcher == "*" +} + +func (s *service) matchesToolName(pattern, toolName string) bool { + if pattern == "" || pattern == "*" { + return true + } + + if pattern == toolName { + return true + } + + if strings.Contains(pattern, "|") { + for tool := range strings.SplitSeq(pattern, "|") { + tool = strings.TrimSpace(tool) + if tool == toolName { + return true + } + } + + return s.matchesRegex(pattern, toolName) + } + + return s.matchesRegex(pattern, toolName) +} + +func (s *service) matchesRegex(pattern, text string) bool { + re, ok := s.regexCache.Get(pattern) + if !ok { + compiled, err := regexp.Compile(pattern) + if err != nil { + // Not a valid regex, don't cache failures. + return false + } + re = s.regexCache.GetOrSet(pattern, func() *regexp.Regexp { + return compiled + }) + } + + if re == nil { + return false + } + + return re.MatchString(text) +} diff --git a/internal/hooks/transcript.go b/internal/hooks/transcript.go new file mode 100644 index 0000000000000000000000000000000000000000..9471e24b84b8b2cd9f3ca6d68f420bcac8daaa23 --- /dev/null +++ b/internal/hooks/transcript.go @@ -0,0 +1,133 @@ +package hooks + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/charmbracelet/crush/internal/message" +) + +// TranscriptMessage represents a message in the exported transcript. +type TranscriptMessage struct { + ID string `json:"id"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + ToolCalls []TranscriptToolCall `json:"tool_calls,omitempty"` + ToolResults []TranscriptToolResult `json:"tool_results,omitempty"` + Timestamp string `json:"timestamp"` +} + +// TranscriptToolCall represents a tool call in the transcript. +type TranscriptToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Input string `json:"input"` +} + +// TranscriptToolResult represents a tool result in the transcript. +type TranscriptToolResult struct { + ToolCallID string `json:"tool_call_id"` + Name string `json:"name"` + Content string `json:"content"` + IsError bool `json:"is_error"` +} + +// Transcript represents the complete transcript structure. +type Transcript struct { + SessionID string `json:"session_id"` + Messages []TranscriptMessage `json:"messages"` +} + +// exportTranscript exports session messages to a temporary JSON file. +func exportTranscript( + ctx context.Context, + messages message.Service, + sessionID string, +) (string, error) { + // Get all messages for the session + msgs, err := messages.List(ctx, sessionID) + if err != nil { + return "", fmt.Errorf("failed to list messages: %w", err) + } + + // Convert to transcript format + transcript := Transcript{ + SessionID: sessionID, + Messages: make([]TranscriptMessage, 0, len(msgs)), + } + + for _, msg := range msgs { + tm := TranscriptMessage{ + ID: msg.ID, + Role: string(msg.Role), + Timestamp: time.Unix(msg.CreatedAt, 0).Format("2006-01-02T15:04:05Z07:00"), + } + + // Extract content + for _, part := range msg.Parts { + if text, ok := part.(message.TextContent); ok { + if tm.Content != "" { + tm.Content += "\n" + } + tm.Content += text.Text + } + } + + // Extract tool calls + if msg.Role == message.Assistant { + toolCalls := msg.ToolCalls() + for _, tc := range toolCalls { + tm.ToolCalls = append(tm.ToolCalls, TranscriptToolCall{ + ID: tc.ID, + Name: tc.Name, + Input: tc.Input, + }) + } + } + + // Extract tool results + if msg.Role == message.Tool { + toolResults := msg.ToolResults() + for _, tr := range toolResults { + tm.ToolResults = append(tm.ToolResults, TranscriptToolResult{ + ToolCallID: tr.ToolCallID, + Name: tr.Name, + Content: tr.Content, + IsError: tr.IsError, + }) + } + } + + transcript.Messages = append(transcript.Messages, tm) + } + + // Marshal to JSON + data, err := json.MarshalIndent(transcript, "", " ") + if err != nil { + return "", fmt.Errorf("failed to marshal transcript: %w", err) + } + + // Write to temporary file + tmpDir := os.TempDir() + filename := fmt.Sprintf("crush-transcript-%s.json", sessionID) + path := filepath.Join(tmpDir, filename) + + // Use restrictive permissions (0600) + if err := os.WriteFile(path, data, 0o600); err != nil { + return "", fmt.Errorf("failed to write transcript file: %w", err) + } + + return path, nil +} + +func cleanupTranscript(path string) { + if path != "" { + if err := os.Remove(path); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to cleanup transcript file %s: %v\n", path, err) + } + } +} diff --git a/internal/message/content.go b/internal/message/content.go index 7f35678230759ab3dcfc13287d340d6f0327d722..00195debad6ead000ffc659ee2c5e2c508e3b5a4 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -133,6 +133,16 @@ type Message struct { CreatedAt int64 UpdatedAt int64 IsSummaryMessage bool + HookOutputs []HookOutput +} + +type HookOutput struct { + Stop bool `json:"stop"` + Error string `json:"error"` + Message string `json:"message"` + Decision string `json:"decision"` + UpdatedInput string `json:"updated_input"` + AdditionalContext string `json:"additional_context"` } func (m *Message) Content() TextContent { @@ -144,6 +154,23 @@ func (m *Message) Content() TextContent { return TextContent{} } +func (m *Message) ContentWithHooksContext() string { + text := strings.TrimSpace(m.Content().Text) + + var additionalContext []string + for _, hookOutput := range m.HookOutputs { + context := strings.TrimSpace(hookOutput.AdditionalContext) + if context != "" { + additionalContext = append(additionalContext, context) + } + } + if len(additionalContext) > 0 { + text += "## Additional Context\n" + text += strings.Join(additionalContext, "\n") + } + return text +} + func (m *Message) ReasoningContent() ReasoningContent { for _, part := range m.Parts { if c, ok := part.(ReasoningContent); ok { @@ -202,6 +229,11 @@ func (m *Message) IsFinished() bool { return false } +// AddHookOutputs appends multiple hook outputs to the message's hook outputs. +func (m *Message) AddHookOutputs(outputs ...HookOutput) { + m.HookOutputs = append(m.HookOutputs, outputs...) +} + func (m *Message) FinishPart() *Finish { for _, part := range m.Parts { if c, ok := part.(Finish); ok { @@ -429,7 +461,7 @@ func (m *Message) ToAIMessage() []fantasy.Message { switch m.Role { case User: var parts []fantasy.MessagePart - text := strings.TrimSpace(m.Content().Text) + text := strings.TrimSpace(m.ContentWithHooksContext()) if text != "" { parts = append(parts, fantasy.TextPart{Text: text}) } diff --git a/internal/message/message.go b/internal/message/message.go index 4cdf89b54f8eaf831d53a5fc51fdb5c71b4b953c..accd9ed7a39c43f2935763842d0b1f31acc36f90 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -18,6 +18,7 @@ type CreateMessageParams struct { Model string Provider string IsSummaryMessage bool + HookOutputs []HookOutput } type Service interface { @@ -69,6 +70,10 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes if params.IsSummaryMessage { isSummary = 1 } + hookOutputsJSON, err := json.Marshal(params.HookOutputs) + if err != nil { + return Message{}, err + } dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{ ID: uuid.New().String(), SessionID: sessionID, @@ -77,6 +82,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes Model: sql.NullString{String: string(params.Model), Valid: true}, Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""}, IsSummaryMessage: isSummary, + HookOutputs: string(hookOutputsJSON), }) if err != nil { return Message{}, err @@ -115,10 +121,15 @@ func (s *service) Update(ctx context.Context, message Message) error { finishedAt.Int64 = f.Time finishedAt.Valid = true } + hookOutputsJSON, err := json.Marshal(message.HookOutputs) + if err != nil { + return err + } err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{ - ID: message.ID, - Parts: string(parts), - FinishedAt: finishedAt, + ID: message.ID, + Parts: string(parts), + FinishedAt: finishedAt, + HookOutputs: string(hookOutputsJSON), }) if err != nil { return err @@ -156,6 +167,12 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { if err != nil { return Message{}, err } + var hookOutputs []HookOutput + if item.HookOutputs != "" { + if err := json.Unmarshal([]byte(item.HookOutputs), &hookOutputs); err != nil { + return Message{}, err + } + } return Message{ ID: item.ID, SessionID: item.SessionID, @@ -166,6 +183,7 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, IsSummaryMessage: item.IsSummaryMessage != 0, + HookOutputs: hookOutputs, }, nil } diff --git a/internal/shell/shell.go b/internal/shell/shell.go index f9f4656b82bbb6ee14b38469a20d493d98354b4a..173643aa8f43ca816165226afd5c59d1ef95fd25 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -111,6 +111,15 @@ func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr i return s.execStream(ctx, command, stdout, stderr) } +func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin string) (string, string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var stdout, stderr bytes.Buffer + err := s.execCommonWithStdin(ctx, command, strings.NewReader(stdin), &stdout, &stderr) + return stdout.String(), stderr.String(), err +} + // GetWorkingDir returns the current working directory func (s *Shell) GetWorkingDir() string { s.mu.Lock() @@ -237,9 +246,9 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand } // newInterp creates a new interpreter with the current shell state -func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) { +func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) { return interp.New( - interp.StdIO(nil, stdout, stderr), + interp.StdIO(stdin, stdout, stderr), interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), @@ -263,7 +272,7 @@ func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr i return fmt.Errorf("could not parse command: %w", err) } - runner, err := s.newInterp(stdout, stderr) + runner, err := s.newInterp(nil, stdout, stderr) if err != nil { return fmt.Errorf("could not run command: %w", err) } @@ -286,6 +295,24 @@ func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr i return s.execCommon(ctx, command, stdout, stderr) } +// execCommonWithStdin is like execCommon but with stdin support +func (s *Shell) execCommonWithStdin(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error { + line, err := syntax.NewParser().Parse(strings.NewReader(command), "") + if err != nil { + return fmt.Errorf("could not parse command: %w", err) + } + + runner, err := s.newInterp(stdin, stdout, stderr) + if err != nil { + return fmt.Errorf("could not run command: %w", err) + } + + err = runner.Run(ctx, line) + s.updateShellFromRunner(runner) + s.logger.InfoPersist("command finished", "command", command, "err", err) + return err +} + func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{ s.blockHandler(),