@@ -8,6 +8,7 @@ import (
"fmt"
"maps"
"slices"
+ "time"
)
// StepResult represents the result of a single step in an agent execution.
@@ -228,9 +229,20 @@ type (
// OnToolCallFunc is called when tool call is complete.
OnToolCallFunc func(toolCall ToolCallContent) error
+ // PreToolExecuteFunc is called before tool execution.
+ // Can modify the tool call or return an error to skip execution.
+ // Returning a modified ToolCall allows changing input parameters.
+ // Returning an error creates an error result without executing the tool.
+ PreToolExecuteFunc func(ctx context.Context, toolCall ToolCall) (context.Context, *ToolCall, error)
+
// OnToolResultFunc is called when tool execution completes.
OnToolResultFunc func(result ToolResultContent) error
+ // PostToolExecuteFunc is called after tool execution, before sending result to LLM.
+ // Can modify the tool response or return an error to replace the response.
+ // Returning a modified ToolResponse allows filtering or redacting output.
+ PostToolExecuteFunc func(ctx context.Context, toolCall ToolCall, response ToolResponse, executionTimeMs int64) (*ToolResponse, error)
+
// OnSourceFunc is called for source references.
OnSourceFunc func(source SourceContent) error
@@ -280,7 +292,9 @@ type AgentStreamCall struct {
OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
OnToolCall OnToolCallFunc // Called when tool call is complete
+ PreToolExecute PreToolExecuteFunc // Called before tool execution (can modify input or block)
OnToolResult OnToolResultFunc // Called when tool execution completes
+ PostToolExecute PostToolExecuteFunc // Called after tool execution (can modify output)
OnSource OnSourceFunc // Called for source references
OnStreamFinish OnStreamFinishFunc // Called when stream finishes
}
@@ -462,7 +476,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}
}
- toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
+ toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil, nil)
// Build step content with validated tool calls and tool results
stepContent := []Content{}
@@ -616,7 +630,7 @@ func toResponseMessages(content []Content) []Message {
return messages
}
-func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error, preToolExecute PreToolExecuteFunc, postToolExecute PostToolExecuteFunc) ([]ToolResultContent, error) {
if len(toolCalls) == 0 {
return nil, nil
}
@@ -669,12 +683,78 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
continue
}
- // Execute the tool
- toolResult, err := tool.Run(ctx, ToolCall{
+ // Prepare tool call for execution
+ executionToolCall := ToolCall{
ID: toolCall.ToolCallID,
Name: toolCall.ToolName,
Input: toolCall.Input,
- })
+ }
+
+ // Call pre-tool execute hook
+ var preHookErr error
+ toolCtx := ctx
+ if preToolExecute != nil {
+ updatedCtx, modifiedCall, err := preToolExecute(ctx, executionToolCall)
+ if err != nil {
+ preHookErr = err
+ } else {
+ toolCtx = updatedCtx
+ if modifiedCall != nil {
+ executionToolCall = *modifiedCall
+ }
+ }
+ }
+
+ // If pre-hook returned error, create error result and skip execution
+ if preHookErr != nil {
+ result := ToolResultContent{
+ ToolCallID: toolCall.ToolCallID,
+ ToolName: toolCall.ToolName,
+ Result: ToolResultOutputContentError{
+ Error: preHookErr,
+ },
+ ProviderExecuted: false,
+ }
+ results = append(results, result)
+ if toolResultCallback != nil {
+ if err := toolResultCallback(result); err != nil {
+ return nil, err
+ }
+ }
+ // Continue to next tool call instead of returning error
+ continue
+ }
+
+ // Execute the tool with timing
+ startTime := time.Now()
+ toolResult, err := tool.Run(toolCtx, executionToolCall)
+ executionTimeMs := time.Since(startTime).Milliseconds()
+
+ // Call post-tool execute hook
+ if postToolExecute != nil && err == nil {
+ modifiedResponse, postErr := postToolExecute(ctx, executionToolCall, toolResult, executionTimeMs)
+ if postErr != nil {
+ // Post-hook error stops execution
+ result := ToolResultContent{
+ ToolCallID: toolCall.ToolCallID,
+ ToolName: toolCall.ToolName,
+ Result: ToolResultOutputContentError{
+ Error: postErr,
+ },
+ ClientMetadata: toolResult.Metadata,
+ ProviderExecuted: false,
+ }
+ if toolResultCallback != nil {
+ if cbErr := toolResultCallback(result); cbErr != nil {
+ return nil, cbErr
+ }
+ }
+ return nil, postErr
+ } else if modifiedResponse != nil {
+ toolResult = *modifiedResponse
+ }
+ }
+
if err != nil {
result := ToolResultContent{
ToolCallID: toolCall.ToolCallID,
@@ -1307,7 +1387,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
var toolResults []ToolResultContent
if len(stepToolCalls) > 0 {
var err error
- toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
+ toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult, opts.PreToolExecute, opts.PostToolExecute)
if err != nil {
return stepExecutionResult{}, err
}