From b208bb3aed3269d5d76696be543a54150b3176fc Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 26 Nov 2025 14:21:45 +0100 Subject: [PATCH] wip: other hooks --- go.mod | 6 +- go.sum | 10 +- internal/agent/agent.go | 192 ++++++++++++++++++++++++++++++ internal/agent/errors.go | 1 + internal/agent/tools/bash.go | 25 ++-- internal/agent/tools/download.go | 24 ++-- internal/agent/tools/edit.go | 99 +++++++-------- internal/agent/tools/fetch.go | 26 ++-- internal/agent/tools/ls.go | 8 +- internal/agent/tools/mcp-tools.go | 25 ++-- internal/agent/tools/multiedit.go | 16 ++- internal/agent/tools/tools.go | 58 ++++++++- internal/agent/tools/view.go | 8 +- internal/agent/tools/write.go | 31 ++--- internal/hooks/README.md | 15 +-- internal/hooks/helpers.sh | 7 -- internal/hooks/manager.go | 2 - internal/hooks/types.go | 9 ++ internal/message/content.go | 26 ++-- 19 files changed, 417 insertions(+), 171 deletions(-) diff --git a/go.mod b/go.mod index f8c69810e844780a76445b55036098c63757f71e..ae383bc809e69efeb0dc14fe34e16a6b59914439 100644 --- a/go.mod +++ b/go.mod @@ -69,7 +69,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/RealAlexandreAI/json-repair v0.0.14 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect - github.com/aws/aws-sdk-go-v2 v1.39.6 // indirect + github.com/aws/aws-sdk-go-v2 v1.40.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect @@ -168,7 +168,7 @@ require ( golang.org/x/term v0.36.0 // indirect golang.org/x/time v0.12.0 // indirect google.golang.org/api v0.239.0 // indirect - google.golang.org/genai v1.34.0 // indirect + google.golang.org/genai v1.36.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/grpc v1.74.2 // indirect google.golang.org/protobuf v1.36.10 // indirect @@ -176,3 +176,5 @@ require ( gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace charm.land/fantasy => ../../fantasy diff --git a/go.sum b/go.sum index 1458d7a78717609e015f21ea9d52b3b45a6df4dc..ebc184135dd903484acba2505ad9b3e6c721f7e0 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,6 @@ charm.land/bubbles/v2 v2.0.0-rc.1 h1:EiIFVAc3Zi/yY86td+79mPhHR7AqZ1OxF+6ztpOCRaM charm.land/bubbles/v2 v2.0.0-rc.1/go.mod h1:5AbN6cEd/47gkEf8TgiQ2O3RZ5QxMS14l9W+7F9fPC4= charm.land/bubbletea/v2 v2.0.0-rc.1.0.20251117161017-15f884bd2973 h1:Ay8VWyn/CbwltswomzWXj0m5KKfSJavFfCDCxI+j8qo= charm.land/bubbletea/v2 v2.0.0-rc.1.0.20251117161017-15f884bd2973/go.mod h1:IXFmnCnMLTWw/KQ9rEatSYqbAPAYi8kA3Yqwa1SFnLk= -charm.land/fantasy v0.3.2 h1:yHTsSZ25LcICMRw3xzdz3OkaZtDQch+B5ljJo17HxgU= -charm.land/fantasy v0.3.2/go.mod h1:sV8Ns/JTJHOaYOHPgVRDugMheAyxsW/nmdpVGrycYEk= charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251106193318-19329a3e8410 h1:D9PbaszZYpB4nj+d6HTWr1onlmlyuGVNfL9gAi8iB3k= charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251106193318-19329a3e8410/go.mod h1:1qZyvvVCenJO2M1ac2mX0yyiIZJoZmDM4DG4s0udJkU= charm.land/x/vcr v0.1.1 h1:PXCFMUG0rPtyk35rhfzYCJEduOzWXCIbrXTFq4OF/9Q= @@ -44,8 +42,8 @@ github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kk github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= -github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= -github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= +github.com/aws/aws-sdk-go-v2 v1.40.0 h1:/WMUA0kjhZExjOQN2z3oLALDREea1A7TobfuiBrKlwc= +github.com/aws/aws-sdk-go-v2 v1.40.0/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3/go.mod h1:UbnqO+zjqk3uIt9yCACHJ9IVNhyhOCnYk8yA19SAWrM= github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= @@ -456,8 +454,8 @@ golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.239.0 h1:2hZKUnFZEy81eugPs4e2XzIJ5SOwQg0G82bpXD65Puo= google.golang.org/api v0.239.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50= -google.golang.org/genai v1.34.0 h1:lPRJRO+HqRX1SwFo1Xb/22nZ5MBEPUbXDl61OoDxlbY= -google.golang.org/genai v1.34.0/go.mod h1:7pAilaICJlQBonjKKJNhftDFv3SREhZcTe9F6nRcjbg= +google.golang.org/genai v1.36.0 h1:sJCIjqTAmwrtAIaemtTiKkg2TO1RxnYEusTmEQ3nGxM= +google.golang.org/genai v1.36.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 0cb63debefe0e61e4868f40ecdc8367f32473e0b..6a5456dc17c11f4658d27b6405f2e29e5cba3ab0 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -11,6 +11,7 @@ import ( "cmp" "context" _ "embed" + "encoding/json" "errors" "fmt" "log/slog" @@ -196,6 +197,14 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy defer cancel() defer a.activeRequests.Del(call.SessionID) + // Track completion reason for stop hook + var stopReason string + defer func() { + if stopReason != "" { + a.executeStopHook(ctx, call.SessionID, stopReason) + } + }() + // create the agent message asap to show loading var currentAssistant *message.Message assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{ @@ -212,6 +221,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy hookErr := a.executePromptSubmitHook(genCtx, &msg, len(msgs) == 0) if hookErr != nil { + stopReason = "error" // Delete the assistant message // use the ctx since this could be a cancellation deleteErr := a.messages.Delete(ctx, currentAssistant.ID) @@ -223,6 +233,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy startTime := time.Now() a.eventPromptSent(call.SessionID) + // Map to store post-tool-use hook results for OnToolResult callback + postToolHookResults := csync.NewMap[string, hooks.HookResult]() + var shouldSummarize bool result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{ Prompt: msg.ContentWithHookContext(), @@ -359,6 +372,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddToolCall(toolCall) return a.messages.Update(genCtx, *currentAssistant) }, + PreToolExecute: func(ctx context.Context, toolCall fantasy.ToolCall) (context.Context, *fantasy.ToolCall, error) { + return a.executePreToolUseHook(ctx, call.SessionID, toolCall, currentAssistant) + }, OnToolResult: func(result fantasy.ToolResultContent) error { var resultContent string isError := false @@ -384,6 +400,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy IsError: isError, Metadata: result.ClientMetadata, } + // Attach hook result if available + if hookRes, ok := postToolHookResults.Get(result.ToolCallID); ok { + toolResult.HookResult = &hookRes + } _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: []message.ContentPart{ @@ -395,6 +415,14 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } return nil }, + PostToolExecute: func(ctx context.Context, toolCall fantasy.ToolCall, response fantasy.ToolResponse, executionTimeMs int64) (*fantasy.ToolResponse, error) { + modifiedResponse, hookResult, err := a.executePostToolUseHook(ctx, call.SessionID, toolCall, response, executionTimeMs) + if hookResult != nil { + // Store for OnToolResult callback + postToolHookResults.Set(toolCall.ID, *hookResult) + } + return modifiedResponse, err + }, OnStepFinish: func(stepResult fantasy.StepResult) error { finishReason := message.FinishReasonUnknown switch stepResult.FinishReason { @@ -440,6 +468,17 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy if err != nil { isCancelErr := errors.Is(err, context.Canceled) isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied) + isHookDenied := errors.Is(err, ErrHookDenied) + + // Set stop reason for defer + if isCancelErr { + stopReason = "cancelled" + } else if isPermissionErr || isHookDenied { + stopReason = "permission_denied" + } else { + stopReason = "error" + } + if currentAssistant == nil { return result, err } @@ -484,6 +523,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy content = "Tool execution canceled by user" } else if isPermissionErr { content = "User denied permission" + } else if isHookDenied { + content = "Hook denied execution" } toolResult := message.ToolResult{ ToolCallID: tc.ID, @@ -508,6 +549,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "") } else if isPermissionErr { currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "") + } else if isHookDenied { + currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook denied execution", "") } else if errors.As(err, &providerErr) { currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message) } else if errors.As(err, &fantasyErr) { @@ -525,6 +568,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } wg.Wait() + // Set completion reason for stop hook + stopReason = "completed" + if shouldSummarize { a.activeRequests.Del(call.SessionID) if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil { @@ -967,3 +1013,149 @@ func (a *sessionAgent) executePromptSubmitHook(ctx context.Context, msg *message } return nil } + +// executePreToolUseHook executes the pre-tool-use hook and applies modifications. +// Only runs for main agent (not sub-agents). +func (a *sessionAgent) executePreToolUseHook(ctx context.Context, sessionID string, toolCall fantasy.ToolCall, currentAssistant *message.Message) (context.Context, *fantasy.ToolCall, error) { + // Skip if sub-agent or no hooks manager. + if a.isSubAgent || a.hooksManager == nil { + return ctx, nil, nil + } + + // Parse tool input to map + var toolInput map[string]any + if err := json.Unmarshal([]byte(toolCall.Input), &toolInput); err != nil { + // If we can't parse the input, skip the hook + return ctx, nil, nil + } + + hookResult, err := a.hooksManager.ExecutePreToolUse(ctx, sessionID, a.workingDir, hooks.PreToolUseData{ + ToolName: toolCall.Name, + ToolCallID: toolCall.ID, + ToolInput: toolInput, + }) + if err != nil { + return ctx, nil, fmt.Errorf("pre-tool-use hook execution failed: %w", err) + } + + // Store hook result in the current assistant's tool call + for _, tc := range currentAssistant.ToolCalls() { + if tc.ID == toolCall.ID { + tc.HookResult = &hookResult + currentAssistant.AddToolCall(tc) + if updateErr := a.messages.Update(ctx, *currentAssistant); updateErr != nil { + slog.Error("failed to update assistant message with pre-hook result", "error", updateErr) + } + break + } + } + + // If hook returned Continue: false, deny execution. + if !hookResult.Continue { + return ctx, nil, ErrHookDenied + } + + // Set permission in context for tools to use + if hookResult.Permission != "" { + ctx = tools.SetHookPermissionInContext(ctx, hookResult.Permission) + } + + // Apply modified input if present. + if len(hookResult.ModifiedInput) > 0 { + // Merge modified input with original + for k, v := range hookResult.ModifiedInput { + toolInput[k] = v + } + + modifiedInputJSON, err := json.Marshal(toolInput) + if err != nil { + return ctx, nil, fmt.Errorf("failed to marshal modified input: %w", err) + } + + modifiedCall := toolCall + modifiedCall.Input = string(modifiedInputJSON) + return ctx, &modifiedCall, nil + } + + return ctx, nil, nil +} + +// executePostToolUseHook executes the post-tool-use hook and applies modifications. +// Only runs for main agent (not sub-agents). +func (a *sessionAgent) executePostToolUseHook(ctx context.Context, sessionID string, toolCall fantasy.ToolCall, response fantasy.ToolResponse, executionTimeMs int64) (*fantasy.ToolResponse, *hooks.HookResult, error) { + // Skip if sub-agent or no hooks manager. + if a.isSubAgent || a.hooksManager == nil { + return nil, nil, nil + } + + // Parse tool input to map + var toolInput map[string]any + if err := json.Unmarshal([]byte(toolCall.Input), &toolInput); err != nil { + return nil, nil, nil + } + + // Parse tool output to map + toolOutput := map[string]any{ + "success": !response.IsError, + "content": response.Content, + } + if response.Metadata != "" { + toolOutput["metadata"] = response.Metadata + } + + hookResult, err := a.hooksManager.ExecutePostToolUse(ctx, sessionID, a.workingDir, hooks.PostToolUseData{ + ToolName: toolCall.Name, + ToolCallID: toolCall.ID, + ToolInput: toolInput, + ToolOutput: toolOutput, + ExecutionTimeMs: executionTimeMs, + }) + if err != nil { + return nil, nil, fmt.Errorf("post-tool-use hook execution failed: %w", err) + } + + // If hook returned Continue: false, return error to stop execution. + if !hookResult.Continue { + return nil, &hookResult, ErrHookDenied + } + + // Apply modified output if present. + if len(hookResult.ModifiedOutput) > 0 { + modifiedResponse := response + + // Apply modifications + if content, ok := hookResult.ModifiedOutput["content"].(string); ok { + modifiedResponse.Content = content + } + if success, ok := hookResult.ModifiedOutput["success"].(bool); ok { + modifiedResponse.IsError = !success + } + if metadata, ok := hookResult.ModifiedOutput["metadata"].(string); ok { + modifiedResponse.Metadata = metadata + } + + return &modifiedResponse, &hookResult, nil + } + + return nil, &hookResult, nil +} + +// executeStopHook executes the stop hook when agent loop ends. +// Only runs for main agent (not sub-agents). Errors are logged but don't fail. +func (a *sessionAgent) executeStopHook(ctx context.Context, sessionID, reason string) { + // Skip if sub-agent or no hooks manager. + if a.isSubAgent || a.hooksManager == nil { + return + } + + // Use a fresh context with timeout to ensure hook runs even if parent is cancelled + hookCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := a.hooksManager.ExecuteStop(hookCtx, sessionID, a.workingDir, hooks.StopData{ + Reason: reason, + }) + if err != nil { + slog.Error("stop hook execution failed", "session_id", sessionID, "reason", reason, "error", err) + } +} diff --git a/internal/agent/errors.go b/internal/agent/errors.go index 562b69f6da97dd3aaf6dd2f1342939a9ad3e596e..f564d8bce948c481168c0ca7a215f6fde1c8fe8f 100644 --- a/internal/agent/errors.go +++ b/internal/agent/errors.go @@ -11,6 +11,7 @@ var ( ErrEmptyPrompt = errors.New("prompt is empty") ErrSessionMissing = errors.New("session id is missing") ErrHookExecutionStop = errors.New("hook stopped execution") + ErrHookDenied = errors.New("hook denied execution") ) func isCancelledErr(err error) bool { diff --git a/internal/agent/tools/bash.go b/internal/agent/tools/bash.go index c3f0bc8cd24a6c4ff7c6f775e357c90b3dc99802..71eb19bcb5069046746d18fef9dd6082e8e7a3ae 100644 --- a/internal/agent/tools/bash.go +++ b/internal/agent/tools/bash.go @@ -215,18 +215,19 @@ func NewBashTool(permissions permission.Service, workingDir string, attribution return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command") } if !isSafeReadOnly { - p := permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: execWorkingDir, - ToolCallID: call.ID, - ToolName: BashToolName, - Action: "execute", - Description: fmt.Sprintf("Execute command: %s", params.Command), - Params: BashPermissionsParams(params), - }, - ) - if !p { + granted, err := CheckHookPermission(ctx, permissions, permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: execWorkingDir, + ToolCallID: call.ID, + ToolName: BashToolName, + Action: "execute", + Description: fmt.Sprintf("Execute command: %s", params.Command), + Params: BashPermissionsParams(params), + }) + if err != nil { + return fantasy.ToolResponse{}, err + } + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } } diff --git a/internal/agent/tools/download.go b/internal/agent/tools/download.go index 2dfd43d61fb3cafc4cec9da62253fc48510de45d..4bd3068e1aa225c8bdd59ce34bc3b44ac8d68ba6 100644 --- a/internal/agent/tools/download.go +++ b/internal/agent/tools/download.go @@ -70,18 +70,18 @@ func NewDownloadTool(permissions permission.Service, workingDir string, client * return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files") } - p := permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: filePath, - ToolName: DownloadToolName, - Action: "download", - Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath), - Params: DownloadPermissionsParams(params), - }, - ) - - if !p { + granted, err := CheckHookPermission(ctx, permissions, permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: filePath, + ToolName: DownloadToolName, + Action: "download", + Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath), + Params: DownloadPermissionsParams(params), + }) + if err != nil { + return fantasy.ToolResponse{}, err + } + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/edit.go b/internal/agent/tools/edit.go index 7012afc8f525a39c2c431ec7327a9fb2d378ef42..b6f2b63f7329d945e07022727f73b40528d5983c 100644 --- a/internal/agent/tools/edit.go +++ b/internal/agent/tools/edit.go @@ -128,22 +128,23 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool content, strings.TrimPrefix(filePath, edit.workingDir), ) - p := edit.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(filePath, edit.workingDir), - ToolCallID: call.ID, - ToolName: EditToolName, - Action: "write", - Description: fmt.Sprintf("Create file %s", filePath), - Params: EditPermissionsParams{ - FilePath: filePath, - OldContent: "", - NewContent: content, - }, + granted, err := CheckHookPermission(edit.ctx, edit.permissions, permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: fsext.PathOrPrefix(filePath, edit.workingDir), + ToolCallID: call.ID, + ToolName: EditToolName, + Action: "write", + Description: fmt.Sprintf("Create file %s", filePath), + Params: EditPermissionsParams{ + FilePath: filePath, + OldContent: "", + NewContent: content, }, - ) - if !p { + }) + if err != nil { + return fantasy.ToolResponse{}, err + } + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } @@ -176,8 +177,7 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool NewContent: content, Additions: additions, Removals: removals, - }, - ), nil + }), nil } func deleteContent(edit editContext, filePath, oldString string, replaceAll bool, call fantasy.ToolCall) (fantasy.ToolResponse, error) { @@ -249,22 +249,23 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool strings.TrimPrefix(filePath, edit.workingDir), ) - p := edit.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(filePath, edit.workingDir), - ToolCallID: call.ID, - ToolName: EditToolName, - Action: "write", - Description: fmt.Sprintf("Delete content from file %s", filePath), - Params: EditPermissionsParams{ - FilePath: filePath, - OldContent: oldContent, - NewContent: newContent, - }, + granted, err := CheckHookPermission(edit.ctx, edit.permissions, permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: fsext.PathOrPrefix(filePath, edit.workingDir), + ToolCallID: call.ID, + ToolName: EditToolName, + Action: "write", + Description: fmt.Sprintf("Delete content from file %s", filePath), + Params: EditPermissionsParams{ + FilePath: filePath, + OldContent: oldContent, + NewContent: newContent, }, - ) - if !p { + }) + if err != nil { + return fantasy.ToolResponse{}, err + } + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } @@ -309,8 +310,7 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool NewContent: newContent, Additions: additions, Removals: removals, - }, - ), nil + }), nil } func replaceContent(edit editContext, filePath, oldString, newString string, replaceAll bool, call fantasy.ToolCall) (fantasy.ToolResponse, error) { @@ -384,22 +384,23 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep strings.TrimPrefix(filePath, edit.workingDir), ) - p := edit.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(filePath, edit.workingDir), - ToolCallID: call.ID, - ToolName: EditToolName, - Action: "write", - Description: fmt.Sprintf("Replace content in file %s", filePath), - Params: EditPermissionsParams{ - FilePath: filePath, - OldContent: oldContent, - NewContent: newContent, - }, + granted, err := CheckHookPermission(edit.ctx, edit.permissions, permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: fsext.PathOrPrefix(filePath, edit.workingDir), + ToolCallID: call.ID, + ToolName: EditToolName, + Action: "write", + Description: fmt.Sprintf("Replace content in file %s", filePath), + Params: EditPermissionsParams{ + FilePath: filePath, + OldContent: oldContent, + NewContent: newContent, }, - ) - if !p { + }) + if err != nil { + return fantasy.ToolResponse{}, err + } + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/fetch.go b/internal/agent/tools/fetch.go index 1ae1d6b9be487ba014a8dbe4d0bce7c9754dc584..4d7317abbd02acea091417a796faedfeff56baec 100644 --- a/internal/agent/tools/fetch.go +++ b/internal/agent/tools/fetch.go @@ -55,19 +55,19 @@ func NewFetchTool(permissions permission.Service, workingDir string, client *htt return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") } - p := permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: workingDir, - ToolCallID: call.ID, - ToolName: FetchToolName, - Action: "fetch", - Description: fmt.Sprintf("Fetch content from URL: %s", params.URL), - Params: FetchPermissionsParams(params), - }, - ) - - if !p { + granted, err := CheckHookPermission(ctx, permissions, permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: workingDir, + ToolCallID: call.ID, + ToolName: FetchToolName, + Action: "fetch", + Description: fmt.Sprintf("Fetch content from URL: %s", params.URL), + Params: FetchPermissionsParams(params), + }) + if err != nil { + return fantasy.ToolResponse{}, err + } + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/ls.go b/internal/agent/tools/ls.go index 7725910c11675c941d791a6ec5d57a190535ee9e..94bd763dba3f324b1a1151386e1eec4b5244a77b 100644 --- a/internal/agent/tools/ls.go +++ b/internal/agent/tools/ls.go @@ -79,7 +79,7 @@ func NewLsTool(permissions permission.Service, workingDir string, lsConfig confi return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing directories outside working directory") } - granted := permissions.Request( + granted, err := CheckHookPermission(ctx, permissions, permission.CreatePermissionRequest{ SessionID: sessionID, Path: absSearchPath, @@ -88,8 +88,10 @@ func NewLsTool(permissions permission.Service, workingDir string, lsConfig confi Action: "list", Description: fmt.Sprintf("List directory outside working directory: %s", absSearchPath), Params: LSPermissionsParams(params), - }, - ) + }) + if err != nil { + return fantasy.ToolResponse{}, err + } if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied diff --git a/internal/agent/tools/mcp-tools.go b/internal/agent/tools/mcp-tools.go index e75a6ccb7c92c6a4925960e0590eb5f3b2bac47e..2402e62f04bb6d724086730336513d16d9d4f026 100644 --- a/internal/agent/tools/mcp-tools.go +++ b/internal/agent/tools/mcp-tools.go @@ -89,18 +89,19 @@ func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolRe return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") } permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name) - p := m.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - ToolCallID: params.ID, - Path: m.workingDir, - ToolName: m.Info().Name, - Action: "execute", - Description: permissionDescription, - Params: params.Input, - }, - ) - if !p { + granted, err := CheckHookPermission(ctx, m.permissions, permission.CreatePermissionRequest{ + SessionID: sessionID, + ToolCallID: params.ID, + Path: m.workingDir, + ToolName: m.Info().Name, + Action: "execute", + Description: permissionDescription, + Params: params.Input, + }) + if err != nil { + return fantasy.ToolResponse{}, err + } + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/multiedit.go b/internal/agent/tools/multiedit.go index 72b7583afe0e43a84a13629ae58644b795fc0755..631d088022c1932914a6dc6f64ef5355884ba0a6 100644 --- a/internal/agent/tools/multiedit.go +++ b/internal/agent/tools/multiedit.go @@ -165,7 +165,7 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call // Check permissions _, additions, removals := diff.GenerateDiff("", currentContent, strings.TrimPrefix(params.FilePath, edit.workingDir)) - p := edit.permissions.Request(permission.CreatePermissionRequest{ + granted, err := CheckHookPermission(edit.ctx, edit.permissions, permission.CreatePermissionRequest{ SessionID: sessionID, Path: fsext.PathOrPrefix(params.FilePath, edit.workingDir), ToolCallID: call.ID, @@ -178,12 +178,12 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call NewContent: currentContent, }, }) - if !p { + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } // Write the file - err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644) + err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644) if err != nil { return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } @@ -219,8 +219,7 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call Removals: removals, EditsApplied: editsApplied, EditsFailed: failedEdits, - }, - ), nil + }), nil } func processMultiEditExistingFile(edit editContext, params MultiEditParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) { @@ -299,7 +298,7 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call // Generate diff and check permissions _, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, edit.workingDir)) - p := edit.permissions.Request(permission.CreatePermissionRequest{ + granted, err := CheckHookPermission(edit.ctx, edit.permissions, permission.CreatePermissionRequest{ SessionID: sessionID, Path: fsext.PathOrPrefix(params.FilePath, edit.workingDir), ToolCallID: call.ID, @@ -312,7 +311,7 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call NewContent: currentContent, }, }) - if !p { + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } @@ -368,8 +367,7 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call Removals: removals, EditsApplied: editsApplied, EditsFailed: failedEdits, - }, - ), nil + }), nil } func applyEditToContent(content string, edit MultiEditOperation) (string, error) { diff --git a/internal/agent/tools/tools.go b/internal/agent/tools/tools.go index 365d7a33e689189096611adb9c7af7bfc76bce75..7ffce30cdbcefe275786c41e9deb31cacd19c16c 100644 --- a/internal/agent/tools/tools.go +++ b/internal/agent/tools/tools.go @@ -2,16 +2,20 @@ package tools import ( "context" + + "github.com/charmbracelet/crush/internal/permission" ) type ( - sessionIDContextKey string - messageIDContextKey string + sessionIDContextKey string + messageIDContextKey string + hookPermissionContextKey string ) const ( - SessionIDContextKey sessionIDContextKey = "session_id" - MessageIDContextKey messageIDContextKey = "message_id" + SessionIDContextKey sessionIDContextKey = "session_id" + MessageIDContextKey messageIDContextKey = "message_id" + HookPermissionContextKey hookPermissionContextKey = "hook_permission" ) func GetSessionFromContext(ctx context.Context) string { @@ -37,3 +41,49 @@ func GetMessageFromContext(ctx context.Context) string { } return s } + +// GetHookPermissionFromContext gets the hook permission decision from context. +// Returns: permission string ("approve" or "deny"), found bool +func GetHookPermissionFromContext(ctx context.Context) (string, bool) { + permission := ctx.Value(HookPermissionContextKey) + if permission == nil { + return "", false + } + s, ok := permission.(string) + if !ok { + return "", false + } + return s, true +} + +// SetHookPermissionInContext sets the hook permission decision in context. +func SetHookPermissionInContext(ctx context.Context, permission string) context.Context { + return context.WithValue(ctx, HookPermissionContextKey, permission) +} + +// CheckHookPermission checks if a hook has already made a permission decision. +// Returns true if execution should proceed, false if denied. +// If hook approved, skips the permission service call. +// If hook denied, returns ErrorPermissionDenied. +// If hook said "ask" or no decision, calls the permission service. +func CheckHookPermission(ctx context.Context, permissionService permission.Service, req permission.CreatePermissionRequest) (bool, error) { + hookPerm, hasHookPerm := GetHookPermissionFromContext(ctx) + + if hasHookPerm { + switch hookPerm { + case "approve": + // Hook auto-approved, skip permission check + return true, nil + case "deny": + // Hook denied, return error + return false, permission.ErrorPermissionDenied + } + } + + // No hook decision or hook said "ask", use normal permission flow + granted := permissionService.Request(req) + if !granted { + return false, permission.ErrorPermissionDenied + } + return true, nil +} diff --git a/internal/agent/tools/view.go b/internal/agent/tools/view.go index 39a2e9d1dbf9e7fcf63c445f9a393f5cc24b88f3..46a10c5b32428dd777f59d8808546b0eb6dcf77c 100644 --- a/internal/agent/tools/view.go +++ b/internal/agent/tools/view.go @@ -82,7 +82,7 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory") } - granted := permissions.Request( + granted, err := CheckHookPermission(ctx, permissions, permission.CreatePermissionRequest{ SessionID: sessionID, Path: absFilePath, @@ -91,8 +91,10 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss Action: "read", Description: fmt.Sprintf("Read file outside working directory: %s", absFilePath), Params: ViewPermissionsParams(params), - }, - ) + }) + if err != nil { + return fantasy.ToolResponse{}, err + } if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied diff --git a/internal/agent/tools/write.go b/internal/agent/tools/write.go index 0868b9f62306e7b43b1c7218ff68a6a7aae140eb..ec8a8d03d2dffbaef023d7bb2c06a72cefde23b1 100644 --- a/internal/agent/tools/write.go +++ b/internal/agent/tools/write.go @@ -110,22 +110,23 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis strings.TrimPrefix(filePath, workingDir), ) - p := permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(filePath, workingDir), - ToolCallID: call.ID, - ToolName: WriteToolName, - Action: "write", - Description: fmt.Sprintf("Create file %s", filePath), - Params: WritePermissionsParams{ - FilePath: filePath, - OldContent: oldContent, - NewContent: params.Content, - }, + granted, err := CheckHookPermission(ctx, permissions, permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: fsext.PathOrPrefix(filePath, workingDir), + ToolCallID: call.ID, + ToolName: WriteToolName, + Action: "write", + Description: fmt.Sprintf("Create file %s", filePath), + Params: WritePermissionsParams{ + FilePath: filePath, + OldContent: oldContent, + NewContent: params.Content, }, - ) - if !p { + }) + if err != nil { + return fantasy.ToolResponse{}, err + } + if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/hooks/README.md b/internal/hooks/README.md index 12b314f94c67df4a130e7165cb3f0cdc774d0a41..c797083ffa3df3d6bc20170037086eec4261be8a 100644 --- a/internal/hooks/README.md +++ b/internal/hooks/README.md @@ -222,13 +222,6 @@ crush_deny "Blocked dangerous operation" # Script exits immediately with code 2 ``` -#### `crush_ask [message]` -Ask user for permission (default behavior). - -```bash -crush_ask "This command modifies files, please review" -``` - ### Context Helpers #### `crush_add_context "content"` @@ -373,7 +366,7 @@ export CRUSH_CONTEXT_FILES="/path/to/file1.md:/path/to/file2.md" ``` **Available variables**: -- `CRUSH_PERMISSION` - `approve`, `ask`, or `deny` +- `CRUSH_PERMISSION` - `approve` or `deny` - `CRUSH_MESSAGE` - User-facing message - `CRUSH_CONTINUE` - `true` or `false` (stop execution) - `CRUSH_MODIFIED_PROMPT` - New prompt text @@ -401,7 +394,7 @@ echo '{ **JSON fields**: - `continue` (bool) - Continue execution -- `permission` (string) - `approve`, `ask`, `deny` +- `permission` (string) - `approve` or `deny` - `message` (string) - User-facing message - `modified_prompt` (string) - New prompt - `modified_input` (object) - Modified tool parameters @@ -444,8 +437,10 @@ Hooks execute **sequentially** in alphabetical order. Use numeric prefixes to co When multiple hooks execute, their results are merged: ### Permission (Most Restrictive Wins) -- `deny` > `ask` > `approve` +- `deny` > `approve` - If any hook denies, the final result is deny +- If any hook approves and no denials, the result is approve +- If no hooks set permission, normal permission flow applies ### Continue (AND Logic) - All hooks must set `Continue=true` (or not set it) diff --git a/internal/hooks/helpers.sh b/internal/hooks/helpers.sh index 21ef57e4aad4b3a6944d3b3a885ac3ccbb43a32c..305c196f5116488a22db12a3e0bafc53bf53886a 100644 --- a/internal/hooks/helpers.sh +++ b/internal/hooks/helpers.sh @@ -21,13 +21,6 @@ crush_deny() { exit 2 } -# Ask user for permission (default behavior). -# Usage: crush_ask ["message"] -crush_ask() { - export CRUSH_PERMISSION=ask - [ -n "$1" ] && export CRUSH_MESSAGE="$1" -} - # Context helpers # Add raw text content to LLM context. diff --git a/internal/hooks/manager.go b/internal/hooks/manager.go index 48e214f45be665843d39b5b5123fd09b9f4feb9b..0f32cdc4bad61e20b06c61e9008a49f63c42d6e6 100644 --- a/internal/hooks/manager.go +++ b/internal/hooks/manager.go @@ -241,8 +241,6 @@ func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) { if new.Permission != "" { if new.Permission == "deny" { accumulated.Permission = "deny" - } else if new.Permission == "ask" && accumulated.Permission != "deny" { - accumulated.Permission = "ask" } else if new.Permission == "approve" && accumulated.Permission == "" { accumulated.Permission = "approve" } diff --git a/internal/hooks/types.go b/internal/hooks/types.go index 24e4b30a65ab7d34fe980a892e62b6919aba3696..367308c0c0da1a8043f211223c8615623af97e4b 100644 --- a/internal/hooks/types.go +++ b/internal/hooks/types.go @@ -97,6 +97,15 @@ type Manager interface { // ExecuteUserPromptSubmit executes the UserPromptSubmit event ExecuteUserPromptSubmit(ctx context.Context, sessionID, workingDir string, data UserPromptSubmitData) (HookResult, error) + + // ExecutePreToolUse executes the PreToolUse event + ExecutePreToolUse(ctx context.Context, sessionID, workingDir string, data PreToolUseData) (HookResult, error) + + // ExecutePostToolUse executes the PostToolUse event + ExecutePostToolUse(ctx context.Context, sessionID, workingDir string, data PostToolUseData) (HookResult, error) + + // ExecuteStop executes the Stop event + ExecuteStop(ctx context.Context, sessionID, workingDir string, data StopData) (HookResult, error) } type UserPromptSubmitData struct { diff --git a/internal/message/content.go b/internal/message/content.go index da93f8ad5562ed0aba2140f30bbec58fb2e08278..20b1e9fd08722a8b5cdb52514e0f6844372288ce 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -96,23 +96,25 @@ func (bc BinaryContent) String(p catwalk.InferenceProvider) string { func (BinaryContent) isPart() {} type ToolCall struct { - ID string `json:"id"` - Name string `json:"name"` - Input string `json:"input"` - ProviderExecuted bool `json:"provider_executed"` - Finished bool `json:"finished"` + ID string `json:"id"` + Name string `json:"name"` + Input string `json:"input"` + ProviderExecuted bool `json:"provider_executed"` + Finished bool `json:"finished"` + HookResult *hooks.HookResult `json:"hook_result,omitempty"` } func (ToolCall) isPart() {} type ToolResult struct { - ToolCallID string `json:"tool_call_id"` - Name string `json:"name"` - Content string `json:"content"` - Data string `json:"data"` - MIMEType string `json:"mime_type"` - Metadata string `json:"metadata"` - IsError bool `json:"is_error"` + ToolCallID string `json:"tool_call_id"` + Name string `json:"name"` + Content string `json:"content"` + Data string `json:"data"` + MIMEType string `json:"mime_type"` + Metadata string `json:"metadata"` + IsError bool `json:"is_error"` + HookResult *hooks.HookResult `json:"hook_result,omitempty"` } func (ToolResult) isPart() {}