@@ -8,6 +8,7 @@ import (
"fmt"
"maps"
"slices"
+ "sync"
)
// StepResult represents the result of a single step in an agent execution.
@@ -730,6 +731,71 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
return results, nil
}
+// executeSingleTool executes a single tool and returns its result and a critical error flag.
+func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
+ result := ToolResultContent{
+ ToolCallID: toolCall.ToolCallID,
+ ToolName: toolCall.ToolName,
+ ProviderExecuted: false,
+ }
+
+ // Skip invalid tool calls - create error result (not critical)
+ if toolCall.Invalid {
+ result.Result = ToolResultOutputContentError{
+ Error: toolCall.ValidationError,
+ }
+ if toolResultCallback != nil {
+ _ = toolResultCallback(result)
+ }
+ return result, false
+ }
+
+ tool, exists := toolMap[toolCall.ToolName]
+ if !exists {
+ result.Result = ToolResultOutputContentError{
+ Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
+ }
+ if toolResultCallback != nil {
+ _ = toolResultCallback(result)
+ }
+ return result, false
+ }
+
+ // Execute the tool
+ toolResult, err := tool.Run(ctx, ToolCall{
+ ID: toolCall.ToolCallID,
+ Name: toolCall.ToolName,
+ Input: toolCall.Input,
+ })
+ if err != nil {
+ result.Result = ToolResultOutputContentError{
+ Error: err,
+ }
+ result.ClientMetadata = toolResult.Metadata
+ if toolResultCallback != nil {
+ _ = toolResultCallback(result)
+ }
+ // This is a critical error - tool.Run() failed
+ return result, true
+ }
+
+ result.ClientMetadata = toolResult.Metadata
+ if toolResult.IsError {
+ result.Result = ToolResultOutputContentError{
+ Error: errors.New(toolResult.Content),
+ }
+ } else {
+ result.Result = ToolResultOutputContentText{
+ Text: toolResult.Content,
+ }
+ }
+ if toolResultCallback != nil {
+ _ = toolResultCallback(result)
+ }
+ // Not a critical error - tool ran successfully (even if it reported an error state)
+ return result, false
+}
+
// Stream implements Agent.
func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
// Convert AgentStreamCall to AgentCall for preparation
@@ -1121,6 +1187,60 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
activeReasoningContent := make(map[string]reasoningContent)
+ // Set up concurrent tool execution
+ type toolExecutionRequest struct {
+ toolCall ToolCallContent
+ parallel bool
+ }
+ toolChan := make(chan toolExecutionRequest, 10)
+ var toolExecutionWg sync.WaitGroup
+ var toolStateMu sync.Mutex
+ toolResults := make([]ToolResultContent, 0)
+ var toolExecutionErr error
+
+ // Create a map for quick tool lookup
+ toolMap := make(map[string]AgentTool)
+ for _, tool := range stepTools {
+ toolMap[tool.Info().Name] = tool
+ }
+
+ // Semaphores for controlling parallelism
+ parallelSem := make(chan struct{}, 5)
+ var sequentialMu sync.Mutex
+
+ // Single coordinator goroutine that dispatches tools
+ toolExecutionWg.Go(func() {
+ for req := range toolChan {
+ if req.parallel {
+ parallelSem <- struct{}{}
+ toolExecutionWg.Go(func() {
+ defer func() { <-parallelSem }()
+ result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
+ toolStateMu.Lock()
+ toolResults = append(toolResults, result)
+ if isCriticalError && toolExecutionErr == nil {
+ if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
+ toolExecutionErr = errorResult.Error
+ }
+ }
+ toolStateMu.Unlock()
+ })
+ } else {
+ sequentialMu.Lock()
+ result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
+ toolStateMu.Lock()
+ toolResults = append(toolResults, result)
+ if isCriticalError && toolExecutionErr == nil {
+ if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
+ toolExecutionErr = errorResult.Error
+ }
+ }
+ toolStateMu.Unlock()
+ sequentialMu.Unlock()
+ }
+ }
+ })
+
// Process stream parts
for part := range stream {
// Forward all parts to chunk callback
@@ -1275,6 +1395,15 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
}
+ // Determine if tool can run in parallel
+ isParallel := false
+ if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
+ isParallel = tool.Info().Parallel
+ }
+
+ // Send tool call to execution channel
+ toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
+
// Clean up active tool call
delete(activeToolCalls, part.ID)
@@ -1310,15 +1439,17 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
}
- // Execute tools if any
- var toolResults []ToolResultContent
- if len(stepToolCalls) > 0 {
- var err error
- toolResults, err = a.executeTools(ctx, stepTools, stepToolCalls, opts.OnToolResult)
- if err != nil {
- return stepExecutionResult{}, err
- }
- // Add tool results to content
+ // Close the tool execution channel and wait for all executions to complete
+ close(toolChan)
+ toolExecutionWg.Wait()
+
+ // Check for tool execution errors
+ if toolExecutionErr != nil {
+ return stepExecutionResult{}, toolExecutionErr
+ }
+
+ // Add tool results to content if any
+ if len(toolResults) > 0 {
for _, result := range toolResults {
stepContent = append(stepContent, result)
}
@@ -18,6 +18,7 @@ type ToolInfo struct {
Description string `json:"description"`
Parameters map[string]any `json:"parameters"`
Required []string `json:"required"`
+ Parallel bool `json:"parallel"` // Whether this tool can run in parallel with other tools
}
// ToolCall represents a tool invocation, matching the existing pattern.
@@ -88,9 +89,25 @@ func NewAgentTool[TInput any](
description: description,
fn: fn,
schema: schema,
+ parallel: false, // Default to sequential execution
}
}
+// NewParallelAgentTool creates a typed tool from a function with automatic schema generation.
+// This also marks a tool as safe to run in parallel with other tools.
+func NewParallelAgentTool[TInput any](
+ name string,
+ description string,
+ fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
+) AgentTool {
+ tool := NewAgentTool(name, description, fn)
+ // Try to use the SetParallel method if available
+ if setter, ok := tool.(interface{ SetParallel(bool) }); ok {
+ setter.SetParallel(true)
+ }
+ return tool
+}
+
// funcToolWrapper wraps a function to implement the AgentTool interface.
type funcToolWrapper[TInput any] struct {
name string
@@ -98,6 +115,7 @@ type funcToolWrapper[TInput any] struct {
fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
schema Schema
providerOptions ProviderOptions
+ parallel bool
}
func (w *funcToolWrapper[TInput]) SetProviderOptions(opts ProviderOptions) {
@@ -108,6 +126,10 @@ func (w *funcToolWrapper[TInput]) ProviderOptions() ProviderOptions {
return w.providerOptions
}
+func (w *funcToolWrapper[TInput]) SetParallel(parallel bool) {
+ w.parallel = parallel
+}
+
func (w *funcToolWrapper[TInput]) Info() ToolInfo {
if w.schema.Required == nil {
w.schema.Required = []string{}
@@ -117,6 +139,7 @@ func (w *funcToolWrapper[TInput]) Info() ToolInfo {
Description: w.description,
Parameters: schema.ToParameters(w.schema),
Required: w.schema.Required,
+ Parallel: w.parallel,
}
}