From 4b9becadbc9651ce66079873824323ae89985d9d Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 21 Nov 2025 12:55:29 +0100 Subject: [PATCH] chore: add first hook implementation --- internal/agent/agent.go | 116 +++++++++++++-- internal/agent/agent_tool.go | 2 +- internal/agent/common_test.go | 15 +- internal/agent/coordinator.go | 55 +++---- internal/agent/errors.go | 9 +- internal/app/app.go | 1 + internal/db/messages.sql.go | 23 ++- ...0251120000000_add_metadata_to_messages.sql | 5 + internal/db/models.go | 1 + internal/db/sql/messages.sql | 4 +- internal/hooks/config.go | 8 +- internal/hooks/examples_test.go | 136 +++++------------- internal/hooks/manager.go | 59 +++++++- internal/hooks/manager_test.go | 90 +++--------- internal/hooks/types.go | 59 ++++++-- internal/message/content.go | 38 ++++- internal/message/message.go | 16 +++ 17 files changed, 385 insertions(+), 252 deletions(-) create mode 100644 internal/db/migrations/20251120000000_add_metadata_to_messages.sql diff --git a/internal/agent/agent.go b/internal/agent/agent.go index ec5bc19ba4efaf0cc15f46620711621a92dff2b9..99263093dff00bb1b0a63324c10354bc6da44756 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -30,6 +30,7 @@ import ( "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/hooks" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/session" @@ -85,6 +86,9 @@ type sessionAgent struct { messages message.Service disableAutoSummarize bool isYolo bool + isSubAgent bool + hooksManager hooks.Manager + workingDir string messageQueue *csync.Map[string, []SessionAgentCall] activeRequests *csync.Map[string, context.CancelFunc] @@ -97,6 +101,9 @@ type SessionAgentOptions struct { SystemPrompt string DisableAutoSummarize bool IsYolo bool + IsSubAgent bool + HooksManager hooks.Manager + WorkingDir string Sessions session.Service Messages message.Service Tools []fantasy.AgentTool @@ -115,6 +122,9 @@ func NewSessionAgent( disableAutoSummarize: opts.DisableAutoSummarize, tools: opts.Tools, isYolo: opts.IsYolo, + isSubAgent: opts.IsSubAgent, + hooksManager: opts.HooksManager, + workingDir: opts.WorkingDir, messageQueue: csync.NewMap[string, []SessionAgentCall](), activeRequests: csync.NewMap[string, context.CancelFunc](), } @@ -172,7 +182,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } // Add the user message to the session. - _, err = a.createUserMessage(ctx, call) + msg, err := a.createUserMessage(ctx, call) if err != nil { return nil, err } @@ -186,15 +196,36 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy defer cancel() defer a.activeRequests.Del(call.SessionID) + // create the agent message asap to show loading + var currentAssistant *message.Message + assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.largeModel.ModelCfg.Model, + Provider: a.largeModel.ModelCfg.Provider, + }) + if err != nil { + return nil, err + } + + currentAssistant = &assistantMessage + + hookErr := a.executePromptSubmitHook(genCtx, &msg, len(msgs) == 0) + if hookErr != nil { + // Delete the assistant message + // use the ctx since this could be a cancellation + deleteErr := a.messages.Delete(ctx, currentAssistant.ID) + return nil, cmp.Or(deleteErr, hookErr) + } + history, files := a.preparePrompt(msgs, call.Attachments...) startTime := time.Now() a.eventPromptSent(call.SessionID) - var currentAssistant *message.Message var shouldSummarize bool result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{ - Prompt: call.Prompt, + Prompt: msg.ContentWithHookContext(), Files: files, Messages: history, ProviderOptions: call.ProviderOptions, @@ -206,6 +237,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy FrequencyPenalty: call.FrequencyPenalty, // Before each step create a new assistant message. PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { + // only add new assistant message when its not the first step + if options.StepNumber != 0 { + var assistantMsg message.Message + assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.largeModel.ModelCfg.Model, + Provider: a.largeModel.ModelCfg.Provider, + }) + currentAssistant = &assistantMsg + // create the message first so we show loading asap + if err != nil { + return callContext, prepared, err + } + } prepared.Messages = options.Messages // Reset all cached items. for i := range prepared.Messages { @@ -219,6 +265,12 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy if createErr != nil { return callContext, prepared, createErr } + + hookErr := a.executePromptSubmitHook(ctx, &msg, len(msgs) == 0) + if hookErr != nil { + return callContext, prepared, hookErr + } + prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } @@ -242,18 +294,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...) } - var assistantMsg message.Message - assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.largeModel.ModelCfg.Model, - Provider: a.largeModel.ModelCfg.Provider, - }) - if err != nil { - return callContext, prepared, err - } - callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID) - currentAssistant = &assistantMsg + callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID) return callContext, prepared, err }, OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error { @@ -882,3 +923,48 @@ func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) { func (a *sessionAgent) Model() Model { return a.largeModel } + +// executePromptSubmitHook executes the user-prompt-submit hook and applies modifications to the call. +// Only runs for main agent (not sub-agents). +func (a *sessionAgent) executePromptSubmitHook(ctx context.Context, msg *message.Message, isFirstMessage bool) error { + // Skip if sub-agent or no hooks manager. + if a.isSubAgent || a.hooksManager == nil { + return nil + } + + // Convert attachments to file paths. + attachmentPaths := make([]string, len(msg.BinaryContent())) + for i, att := range msg.BinaryContent() { + attachmentPaths[i] = att.Path + } + + hookResult, err := a.hooksManager.ExecuteUserPromptSubmit(ctx, msg.SessionID, a.workingDir, hooks.UserPromptSubmitData{ + Prompt: msg.Content().Text, + Attachments: attachmentPaths, + Model: a.largeModel.CatwalkCfg.ID, + Provider: a.largeModel.Model.Provider(), + IsFirstMessage: isFirstMessage, + }) + if err != nil { + return fmt.Errorf("hook execution failed: %w", err) + } + + // Apply hook modifications to the prompt. + if hookResult.ModifiedPrompt != nil { + for i, part := range msg.Parts { + if _, ok := part.(message.TextContent); ok { + msg.Parts[i] = message.TextContent{Text: *hookResult.ModifiedPrompt} + } + } + } + msg.AddHookResult(hookResult) + err = a.messages.Update(ctx, *msg) + if err != nil { + return err + } + // If hook returned Continue: false, stop execution. + if !hookResult.Continue { + return ErrHookExecutionStop + } + return nil +} diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index b937c20dd510eb9dff52bd348bf7baf21aafcf8c..b43197e9c0393ef7dde69512502408310051ba60 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -34,7 +34,7 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) return nil, err } - agent, err := c.buildAgent(ctx, prompt, agentCfg) + agent, err := c.buildAgent(ctx, prompt, agentCfg, true) if err != nil { return nil, err } diff --git a/internal/agent/common_test.go b/internal/agent/common_test.go index 58264e2637f3a44e45d54a66c72ee3b8d6c642a3..def31b5d55de3ab35cf643381e9fd8990dd827b5 100644 --- a/internal/agent/common_test.go +++ b/internal/agent/common_test.go @@ -149,7 +149,20 @@ func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPro DefaultMaxTokens: 10000, }, } - agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools}) + agent := NewSessionAgent(SessionAgentOptions{ + LargeModel: largeModel, + SmallModel: smallModel, + SystemPromptPrefix: "", + SystemPrompt: systemPrompt, + DisableAutoSummarize: false, + IsYolo: true, + IsSubAgent: false, + HooksManager: nil, + WorkingDir: "", + Sessions: env.sessions, + Messages: env.messages, + Tools: tools, + }) return agent } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 4bfcc0062ae9a06dc858989f2cce925976d6d32b..ea1219dcee7d60d039a8693eff9541f8ce40c0b8 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -21,6 +21,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/hooks" "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" @@ -55,12 +56,13 @@ type Coordinator interface { } type coordinator struct { - cfg *config.Config - sessions session.Service - messages message.Service - permissions permission.Service - history history.Service - lspClients *csync.Map[string, *lsp.Client] + cfg *config.Config + sessions session.Service + messages message.Service + permissions permission.Service + history history.Service + lspClients *csync.Map[string, *lsp.Client] + hooksManager hooks.Manager currentAgent SessionAgent agents map[string]SessionAgent @@ -76,15 +78,17 @@ func NewCoordinator( permissions permission.Service, history history.Service, lspClients *csync.Map[string, *lsp.Client], + hooksManager hooks.Manager, ) (Coordinator, error) { c := &coordinator{ - cfg: cfg, - sessions: sessions, - messages: messages, - permissions: permissions, - history: history, - lspClients: lspClients, - agents: make(map[string]SessionAgent), + cfg: cfg, + sessions: sessions, + messages: messages, + permissions: permissions, + history: history, + lspClients: lspClients, + hooksManager: hooksManager, + agents: make(map[string]SessionAgent), } agentCfg, ok := cfg.Agents[config.AgentCoder] @@ -98,7 +102,7 @@ func NewCoordinator( return nil, err } - agent, err := c.buildAgent(ctx, prompt, agentCfg) + agent, err := c.buildAgent(ctx, prompt, agentCfg, false) if err != nil { return nil, err } @@ -274,7 +278,7 @@ func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderO return modelOptions, temp, topP, topK, freqPenalty, presPenalty } -func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) { +func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) { large, small, err := c.buildAgentModels(ctx) if err != nil { return nil, err @@ -287,15 +291,18 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider) result := NewSessionAgent(SessionAgentOptions{ - large, - small, - largeProviderCfg.SystemPromptPrefix, - systemPrompt, - c.cfg.Options.DisableAutoSummarize, - c.permissions.SkipRequests(), - c.sessions, - c.messages, - nil, + LargeModel: large, + SmallModel: small, + SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix, + SystemPrompt: systemPrompt, + DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize, + IsYolo: c.permissions.SkipRequests(), + IsSubAgent: isSubAgent, + HooksManager: c.hooksManager, + WorkingDir: c.cfg.WorkingDir(), + Sessions: c.sessions, + Messages: c.messages, + Tools: nil, }) c.readyWg.Go(func() error { tools, err := c.buildTools(ctx, agent) diff --git a/internal/agent/errors.go b/internal/agent/errors.go index 1b4f0dfce6c8b13a0b22c42ed651ba895e24ed9b..562b69f6da97dd3aaf6dd2f1342939a9ad3e596e 100644 --- a/internal/agent/errors.go +++ b/internal/agent/errors.go @@ -6,10 +6,11 @@ import ( ) var ( - ErrRequestCancelled = errors.New("request canceled by user") - ErrSessionBusy = errors.New("session is currently processing another request") - ErrEmptyPrompt = errors.New("prompt is empty") - ErrSessionMissing = errors.New("session id is missing") + ErrRequestCancelled = errors.New("request canceled by user") + ErrSessionBusy = errors.New("session is currently processing another request") + ErrEmptyPrompt = errors.New("prompt is empty") + ErrSessionMissing = errors.New("session id is missing") + ErrHookExecutionStop = errors.New("hook stopped execution") ) func isCancelledErr(err error) bool { diff --git a/internal/app/app.go b/internal/app/app.go index 1becd7223cb3b4367c1ff080ed88a77ec4525750..3d284604d98abc5047f887e437c35542a0eb39c6 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -333,6 +333,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error { app.Permissions, app.History, app.LSPClients, + app.HooksManager, ) if err != nil { slog.Error("Failed to create coder agent", "err", err) diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go index f10b9d5e2c47ec90aec9dc0f206d4a157fa7f6b0..e96274385c67bc7061dd095fc372b3639536bb38 100644 --- a/internal/db/messages.sql.go +++ b/internal/db/messages.sql.go @@ -19,12 +19,13 @@ INSERT INTO messages ( model, provider, is_summary_message, + metadata, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) -RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message +RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, metadata ` type CreateMessageParams struct { @@ -35,6 +36,7 @@ type CreateMessageParams struct { Model sql.NullString `json:"model"` Provider sql.NullString `json:"provider"` IsSummaryMessage int64 `json:"is_summary_message"` + Metadata string `json:"metadata"` } func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) { @@ -46,6 +48,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M arg.Model, arg.Provider, arg.IsSummaryMessage, + arg.Metadata, ) var i Message err := row.Scan( @@ -59,6 +62,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M &i.FinishedAt, &i.Provider, &i.IsSummaryMessage, + &i.Metadata, ) return i, err } @@ -84,7 +88,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e } const getMessage = `-- name: GetMessage :one -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, metadata FROM messages WHERE id = ? LIMIT 1 ` @@ -103,12 +107,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) { &i.FinishedAt, &i.Provider, &i.IsSummaryMessage, + &i.Metadata, ) return i, err } const listMessagesBySession = `-- name: ListMessagesBySession :many -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message, metadata FROM messages WHERE session_id = ? ORDER BY created_at ASC @@ -134,6 +139,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) ( &i.FinishedAt, &i.Provider, &i.IsSummaryMessage, + &i.Metadata, ); err != nil { return nil, err } @@ -152,6 +158,7 @@ const updateMessage = `-- name: UpdateMessage :exec UPDATE messages SET parts = ?, + metadata = ?, finished_at = ?, updated_at = strftime('%s', 'now') WHERE id = ? @@ -159,11 +166,17 @@ WHERE id = ? type UpdateMessageParams struct { Parts string `json:"parts"` + Metadata string `json:"metadata"` FinishedAt sql.NullInt64 `json:"finished_at"` ID string `json:"id"` } func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error { - _, err := q.exec(ctx, q.updateMessageStmt, updateMessage, arg.Parts, arg.FinishedAt, arg.ID) + _, err := q.exec(ctx, q.updateMessageStmt, updateMessage, + arg.Parts, + arg.Metadata, + arg.FinishedAt, + arg.ID, + ) return err } diff --git a/internal/db/migrations/20251120000000_add_metadata_to_messages.sql b/internal/db/migrations/20251120000000_add_metadata_to_messages.sql new file mode 100644 index 0000000000000000000000000000000000000000..06aea133eebff3189956e14c70b36d5e20e71556 --- /dev/null +++ b/internal/db/migrations/20251120000000_add_metadata_to_messages.sql @@ -0,0 +1,5 @@ +-- +goose Up +ALTER TABLE messages ADD COLUMN metadata TEXT DEFAULT '{}' NOT NULL; + +-- +goose Down +ALTER TABLE messages DROP COLUMN metadata; diff --git a/internal/db/models.go b/internal/db/models.go index ddced85da6628097d981b219ef8c768f50474c85..e9dc7187be7912608fec6fd9fcfb5ea23bfac5f1 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -29,6 +29,7 @@ type Message struct { FinishedAt sql.NullInt64 `json:"finished_at"` Provider sql.NullString `json:"provider"` IsSummaryMessage int64 `json:"is_summary_message"` + Metadata string `json:"metadata"` } type Session struct { diff --git a/internal/db/sql/messages.sql b/internal/db/sql/messages.sql index fc66b78c08b85c8fe1f7ec79985fb2edd4a03668..114ac4e910e2e1752a691c72dbe317d5924f9275 100644 --- a/internal/db/sql/messages.sql +++ b/internal/db/sql/messages.sql @@ -18,10 +18,11 @@ INSERT INTO messages ( model, provider, is_summary_message, + metadata, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) RETURNING *; @@ -29,6 +30,7 @@ RETURNING *; UPDATE messages SET parts = ?, + metadata = ?, finished_at = ?, updated_at = strftime('%s', 'now') WHERE id = ?; diff --git a/internal/hooks/config.go b/internal/hooks/config.go index f9223f17e365e9cd11d6a92b4526c4c93b04f5e7..1cc30462ccdfd0b32c67207a640677e3e40da852 100644 --- a/internal/hooks/config.go +++ b/internal/hooks/config.go @@ -2,8 +2,8 @@ package hooks // Config defines hook system configuration. type Config struct { - // Enabled controls whether hooks are executed. - Enabled bool `json:"enabled,omitempty" jsonschema:"description=Enable or disable hook execution,default=true"` + // Disabled controls whether hooks are executed. + Disabled bool `json:"disabled,omitempty" jsonschema:"description=Disable hook execution,default=false"` // TimeoutSeconds is the maximum time a hook can run. TimeoutSeconds int `json:"timeout_seconds,omitempty" jsonschema:"description=Maximum execution time for hooks in seconds,default=30,example=30"` @@ -16,10 +16,10 @@ type Config struct { // Map key is the hook type (e.g., "pre-tool-use"). Inline map[string][]InlineHook `json:"inline,omitempty" jsonschema:"description=Inline hook scripts defined in configuration"` - // Disabled is a list of hook paths to skip. + // DisableHooks is a list of hook paths to skip. // Paths are relative to the hooks directory. // Example: ["pre-tool-use/02-slow-check.sh"] - Disabled []string `json:"disabled,omitempty" jsonschema:"description=List of hook paths to disable,example=pre-tool-use/02-slow-check.sh"` + DisableHooks []string `json:"disable_hooks,omitempty" jsonschema:"description=List of hook paths to disable,example=pre-tool-use/02-slow-check.sh"` // Environment variables to pass to hooks. Environment map[string]string `json:"environment,omitempty" jsonschema:"description=Environment variables to pass to all hooks"` diff --git a/internal/hooks/examples_test.go b/internal/hooks/examples_test.go index 77651aee1ac5f14516c42c590a1b474f028294dc..83dde9dfcc93a701e7c68a7a383262e28c739987 100644 --- a/internal/hooks/examples_test.go +++ b/internal/hooks/examples_test.go @@ -35,15 +35,11 @@ fi manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) // Test: Should block "rm -rf /" - result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, + result, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{ ToolName: "bash", ToolCallID: "call-1", - Data: map[string]any{ - "tool_input": map[string]any{ - "command": "rm -rf /", - }, + ToolInput: map[string]any{ + "command": "rm -rf /", }, }) @@ -53,15 +49,11 @@ fi assert.Contains(t, result.Message, "Blocked dangerous command") // Test: Should allow safe commands - result2, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, + result2, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{ ToolName: "bash", ToolCallID: "call-2", - Data: map[string]any{ - "tool_input": map[string]any{ - "command": "ls -la", - }, + ToolInput: map[string]any{ + "command": "ls -la", }, }) @@ -94,12 +86,9 @@ esac manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) // Test: Should auto-approve view tool - result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, + result, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{ ToolName: "view", ToolCallID: "call-1", - Data: map[string]any{}, }) require.NoError(t, err) @@ -108,15 +97,11 @@ esac assert.Contains(t, result.Message, "Auto-approved read-only tool") // Test: Should auto-approve safe bash commands - result2, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, + result2, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{ ToolName: "bash", ToolCallID: "call-2", - Data: map[string]any{ - "tool_input": map[string]any{ - "command": "ls -la", - }, + ToolInput: map[string]any{ + "command": "ls -la", }, }) @@ -156,12 +141,8 @@ fi manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{ - "prompt": "help me", - }, + result, err := manager.ExecuteUserPromptSubmit(context.Background(), "test", tempDir, UserPromptSubmitData{ + Prompt: "help me", }) require.NoError(t, err) @@ -188,12 +169,9 @@ echo "$TIMESTAMP|$CRUSH_TOOL_NAME|$CRUSH_TOOL_CALL_ID" >> "$AUDIT_FILE" manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookPostToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, + result, err := manager.ExecutePostToolUse(context.Background(), "test", tempDir, PostToolUseData{ ToolName: "bash", ToolCallID: "call-123", - Data: map[string]any{}, }) require.NoError(t, err) @@ -221,18 +199,10 @@ echo "Hook: $CRUSH_HOOK_TYPE" >> "` + logFile + `" manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) // Test with different hook types - _, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + _, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{}) require.NoError(t, err) - _, err = manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + _, err = manager.ExecuteUserPromptSubmit(context.Background(), "test", tempDir, UserPromptSubmitData{}) require.NoError(t, err) // Verify both hook types were logged @@ -271,11 +241,7 @@ fi manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.False(t, result.Continue, "Should stop execution when rate limit exceeded") @@ -302,11 +268,7 @@ fi manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := manager.ExecuteUserPromptSubmit(context.Background(), "test", tempDir, UserPromptSubmitData{}) require.NoError(t, err) assert.True(t, result.Continue) @@ -330,15 +292,11 @@ echo "{\"modified_input\": {\"command\": \"$SAFE_CMD\"}}" manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, + result, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{ ToolName: "bash", ToolCallID: "call-1", - Data: map[string]any{ - "tool_input": map[string]any{ - "command": "rm --force file.txt", - }, + ToolInput: map[string]any{ + "command": "rm --force file.txt", }, }) @@ -363,11 +321,7 @@ export CRUSH_MESSAGE="Auto-approved" manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.True(t, result.Continue) @@ -403,11 +357,7 @@ fi manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.False(t, result.Continue, "Exit code 2 should stop execution") @@ -442,13 +392,9 @@ export CRUSH_MODIFIED_PROMPT="Enhanced: $PROMPT" manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{ - "prompt": "original prompt", - "model": "gpt-4", - }, + result, err := manager.ExecuteUserPromptSubmit(context.Background(), "test", tempDir, UserPromptSubmitData{ + Prompt: "original prompt", + Model: "gpt-4", }) require.NoError(t, err) @@ -480,25 +426,17 @@ fi manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) // Test: First message - result1, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{ - "prompt": "first prompt", - "is_first_message": true, - }, + result1, err := manager.ExecuteUserPromptSubmit(context.Background(), "test", tempDir, UserPromptSubmitData{ + Prompt: "first prompt", + IsFirstMessage: true, }) require.NoError(t, err) assert.Contains(t, result1.ContextContent, "This is the first message") // Test: Follow-up message - result2, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{ - "prompt": "follow-up prompt", - "is_first_message": false, - }, + result2, err := manager.ExecuteUserPromptSubmit(context.Background(), "test", tempDir, UserPromptSubmitData{ + Prompt: "follow-up prompt", + IsFirstMessage: false, }) require.NoError(t, err) assert.Contains(t, result2.ContextContent, "This is a follow-up message") @@ -532,11 +470,7 @@ export CRUSH_MESSAGE="${CRUSH_MESSAGE:-}; third" manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{}) require.NoError(t, err) // Messages should be merged in order @@ -563,11 +497,7 @@ echo '{"message": "Combined output", "modified_input": {"key": "value"}}' manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) - result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := manager.ExecutePreToolUse(context.Background(), "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.True(t, result.Continue) diff --git a/internal/hooks/manager.go b/internal/hooks/manager.go index b17753647f184b748ab2769c4a38cf06baf516c2..386986d6aa2656edde4b61cf40d5277a162847f1 100644 --- a/internal/hooks/manager.go +++ b/internal/hooks/manager.go @@ -25,10 +25,9 @@ type manager struct { } // NewManager creates a new hook manager. -func NewManager(workingDir, dataDir string, cfg *Config) Manager { +func NewManager(workingDir, dataDir string, cfg *Config) *manager { if cfg == nil { cfg = &Config{ - Enabled: true, TimeoutSeconds: 30, Directories: []string{filepath.Join(dataDir, "hooks")}, } @@ -63,9 +62,9 @@ func isExecutable(info os.FileInfo) bool { return info.Mode()&0o111 != 0 } -// ExecuteHooks implements Manager. -func (m *manager) ExecuteHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) { - if !m.config.Enabled { +// executeHooks is the internal method that executes hooks for a given type. +func (m *manager) executeHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) { + if m.config.Disabled { return HookResult{Continue: true}, nil } @@ -220,7 +219,7 @@ func (m *manager) isDisabled(hookPath string) bool { if rel, err := filepath.Rel(dir, hookPath); err == nil { // Normalize to forward slashes for cross-platform comparison rel = filepath.ToSlash(rel) - if slices.Contains(m.config.Disabled, rel) { + if slices.Contains(m.config.DisableHooks, rel) { return true } } @@ -283,3 +282,51 @@ func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) { func (m *manager) ListHooks(hookType HookType) []string { return m.discoverHooks(hookType) } + +// ExecuteUserPromptSubmit executes user-prompt-submit hooks. +func (m *manager) ExecuteUserPromptSubmit(ctx context.Context, sessionID, workingDir string, data UserPromptSubmitData) (HookResult, error) { + hookCtx := HookContext{ + SessionID: sessionID, + WorkingDir: workingDir, + Data: data, + } + + return m.executeHooks(ctx, HookUserPromptSubmit, hookCtx) +} + +// ExecutePreToolUse executes pre-tool-use hooks. +func (m *manager) ExecutePreToolUse(ctx context.Context, sessionID, workingDir string, data PreToolUseData) (HookResult, error) { + hookCtx := HookContext{ + SessionID: sessionID, + WorkingDir: workingDir, + ToolName: data.ToolName, + ToolCallID: data.ToolCallID, + Data: data, + } + + return m.executeHooks(ctx, HookPreToolUse, hookCtx) +} + +// ExecutePostToolUse executes post-tool-use hooks. +func (m *manager) ExecutePostToolUse(ctx context.Context, sessionID, workingDir string, data PostToolUseData) (HookResult, error) { + hookCtx := HookContext{ + SessionID: sessionID, + WorkingDir: workingDir, + ToolName: data.ToolName, + ToolCallID: data.ToolCallID, + Data: data, + } + + return m.executeHooks(ctx, HookPostToolUse, hookCtx) +} + +// ExecuteStop executes stop hooks. +func (m *manager) ExecuteStop(ctx context.Context, sessionID, workingDir string, data StopData) (HookResult, error) { + hookCtx := HookContext{ + SessionID: sessionID, + WorkingDir: workingDir, + Data: data, + } + + return m.executeHooks(ctx, HookStop, hookCtx) +} diff --git a/internal/hooks/manager_test.go b/internal/hooks/manager_test.go index e8ec12ea9d782408ff3f4e2ad25da4431bb56fef..a0566947ed2ee31a093efb2c40f223da3c93cf45 100644 --- a/internal/hooks/manager_test.go +++ b/internal/hooks/manager_test.go @@ -109,12 +109,8 @@ crush_add_context "Context from hook 2" mgr := NewManager(tempDir, dataDir, nil) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{ - "prompt": "test prompt", - }, + result, err := mgr.ExecuteUserPromptSubmit(ctx, "test", tempDir, UserPromptSubmitData{ + Prompt: "test prompt", }) require.NoError(t, err) @@ -145,11 +141,7 @@ export CRUSH_MESSAGE="should not see this" mgr := NewManager(tempDir, dataDir, nil) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.False(t, result.Continue) @@ -180,11 +172,7 @@ export CRUSH_PERMISSION=deny mgr := NewManager(tempDir, dataDir, nil) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.Equal(t, "deny", result.Permission) @@ -211,19 +199,14 @@ export CRUSH_MESSAGE="disabled" require.NoError(t, err) cfg := &Config{ - Enabled: true, TimeoutSeconds: 30, Directories: []string{filepath.Join(dataDir, "hooks")}, - Disabled: []string{"pre-tool-use/02-disabled.sh"}, + DisableHooks: []string{"pre-tool-use/02-disabled.sh"}, } mgr := NewManager(tempDir, dataDir, cfg) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.Equal(t, "enabled", result.Message) @@ -234,7 +217,6 @@ export CRUSH_MESSAGE="disabled" dataDir := filepath.Join(tempDir, ".crush") cfg := &Config{ - Enabled: true, TimeoutSeconds: 30, Directories: []string{filepath.Join(dataDir, "hooks")}, Inline: map[string][]InlineHook{ @@ -251,11 +233,7 @@ export CRUSH_MESSAGE="inline hook executed" mgr := NewManager(tempDir, dataDir, cfg) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecuteUserPromptSubmit(ctx, "test", tempDir, UserPromptSubmitData{}) require.NoError(t, err) assert.Equal(t, "inline hook executed", result.Message) @@ -265,17 +243,11 @@ export CRUSH_MESSAGE="inline hook executed" tempDir := t.TempDir() dataDir := filepath.Join(tempDir, ".crush") - cfg := &Config{ - Enabled: false, - } + cfg := &Config{} mgr := NewManager(tempDir, dataDir, cfg) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.True(t, result.Continue) @@ -288,11 +260,7 @@ export CRUSH_MESSAGE="inline hook executed" mgr := NewManager(tempDir, dataDir, nil) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.True(t, result.Continue) @@ -313,11 +281,7 @@ exit 1 mgr := NewManager(tempDir, dataDir, nil) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.False(t, result.Continue) @@ -372,22 +336,14 @@ export CRUSH_MESSAGE="$CRUSH_MESSAGE; specific hook" // Test PreToolUse - should execute both catch-all and specific. ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.Contains(t, result.Message, "catch-all: pre-tool-use") assert.Contains(t, result.Message, "specific hook") // Test UserPromptSubmit - should only execute catch-all. - result2, err := mgr.ExecuteHooks(ctx, HookUserPromptSubmit, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result2, err := mgr.ExecuteUserPromptSubmit(ctx, "test", tempDir, UserPromptSubmitData{}) require.NoError(t, err) assert.Equal(t, "catch-all: user-prompt-submit", result2.Message) @@ -412,7 +368,6 @@ fi require.NoError(t, err) cfg := &Config{ - Enabled: true, TimeoutSeconds: 30, Directories: []string{filepath.Join(dataDir, "hooks")}, Environment: map[string]string{ @@ -423,11 +378,7 @@ fi mgr := NewManager(tempDir, dataDir, cfg) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.Equal(t, "config environment variables received", result.Message) @@ -440,7 +391,6 @@ fi require.NoError(t, os.MkdirAll(readOnlyDir, 0o555)) // Read-only cfg := &Config{ - Enabled: true, TimeoutSeconds: 30, Directories: []string{filepath.Join(readOnlyDir, "hooks")}, Inline: map[string][]InlineHook{ @@ -458,11 +408,7 @@ fi // Should not error even though inline hook write fails. // The hook will be skipped and logged. - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.True(t, result.Continue) // Should continue despite write failure @@ -511,11 +457,7 @@ export CRUSH_MESSAGE="auto-approved" mgr := NewManager(tempDir, dataDir, nil) ctx := context.Background() - result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ - SessionID: "test", - WorkingDir: tempDir, - Data: map[string]any{}, - }) + result, err := mgr.ExecutePreToolUse(ctx, "test", tempDir, PreToolUseData{}) require.NoError(t, err) assert.Equal(t, "approve", result.Permission) diff --git a/internal/hooks/types.go b/internal/hooks/types.go index 5fc2b6196efd2fd1adcc73809def33a944d33a17..24e4b30a65ab7d34fe980a892e62b6919aba3696 100644 --- a/internal/hooks/types.go +++ b/internal/hooks/types.go @@ -40,7 +40,7 @@ type HookContext struct { // For PreToolUse: tool_name, tool_call_id, tool_input // For PostToolUse: tool_name, tool_call_id, tool_input, tool_output, execution_time_ms // For Stop: reason - Data map[string]any + Data any // ToolName is the tool name (for tool hooks only). ToolName string @@ -54,40 +54,73 @@ type HookContext struct { // HookResult contains the result of hook execution. type HookResult struct { + // HookType the hook type + HookType HookType `json:"hook_type"` + // Name the name of the hook (usually the file name) + Name string `json:"name"` + // Path hook path + Path string `json:"path"` + // AllResults stores all results for this event + AllResults []HookResult `json:"all_results,omitempty"` // Continue indicates whether to continue execution. // If false, execution stops. - Continue bool + Continue bool `json:"continue"` // Permission decision (for PreToolUse hooks only). // Values: "ask" (default), "approve", "deny" - Permission string + Permission string `json:"permission"` // ModifiedPrompt is the modified user prompt (for UserPromptSubmit). - ModifiedPrompt *string + ModifiedPrompt *string `json:"modified_prompt"` // ModifiedInput is the modified tool input parameters (for PreToolUse). // This is a map that can be merged with the original tool input. - ModifiedInput map[string]any + ModifiedInput map[string]any `json:"modified_input"` // ModifiedOutput is the modified tool output (for PostToolUse). - ModifiedOutput map[string]any + ModifiedOutput map[string]any `json:"modified_output"` // ContextContent is raw text content to add to LLM context. - ContextContent string + ContextContent string `json:"context_content"` // ContextFiles is a list of file paths to load and add to LLM context. - ContextFiles []string + ContextFiles []string `json:"context_files"` // Message is a user-facing message (logged and potentially displayed). - Message string + Message string `json:"message"` } // Manager coordinates hook discovery and execution. type Manager interface { - // ExecuteHooks executes all hooks for the given type in order. - // Returns accumulated results from all hooks. - ExecuteHooks(ctx context.Context, hookType HookType, context HookContext) (HookResult, error) - // ListHooks returns all discovered hooks for a given type. ListHooks(hookType HookType) []string + + // ExecuteUserPromptSubmit executes the UserPromptSubmit event + ExecuteUserPromptSubmit(ctx context.Context, sessionID, workingDir string, data UserPromptSubmitData) (HookResult, error) +} + +type UserPromptSubmitData struct { + Prompt string `json:"prompt"` + Attachments []string `json:"attachments"` + Model string `json:"model"` + Provider string `json:"provider"` + IsFirstMessage bool `json:"is_first_message"` +} + +type PreToolUseData struct { + ToolName string `json:"tool_name"` + ToolCallID string `json:"tool_call_id"` + ToolInput map[string]any `json:"tool_input"` +} + +type PostToolUseData struct { + ToolName string `json:"tool_name"` + ToolCallID string `json:"tool_call_id"` + ToolInput map[string]any `json:"tool_input"` + ToolOutput map[string]any `json:"tool_output"` + ExecutionTimeMs int64 `json:"execution_time_ms"` +} + +type StopData struct { + Reason string `json:"reason"` } diff --git a/internal/message/content.go b/internal/message/content.go index 358ad120d8f87109ea8888984ad236b155388788..da93f8ad5562ed0aba2140f30bbec58fb2e08278 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -3,6 +3,7 @@ package message import ( "encoding/base64" "errors" + "fmt" "slices" "strings" "time" @@ -12,6 +13,7 @@ import ( "charm.land/fantasy/providers/google" "charm.land/fantasy/providers/openai" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/hooks" ) type MessageRole string @@ -134,6 +136,11 @@ type Message struct { CreatedAt int64 UpdatedAt int64 IsSummaryMessage bool + Metadata MessageMetadata +} + +type MessageMetadata struct { + Hooks *hooks.HookResult `json:"hooks,omitempty"` } func (m *Message) Content() TextContent { @@ -145,6 +152,30 @@ func (m *Message) Content() TextContent { return TextContent{} } +func (m *Message) ContentWithHookContext() string { + text := strings.TrimSpace(m.Content().Text) + if m.Metadata.Hooks == nil { + return text + } + + hookContext := strings.TrimSpace(m.Metadata.Hooks.ContextContent) + if hookContext != "" { + text += fmt.Sprintf("## Additional Context\n %s", hookContext) + } + + if len(m.Metadata.Hooks.ContextFiles) > 0 { + text += "\n## Additional Context Files\n" + + for _, file := range m.Metadata.Hooks.ContextFiles { + text += fmt.Sprintf("- %s\n", file) + } + + text += "**Note: Read these files if needed**" + } + + return text +} + func (m *Message) ReasoningContent() ReasoningContent { for _, part := range m.Parts { if c, ok := part.(ReasoningContent); ok { @@ -241,6 +272,11 @@ func (m *Message) AppendContent(delta string) { } } +// AddHookResult adds the result of the hooks for this message +func (m *Message) AddHookResult(hooks hooks.HookResult) { + m.Metadata.Hooks = &hooks +} + func (m *Message) AppendReasoningContent(delta string) { found := false for i, part := range m.Parts { @@ -431,7 +467,7 @@ func (m *Message) ToAIMessage() []fantasy.Message { switch m.Role { case User: var parts []fantasy.MessagePart - text := strings.TrimSpace(m.Content().Text) + text := strings.TrimSpace(m.ContentWithHookContext()) if text != "" { parts = append(parts, fantasy.TextPart{Text: text}) } diff --git a/internal/message/message.go b/internal/message/message.go index 4cdf89b54f8eaf831d53a5fc51fdb5c71b4b953c..e393e9243ed808ca9ee819733722726ee7df4820 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -18,6 +18,7 @@ type CreateMessageParams struct { Model string Provider string IsSummaryMessage bool + Metadata MessageMetadata } type Service interface { @@ -65,6 +66,10 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes if err != nil { return Message{}, err } + metadataJSON, err := json.Marshal(params.Metadata) + if err != nil { + return Message{}, err + } isSummary := int64(0) if params.IsSummaryMessage { isSummary = 1 @@ -77,6 +82,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes Model: sql.NullString{String: string(params.Model), Valid: true}, Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""}, IsSummaryMessage: isSummary, + Metadata: string(metadataJSON), }) if err != nil { return Message{}, err @@ -110,6 +116,10 @@ func (s *service) Update(ctx context.Context, message Message) error { if err != nil { return err } + metadata, err := json.Marshal(message.Metadata) + if err != nil { + return err + } finishedAt := sql.NullInt64{} if f := message.FinishPart(); f != nil { finishedAt.Int64 = f.Time @@ -118,6 +128,7 @@ func (s *service) Update(ctx context.Context, message Message) error { err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{ ID: message.ID, Parts: string(parts), + Metadata: string(metadata), FinishedAt: finishedAt, }) if err != nil { @@ -156,6 +167,10 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { if err != nil { return Message{}, err } + var metadata MessageMetadata + if err := json.Unmarshal([]byte(item.Metadata), &metadata); err != nil { + return Message{}, err + } return Message{ ID: item.ID, SessionID: item.SessionID, @@ -166,6 +181,7 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, IsSummaryMessage: item.IsSummaryMessage != 0, + Metadata: metadata, }, nil }