chore: add hooks for pre/post tool call manipulation

Kujtim Hoxha created

Change summary

agent.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---
1 file changed, 86 insertions(+), 6 deletions(-)

Detailed changes

agent.go 🔗

@@ -8,6 +8,7 @@ import (
 	"fmt"
 	"maps"
 	"slices"
+	"time"
 )
 
 // StepResult represents the result of a single step in an agent execution.
@@ -228,9 +229,20 @@ type (
 	// OnToolCallFunc is called when tool call is complete.
 	OnToolCallFunc func(toolCall ToolCallContent) error
 
+	// PreToolExecuteFunc is called before tool execution.
+	// Can modify the tool call or return an error to skip execution.
+	// Returning a modified ToolCall allows changing input parameters.
+	// Returning an error creates an error result without executing the tool.
+	PreToolExecuteFunc func(ctx context.Context, toolCall ToolCall) (context.Context, *ToolCall, error)
+
 	// OnToolResultFunc is called when tool execution completes.
 	OnToolResultFunc func(result ToolResultContent) error
 
+	// PostToolExecuteFunc is called after tool execution, before sending result to LLM.
+	// Can modify the tool response or return an error to replace the response.
+	// Returning a modified ToolResponse allows filtering or redacting output.
+	PostToolExecuteFunc func(ctx context.Context, toolCall ToolCall, response ToolResponse, executionTimeMs int64) (*ToolResponse, error)
+
 	// OnSourceFunc is called for source references.
 	OnSourceFunc func(source SourceContent) error
 
@@ -280,7 +292,9 @@ type AgentStreamCall struct {
 	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 	OnToolCall       OnToolCallFunc       // Called when tool call is complete
+	PreToolExecute   PreToolExecuteFunc   // Called before tool execution (can modify input or block)
 	OnToolResult     OnToolResultFunc     // Called when tool execution completes
+	PostToolExecute  PostToolExecuteFunc  // Called after tool execution (can modify output)
 	OnSource         OnSourceFunc         // Called for source references
 	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 }
@@ -462,7 +476,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 			}
 		}
 
-		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
+		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil, nil)
 
 		// Build step content with validated tool calls and tool results
 		stepContent := []Content{}
@@ -616,7 +630,7 @@ func toResponseMessages(content []Content) []Message {
 	return messages
 }
 
-func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error, preToolExecute PreToolExecuteFunc, postToolExecute PostToolExecuteFunc) ([]ToolResultContent, error) {
 	if len(toolCalls) == 0 {
 		return nil, nil
 	}
@@ -669,12 +683,78 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 			continue
 		}
 
-		// Execute the tool
-		toolResult, err := tool.Run(ctx, ToolCall{
+		// Prepare tool call for execution
+		executionToolCall := ToolCall{
 			ID:    toolCall.ToolCallID,
 			Name:  toolCall.ToolName,
 			Input: toolCall.Input,
-		})
+		}
+
+		// Call pre-tool execute hook
+		var preHookErr error
+		toolCtx := ctx
+		if preToolExecute != nil {
+			updatedCtx, modifiedCall, err := preToolExecute(ctx, executionToolCall)
+			if err != nil {
+				preHookErr = err
+			} else {
+				toolCtx = updatedCtx
+				if modifiedCall != nil {
+					executionToolCall = *modifiedCall
+				}
+			}
+		}
+
+		// If pre-hook returned error, create error result and skip execution
+		if preHookErr != nil {
+			result := ToolResultContent{
+				ToolCallID: toolCall.ToolCallID,
+				ToolName:   toolCall.ToolName,
+				Result: ToolResultOutputContentError{
+					Error: preHookErr,
+				},
+				ProviderExecuted: false,
+			}
+			results = append(results, result)
+			if toolResultCallback != nil {
+				if err := toolResultCallback(result); err != nil {
+					return nil, err
+				}
+			}
+			// Continue to next tool call instead of returning error
+			continue
+		}
+
+		// Execute the tool with timing
+		startTime := time.Now()
+		toolResult, err := tool.Run(toolCtx, executionToolCall)
+		executionTimeMs := time.Since(startTime).Milliseconds()
+
+		// Call post-tool execute hook
+		if postToolExecute != nil && err == nil {
+			modifiedResponse, postErr := postToolExecute(ctx, executionToolCall, toolResult, executionTimeMs)
+			if postErr != nil {
+				// Post-hook error stops execution
+				result := ToolResultContent{
+					ToolCallID: toolCall.ToolCallID,
+					ToolName:   toolCall.ToolName,
+					Result: ToolResultOutputContentError{
+						Error: postErr,
+					},
+					ClientMetadata:   toolResult.Metadata,
+					ProviderExecuted: false,
+				}
+				if toolResultCallback != nil {
+					if cbErr := toolResultCallback(result); cbErr != nil {
+						return nil, cbErr
+					}
+				}
+				return nil, postErr
+			} else if modifiedResponse != nil {
+				toolResult = *modifiedResponse
+			}
+		}
+
 		if err != nil {
 			result := ToolResultContent{
 				ToolCallID: toolCall.ToolCallID,
@@ -1307,7 +1387,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 	var toolResults []ToolResultContent
 	if len(stepToolCalls) > 0 {
 		var err error
-		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
+		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult, opts.PreToolExecute, opts.PostToolExecute)
 		if err != nil {
 			return stepExecutionResult{}, err
 		}