diff --git a/agent.go b/agent.go index 214cb388918fa15115716ccf58d96fffce36cf41..19beb0ecb6e89f8f9be3a8eb224c78cb1b52cd3b 100644 --- a/agent.go +++ b/agent.go @@ -222,6 +222,12 @@ type ( // OnToolCallFunc is called when tool call is complete. OnToolCallFunc func(toolCall ToolCallContent) error + // OnBeforeToolExecutionFunc is called right before tool execution. + // It can modify tool input and/or skip tool execution by returning (modifiedInput, skipExecution, error). + // If skipExecution is true, a synthetic error result is created with the provided error. + // If modifiedInput is not empty, it replaces the original tool input. + OnBeforeToolExecutionFunc func(toolCall ToolCallContent) (modifiedInput string, skipExecution bool, err error) + // OnToolResultFunc is called when tool execution completes. OnToolResultFunc func(result ToolResultContent) error @@ -274,6 +280,7 @@ type AgentStreamCall struct { OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas OnToolInputEnd OnToolInputEndFunc // Called when tool input ends OnToolCall OnToolCallFunc // Called when tool call is complete + OnBeforeToolExecution OnBeforeToolExecutionFunc // Called right before tool execution OnToolResult OnToolResultFunc // Called when tool execution completes OnSource OnSourceFunc // Called for source references OnStreamFinish OnStreamFinishFunc // Called when stream finishes @@ -453,7 +460,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err } } - toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil) + toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil) // Build step content with validated tool calls and tool results stepContent := []Content{} @@ -607,7 +614,7 @@ func toResponseMessages(content []Content) []Message { return messages } -func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) { +func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, beforeExecCallback OnBeforeToolExecutionFunc, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) { if len(toolCalls) == 0 { return nil, nil } @@ -660,11 +667,62 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall continue } + // Call OnBeforeToolExecution hook to allow modification or skipping + inputToUse := toolCall.Input + if beforeExecCallback != nil { + modifiedInput, skipExecution, err := beforeExecCallback(toolCall) + if err != nil { + // Callback returned error - create synthetic error result + result := ToolResultContent{ + ToolCallID: toolCall.ToolCallID, + ToolName: toolCall.ToolName, + Result: ToolResultOutputContentError{ + Error: err, + }, + ProviderExecuted: false, + } + results = append(results, result) + if toolResultCallback != nil { + if cbErr := toolResultCallback(result); cbErr != nil { + return nil, cbErr + } + } + if skipExecution { + // Hook wants to skip execution + continue + } else { + // Hook returned error but doesn't want to skip - stop execution + return nil, err + } + } + if skipExecution { + // Hook wants to skip execution without error + result := ToolResultContent{ + ToolCallID: toolCall.ToolCallID, + ToolName: toolCall.ToolName, + Result: ToolResultOutputContentError{ + Error: errors.New("Tool execution skipped by hook"), + }, + ProviderExecuted: false, + } + results = append(results, result) + if toolResultCallback != nil { + if cbErr := toolResultCallback(result); cbErr != nil { + return nil, cbErr + } + } + continue + } + if modifiedInput != "" { + inputToUse = modifiedInput + } + } + // Execute the tool toolResult, err := tool.Run(ctx, ToolCall{ ID: toolCall.ToolCallID, Name: toolCall.ToolName, - Input: toolCall.Input, + Input: inputToUse, }) if err != nil { result := ToolResultContent{ @@ -1274,7 +1332,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op var toolResults []ToolResultContent if len(stepToolCalls) > 0 { var err error - toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult) + toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnBeforeToolExecution, opts.OnToolResult) if err != nil { return StepResult{}, false, err }