hooked_tool.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8
  9	"charm.land/fantasy"
 10	"github.com/charmbracelet/crush/internal/agent/tools"
 11	"github.com/charmbracelet/crush/internal/hooks"
 12	"github.com/charmbracelet/crush/internal/permission"
 13	"github.com/tidwall/sjson"
 14)
 15
 16// hookedTool wraps a fantasy.AgentTool to run PreToolUse hooks before
 17// delegating to the inner tool.
 18type hookedTool struct {
 19	inner  fantasy.AgentTool
 20	runner *hooks.Runner
 21}
 22
 23func newHookedTool(inner fantasy.AgentTool, runner *hooks.Runner) *hookedTool {
 24	return &hookedTool{inner: inner, runner: runner}
 25}
 26
 27// wrapToolsWithHooks returns a tool slice with each entry wrapped in a
 28// hookedTool. Returns the original slice unchanged when runner is nil or
 29// when isSubAgent is true — sub-agents never fire hooks, the top-level
 30// invocation of the sub-agent tool itself is wrapped on the caller's side.
 31func wrapToolsWithHooks(tools []fantasy.AgentTool, runner *hooks.Runner, isSubAgent bool) []fantasy.AgentTool {
 32	if runner == nil || isSubAgent {
 33		return tools
 34	}
 35	out := make([]fantasy.AgentTool, len(tools))
 36	for i, tool := range tools {
 37		out[i] = newHookedTool(tool, runner)
 38	}
 39	return out
 40}
 41
 42func (h *hookedTool) Info() fantasy.ToolInfo {
 43	return h.inner.Info()
 44}
 45
 46func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions {
 47	return h.inner.ProviderOptions()
 48}
 49
 50func (h *hookedTool) SetProviderOptions(opts fantasy.ProviderOptions) {
 51	h.inner.SetProviderOptions(opts)
 52}
 53
 54func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 55	sessionID := tools.GetSessionFromContext(ctx)
 56	result, err := h.runner.Run(ctx, hooks.EventPreToolUse, sessionID, call.Name, call.Input)
 57	if err != nil {
 58		slog.Warn("Hook execution error, proceeding with tool call",
 59			"tool", call.Name, "error", err)
 60	}
 61
 62	if result.Decision == hooks.DecisionDeny || result.Halt {
 63		reason := fmt.Sprintf("Tool call blocked by hook. Reason: %s", result.Reason)
 64		if result.Halt {
 65			reason = fmt.Sprintf("Turn halted by hook. Reason: %s", result.Reason)
 66		}
 67		resp := fantasy.NewTextErrorResponse(reason)
 68		// Halt ends the whole turn; a plain deny only blocks this tool
 69		// call so the model can see the error and try something else.
 70		resp.StopTurn = result.Halt
 71		resp.Metadata = hookMetadataJSON(result)
 72		return resp, nil
 73	}
 74
 75	if result.UpdatedInput != "" {
 76		call.Input = result.UpdatedInput
 77	}
 78
 79	// An explicit allow from a hook pre-approves the permission prompt for
 80	// this tool call. Deny is already handled above; silence falls through
 81	// to the normal permission flow.
 82	if result.Decision == hooks.DecisionAllow {
 83		ctx = permission.WithHookApproval(ctx, call.ID)
 84	}
 85
 86	resp, err := h.inner.Run(ctx, call)
 87	if err != nil {
 88		return resp, err
 89	}
 90
 91	if result.Context != "" {
 92		if resp.Content != "" {
 93			resp.Content += "\n"
 94		}
 95		resp.Content += result.Context
 96	}
 97
 98	resp.Metadata = mergeHookMetadata(resp.Metadata, result)
 99	return resp, nil
100}
101
102// buildHookMetadata creates a HookMetadata from an AggregateResult.
103func buildHookMetadata(result hooks.AggregateResult) hooks.HookMetadata {
104	return hooks.HookMetadata{
105		HookCount:    result.HookCount,
106		Decision:     result.Decision.String(),
107		Halt:         result.Halt,
108		Reason:       result.Reason,
109		InputRewrite: result.UpdatedInput != "",
110		Hooks:        result.Hooks,
111	}
112}
113
114// hookMetadataJSON builds a JSON string containing only the hook metadata.
115func hookMetadataJSON(result hooks.AggregateResult) string {
116	meta := buildHookMetadata(result)
117	data, err := json.Marshal(meta)
118	if err != nil {
119		return ""
120	}
121	return `{"hook":` + string(data) + `}`
122}
123
124// mergeHookMetadata injects hook metadata into existing tool metadata.
125func mergeHookMetadata(existing string, result hooks.AggregateResult) string {
126	if result.HookCount == 0 {
127		return existing
128	}
129	meta := buildHookMetadata(result)
130	data, err := json.Marshal(meta)
131	if err != nil {
132		return existing
133	}
134	if existing == "" {
135		existing = "{}"
136	}
137	merged, err := sjson.SetRaw(existing, "hook", string(data))
138	if err != nil {
139		return existing
140	}
141	return merged
142}