1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8
9 "charm.land/fantasy"
10 "github.com/charmbracelet/crush/internal/agent/tools"
11 "github.com/charmbracelet/crush/internal/hooks"
12 "github.com/charmbracelet/crush/internal/permission"
13 "github.com/tidwall/sjson"
14)
15
16// hookedTool wraps a fantasy.AgentTool to run PreToolUse hooks before
17// delegating to the inner tool.
18type hookedTool struct {
19 inner fantasy.AgentTool
20 runner *hooks.Runner
21}
22
23func newHookedTool(inner fantasy.AgentTool, runner *hooks.Runner) *hookedTool {
24 return &hookedTool{inner: inner, runner: runner}
25}
26
27// wrapToolsWithHooks returns a tool slice with each entry wrapped in a
28// hookedTool. Returns the original slice unchanged when runner is nil or
29// when isSubAgent is true — sub-agents never fire hooks, the top-level
30// invocation of the sub-agent tool itself is wrapped on the caller's side.
31func wrapToolsWithHooks(tools []fantasy.AgentTool, runner *hooks.Runner, isSubAgent bool) []fantasy.AgentTool {
32 if runner == nil || isSubAgent {
33 return tools
34 }
35 out := make([]fantasy.AgentTool, len(tools))
36 for i, tool := range tools {
37 out[i] = newHookedTool(tool, runner)
38 }
39 return out
40}
41
42func (h *hookedTool) Info() fantasy.ToolInfo {
43 return h.inner.Info()
44}
45
46func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions {
47 return h.inner.ProviderOptions()
48}
49
50func (h *hookedTool) SetProviderOptions(opts fantasy.ProviderOptions) {
51 h.inner.SetProviderOptions(opts)
52}
53
54func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
55 sessionID := tools.GetSessionFromContext(ctx)
56 result, err := h.runner.Run(ctx, hooks.EventPreToolUse, sessionID, call.Name, call.Input)
57 if err != nil {
58 slog.Warn("Hook execution error, proceeding with tool call",
59 "tool", call.Name, "error", err)
60 }
61
62 if result.Decision == hooks.DecisionDeny || result.Halt {
63 reason := fmt.Sprintf("Tool call blocked by hook. Reason: %s", result.Reason)
64 if result.Halt {
65 reason = fmt.Sprintf("Turn halted by hook. Reason: %s", result.Reason)
66 }
67 resp := fantasy.NewTextErrorResponse(reason)
68 // Halt ends the whole turn; a plain deny only blocks this tool
69 // call so the model can see the error and try something else.
70 resp.StopTurn = result.Halt
71 resp.Metadata = hookMetadataJSON(result)
72 return resp, nil
73 }
74
75 if result.UpdatedInput != "" {
76 call.Input = result.UpdatedInput
77 }
78
79 // An explicit allow from a hook pre-approves the permission prompt for
80 // this tool call. Deny is already handled above; silence falls through
81 // to the normal permission flow.
82 if result.Decision == hooks.DecisionAllow {
83 ctx = permission.WithHookApproval(ctx, call.ID)
84 }
85
86 resp, err := h.inner.Run(ctx, call)
87 if err != nil {
88 return resp, err
89 }
90
91 if result.Context != "" {
92 if resp.Content != "" {
93 resp.Content += "\n"
94 }
95 resp.Content += result.Context
96 }
97
98 resp.Metadata = mergeHookMetadata(resp.Metadata, result)
99 return resp, nil
100}
101
102// buildHookMetadata creates a HookMetadata from an AggregateResult.
103func buildHookMetadata(result hooks.AggregateResult) hooks.HookMetadata {
104 return hooks.HookMetadata{
105 HookCount: result.HookCount,
106 Decision: result.Decision.String(),
107 Halt: result.Halt,
108 Reason: result.Reason,
109 InputRewrite: result.UpdatedInput != "",
110 Hooks: result.Hooks,
111 }
112}
113
114// hookMetadataJSON builds a JSON string containing only the hook metadata.
115func hookMetadataJSON(result hooks.AggregateResult) string {
116 meta := buildHookMetadata(result)
117 data, err := json.Marshal(meta)
118 if err != nil {
119 return ""
120 }
121 return `{"hook":` + string(data) + `}`
122}
123
124// mergeHookMetadata injects hook metadata into existing tool metadata.
125func mergeHookMetadata(existing string, result hooks.AggregateResult) string {
126 if result.HookCount == 0 {
127 return existing
128 }
129 meta := buildHookMetadata(result)
130 data, err := json.Marshal(meta)
131 if err != nil {
132 return existing
133 }
134 if existing == "" {
135 existing = "{}"
136 }
137 merged, err := sjson.SetRaw(existing, "hook", string(data))
138 if err != nil {
139 return existing
140 }
141 return merged
142}