@@ -222,6 +222,12 @@ type (
// OnToolCallFunc is called when tool call is complete.
OnToolCallFunc func(toolCall ToolCallContent) error
+ // OnBeforeToolExecutionFunc is called right before tool execution.
+ // It can modify tool input and/or skip tool execution by returning (modifiedInput, skipExecution, error).
+ // If skipExecution is true, a synthetic error result is created with the provided error.
+ // If modifiedInput is not empty, it replaces the original tool input.
+ OnBeforeToolExecutionFunc func(toolCall ToolCallContent) (modifiedInput string, skipExecution bool, err error)
+
// OnToolResultFunc is called when tool execution completes.
OnToolResultFunc func(result ToolResultContent) error
@@ -274,6 +280,7 @@ type AgentStreamCall struct {
OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
OnToolCall OnToolCallFunc // Called when tool call is complete
+ OnBeforeToolExecution OnBeforeToolExecutionFunc // Called right before tool execution
OnToolResult OnToolResultFunc // Called when tool execution completes
OnSource OnSourceFunc // Called for source references
OnStreamFinish OnStreamFinishFunc // Called when stream finishes
@@ -453,7 +460,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}
}
- toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
+ toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil)
// Build step content with validated tool calls and tool results
stepContent := []Content{}
@@ -607,7 +614,7 @@ func toResponseMessages(content []Content) []Message {
return messages
}
-func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, beforeExecCallback OnBeforeToolExecutionFunc, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
if len(toolCalls) == 0 {
return nil, nil
}
@@ -660,11 +667,62 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
continue
}
+ // Call OnBeforeToolExecution hook to allow modification or skipping
+ inputToUse := toolCall.Input
+ if beforeExecCallback != nil {
+ modifiedInput, skipExecution, err := beforeExecCallback(toolCall)
+ if err != nil {
+ // Callback returned error - create synthetic error result
+ result := ToolResultContent{
+ ToolCallID: toolCall.ToolCallID,
+ ToolName: toolCall.ToolName,
+ Result: ToolResultOutputContentError{
+ Error: err,
+ },
+ ProviderExecuted: false,
+ }
+ results = append(results, result)
+ if toolResultCallback != nil {
+ if cbErr := toolResultCallback(result); cbErr != nil {
+ return nil, cbErr
+ }
+ }
+ if skipExecution {
+ // Hook wants to skip execution
+ continue
+ } else {
+ // Hook returned error but doesn't want to skip - stop execution
+ return nil, err
+ }
+ }
+ if skipExecution {
+ // Hook wants to skip execution without error
+ result := ToolResultContent{
+ ToolCallID: toolCall.ToolCallID,
+ ToolName: toolCall.ToolName,
+ Result: ToolResultOutputContentError{
+ Error: errors.New("Tool execution skipped by hook"),
+ },
+ ProviderExecuted: false,
+ }
+ results = append(results, result)
+ if toolResultCallback != nil {
+ if cbErr := toolResultCallback(result); cbErr != nil {
+ return nil, cbErr
+ }
+ }
+ continue
+ }
+ if modifiedInput != "" {
+ inputToUse = modifiedInput
+ }
+ }
+
// Execute the tool
toolResult, err := tool.Run(ctx, ToolCall{
ID: toolCall.ToolCallID,
Name: toolCall.ToolName,
- Input: toolCall.Input,
+ Input: inputToUse,
})
if err != nil {
result := ToolResultContent{
@@ -1274,7 +1332,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
var toolResults []ToolResultContent
if len(stepToolCalls) > 0 {
var err error
- toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
+ toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnBeforeToolExecution, opts.OnToolResult)
if err != nil {
return StepResult{}, false, err
}