chore: allow updating tools in prepare step (#86)

Kujtim Hoxha created

Change summary

agent.go | 27 +++++++++++++++++----------
1 file changed, 17 insertions(+), 10 deletions(-)

Detailed changes

agent.go 🔗

@@ -103,6 +103,7 @@ type PrepareStepResult struct {
 	ToolChoice      *ToolChoice
 	ActiveTools     []string
 	DisableAllTools bool
+	Tools           []AgentTool
 }
 
 // ToolCallRepairOptions contains the options for repairing a tool call.
@@ -376,7 +377,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 		stepActiveTools := opts.ActiveTools
 		stepToolChoice := ToolChoiceAuto
 		disableAllTools := false
-
+		stepTools := a.settings.tools
 		if opts.PrepareStep != nil {
 			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 				Model:      stepModel,
@@ -407,6 +408,9 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 				stepActiveTools = prepared.ActiveTools
 			}
 			disableAllTools = prepared.DisableAllTools
+			if prepared.Tools != nil {
+				stepTools = prepared.Tools
+			}
 		}
 
 		// Recreate prompt with potentially modified system prompt
@@ -421,7 +425,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 			}
 		}
 
-		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
+		preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
 
 		retryOptions := DefaultRetryOptions()
 		if opts.MaxRetries != nil {
@@ -457,12 +461,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 				}
 
 				// Validate and potentially repair the tool call
-				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
+				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 				stepToolCalls = append(stepToolCalls, validatedToolCall)
 			}
 		}
 
-		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
+		toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
 
 		// Build step content with validated tool calls and tool results
 		stepContent := []Content{}
@@ -771,7 +775,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 		stepActiveTools := call.ActiveTools
 		stepToolChoice := ToolChoiceAuto
 		disableAllTools := false
-
+		stepTools := a.settings.tools
 		// Apply step preparation if provided
 		if call.PrepareStep != nil {
 			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
@@ -802,6 +806,9 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 				stepActiveTools = prepared.ActiveTools
 			}
 			disableAllTools = prepared.DisableAllTools
+			if prepared.Tools != nil {
+				stepTools = prepared.Tools
+			}
 		}
 
 		// Recreate prompt with potentially modified system prompt
@@ -815,7 +822,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 			}
 		}
 
-		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
+		preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
 
 		// Start step stream
 		if opts.OnStepStart != nil {
@@ -852,7 +859,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 			}
 
 			// Process the stream
-			result, err := a.processStepStream(ctx, stream, opts, steps)
+			result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
 			if err != nil {
 				return stepExecutionResult{}, err
 			}
@@ -1098,7 +1105,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption {
 }
 
 // processStepStream processes a single step's stream and returns the step result.
-func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) {
+func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
 	var stepContent []Content
 	var stepToolCalls []ToolCallContent
 	var stepUsage Usage
@@ -1257,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			}
 
 			// Validate and potentially repair the tool call
-			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
+			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
 			stepToolCalls = append(stepToolCalls, validatedToolCall)
 			stepContent = append(stepContent, validatedToolCall)
 
@@ -1307,7 +1314,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 	var toolResults []ToolResultContent
 	if len(stepToolCalls) > 0 {
 		var err error
-		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
+		toolResults, err = a.executeTools(ctx, stepTools, stepToolCalls, opts.OnToolResult)
 		if err != nil {
 			return stepExecutionResult{}, err
 		}