From 1c6ba1faf8d9f3a1b02ed73b3500ffa1b4b8541c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 4 Dec 2025 15:16:22 +0100 Subject: [PATCH] feat: improve tool execution (#88) * feat: improve tool execution - makes it so we execute the tool call right after it is available and than queues other tool calls. - allow tools to be called in parallel (e.x for the agent tool) * chore: simplify and rename func * refactor: simplify code * chore: change the signature --- agent.go | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++---- tool.go | 23 +++++++++ 2 files changed, 163 insertions(+), 9 deletions(-) diff --git a/agent.go b/agent.go index 747ac86de5e629a4abccdabc12b1e2d8d3d5b33f..aa11f7e6f49ee7c1a3a4db764b5bae6f822b854c 100644 --- a/agent.go +++ b/agent.go @@ -8,6 +8,7 @@ import ( "fmt" "maps" "slices" + "sync" ) // StepResult represents the result of a single step in an agent execution. @@ -730,6 +731,71 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall return results, nil } +// executeSingleTool executes a single tool and returns its result and a critical error flag. +func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) { + result := ToolResultContent{ + ToolCallID: toolCall.ToolCallID, + ToolName: toolCall.ToolName, + ProviderExecuted: false, + } + + // Skip invalid tool calls - create error result (not critical) + if toolCall.Invalid { + result.Result = ToolResultOutputContentError{ + Error: toolCall.ValidationError, + } + if toolResultCallback != nil { + _ = toolResultCallback(result) + } + return result, false + } + + tool, exists := toolMap[toolCall.ToolName] + if !exists { + result.Result = ToolResultOutputContentError{ + Error: errors.New("Error: Tool not found: " + toolCall.ToolName), + } + if toolResultCallback != nil { + _ = toolResultCallback(result) + } + return result, false + } + + // Execute the tool + toolResult, err := tool.Run(ctx, ToolCall{ + ID: toolCall.ToolCallID, + Name: toolCall.ToolName, + Input: toolCall.Input, + }) + if err != nil { + result.Result = ToolResultOutputContentError{ + Error: err, + } + result.ClientMetadata = toolResult.Metadata + if toolResultCallback != nil { + _ = toolResultCallback(result) + } + // This is a critical error - tool.Run() failed + return result, true + } + + result.ClientMetadata = toolResult.Metadata + if toolResult.IsError { + result.Result = ToolResultOutputContentError{ + Error: errors.New(toolResult.Content), + } + } else { + result.Result = ToolResultOutputContentText{ + Text: toolResult.Content, + } + } + if toolResultCallback != nil { + _ = toolResultCallback(result) + } + // Not a critical error - tool ran successfully (even if it reported an error state) + return result, false +} + // Stream implements Agent. func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) { // Convert AgentStreamCall to AgentCall for preparation @@ -1121,6 +1187,60 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } activeReasoningContent := make(map[string]reasoningContent) + // Set up concurrent tool execution + type toolExecutionRequest struct { + toolCall ToolCallContent + parallel bool + } + toolChan := make(chan toolExecutionRequest, 10) + var toolExecutionWg sync.WaitGroup + var toolStateMu sync.Mutex + toolResults := make([]ToolResultContent, 0) + var toolExecutionErr error + + // Create a map for quick tool lookup + toolMap := make(map[string]AgentTool) + for _, tool := range stepTools { + toolMap[tool.Info().Name] = tool + } + + // Semaphores for controlling parallelism + parallelSem := make(chan struct{}, 5) + var sequentialMu sync.Mutex + + // Single coordinator goroutine that dispatches tools + toolExecutionWg.Go(func() { + for req := range toolChan { + if req.parallel { + parallelSem <- struct{}{} + toolExecutionWg.Go(func() { + defer func() { <-parallelSem }() + result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult) + toolStateMu.Lock() + toolResults = append(toolResults, result) + if isCriticalError && toolExecutionErr == nil { + if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil { + toolExecutionErr = errorResult.Error + } + } + toolStateMu.Unlock() + }) + } else { + sequentialMu.Lock() + result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult) + toolStateMu.Lock() + toolResults = append(toolResults, result) + if isCriticalError && toolExecutionErr == nil { + if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil { + toolExecutionErr = errorResult.Error + } + } + toolStateMu.Unlock() + sequentialMu.Unlock() + } + } + }) + // Process stream parts for part := range stream { // Forward all parts to chunk callback @@ -1275,6 +1395,15 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } } + // Determine if tool can run in parallel + isParallel := false + if tool, exists := toolMap[validatedToolCall.ToolName]; exists { + isParallel = tool.Info().Parallel + } + + // Send tool call to execution channel + toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel} + // Clean up active tool call delete(activeToolCalls, part.ID) @@ -1310,15 +1439,17 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } } - // Execute tools if any - var toolResults []ToolResultContent - if len(stepToolCalls) > 0 { - var err error - toolResults, err = a.executeTools(ctx, stepTools, stepToolCalls, opts.OnToolResult) - if err != nil { - return stepExecutionResult{}, err - } - // Add tool results to content + // Close the tool execution channel and wait for all executions to complete + close(toolChan) + toolExecutionWg.Wait() + + // Check for tool execution errors + if toolExecutionErr != nil { + return stepExecutionResult{}, toolExecutionErr + } + + // Add tool results to content if any + if len(toolResults) > 0 { for _, result := range toolResults { stepContent = append(stepContent, result) } diff --git a/tool.go b/tool.go index 9731739ceaf8fcf5dc666f3f0e11a8ab22d5013f..e6823062419756d7313d3dce9fb19036775fb9ec 100644 --- a/tool.go +++ b/tool.go @@ -18,6 +18,7 @@ type ToolInfo struct { Description string `json:"description"` Parameters map[string]any `json:"parameters"` Required []string `json:"required"` + Parallel bool `json:"parallel"` // Whether this tool can run in parallel with other tools } // ToolCall represents a tool invocation, matching the existing pattern. @@ -88,9 +89,25 @@ func NewAgentTool[TInput any]( description: description, fn: fn, schema: schema, + parallel: false, // Default to sequential execution } } +// NewParallelAgentTool creates a typed tool from a function with automatic schema generation. +// This also marks a tool as safe to run in parallel with other tools. +func NewParallelAgentTool[TInput any]( + name string, + description string, + fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error), +) AgentTool { + tool := NewAgentTool(name, description, fn) + // Try to use the SetParallel method if available + if setter, ok := tool.(interface{ SetParallel(bool) }); ok { + setter.SetParallel(true) + } + return tool +} + // funcToolWrapper wraps a function to implement the AgentTool interface. type funcToolWrapper[TInput any] struct { name string @@ -98,6 +115,7 @@ type funcToolWrapper[TInput any] struct { fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error) schema Schema providerOptions ProviderOptions + parallel bool } func (w *funcToolWrapper[TInput]) SetProviderOptions(opts ProviderOptions) { @@ -108,6 +126,10 @@ func (w *funcToolWrapper[TInput]) ProviderOptions() ProviderOptions { return w.providerOptions } +func (w *funcToolWrapper[TInput]) SetParallel(parallel bool) { + w.parallel = parallel +} + func (w *funcToolWrapper[TInput]) Info() ToolInfo { if w.schema.Required == nil { w.schema.Required = []string{} @@ -117,6 +139,7 @@ func (w *funcToolWrapper[TInput]) Info() ToolInfo { Description: w.description, Parameters: schema.ToParameters(w.schema), Required: w.schema.Required, + Parallel: w.parallel, } }