wip: new hook

Kujtim Hoxha created

Change summary

agent.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---
1 file changed, 62 insertions(+), 4 deletions(-)

Detailed changes

agent.go 🔗

@@ -222,6 +222,12 @@ type (
 	// OnToolCallFunc is called when tool call is complete.
 	OnToolCallFunc func(toolCall ToolCallContent) error
 
+	// OnBeforeToolExecutionFunc is called right before tool execution.
+	// It can modify tool input and/or skip tool execution by returning (modifiedInput, skipExecution, error).
+	// If skipExecution is true, a synthetic error result is created with the provided error.
+	// If modifiedInput is not empty, it replaces the original tool input.
+	OnBeforeToolExecutionFunc func(toolCall ToolCallContent) (modifiedInput string, skipExecution bool, err error)
+
 	// OnToolResultFunc is called when tool execution completes.
 	OnToolResultFunc func(result ToolResultContent) error
 
@@ -274,6 +280,7 @@ type AgentStreamCall struct {
 	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 	OnToolCall       OnToolCallFunc       // Called when tool call is complete
+	OnBeforeToolExecution OnBeforeToolExecutionFunc // Called right before tool execution
 	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 	OnSource         OnSourceFunc         // Called for source references
 	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
@@ -453,7 +460,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 			}
 		}
 
-		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
+		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil)
 
 		// Build step content with validated tool calls and tool results
 		stepContent := []Content{}
@@ -607,7 +614,7 @@ func toResponseMessages(content []Content) []Message {
 	return messages
 }
 
-func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, beforeExecCallback OnBeforeToolExecutionFunc, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 	if len(toolCalls) == 0 {
 		return nil, nil
 	}
@@ -660,11 +667,62 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 			continue
 		}
 
+		// Call OnBeforeToolExecution hook to allow modification or skipping
+		inputToUse := toolCall.Input
+		if beforeExecCallback != nil {
+			modifiedInput, skipExecution, err := beforeExecCallback(toolCall)
+			if err != nil {
+				// Callback returned error - create synthetic error result
+				result := ToolResultContent{
+					ToolCallID: toolCall.ToolCallID,
+					ToolName:   toolCall.ToolName,
+					Result: ToolResultOutputContentError{
+						Error: err,
+					},
+					ProviderExecuted: false,
+				}
+				results = append(results, result)
+				if toolResultCallback != nil {
+					if cbErr := toolResultCallback(result); cbErr != nil {
+						return nil, cbErr
+					}
+				}
+				if skipExecution {
+					// Hook wants to skip execution
+					continue
+				} else {
+					// Hook returned error but doesn't want to skip - stop execution
+					return nil, err
+				}
+			}
+			if skipExecution {
+				// Hook wants to skip execution without error
+				result := ToolResultContent{
+					ToolCallID: toolCall.ToolCallID,
+					ToolName:   toolCall.ToolName,
+					Result: ToolResultOutputContentError{
+						Error: errors.New("Tool execution skipped by hook"),
+					},
+					ProviderExecuted: false,
+				}
+				results = append(results, result)
+				if toolResultCallback != nil {
+					if cbErr := toolResultCallback(result); cbErr != nil {
+						return nil, cbErr
+					}
+				}
+				continue
+			}
+			if modifiedInput != "" {
+				inputToUse = modifiedInput
+			}
+		}
+
 		// Execute the tool
 		toolResult, err := tool.Run(ctx, ToolCall{
 			ID:    toolCall.ToolCallID,
 			Name:  toolCall.ToolName,
-			Input: toolCall.Input,
+			Input: inputToUse,
 		})
 		if err != nil {
 			result := ToolResultContent{
@@ -1274,7 +1332,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 	var toolResults []ToolResultContent
 	if len(stepToolCalls) > 0 {
 		var err error
-		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
+		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnBeforeToolExecution, opts.OnToolResult)
 		if err != nil {
 			return StepResult{}, false, err
 		}