feat: improve tool execution (#88)

Kujtim Hoxha created

* feat: improve tool execution

- makes it so we execute the tool call right after it is available and
than queues other tool calls.
- allow tools to be called in parallel (e.x for the agent tool)

* chore: simplify and rename func

* refactor: simplify code

* chore: change the signature

Change summary

agent.go | 149 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---
tool.go  |  23 ++++++++
2 files changed, 163 insertions(+), 9 deletions(-)

Detailed changes

agent.go 🔗

@@ -8,6 +8,7 @@ import (
 	"fmt"
 	"maps"
 	"slices"
+	"sync"
 )
 
 // StepResult represents the result of a single step in an agent execution.
@@ -730,6 +731,71 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 	return results, nil
 }
 
+// executeSingleTool executes a single tool and returns its result and a critical error flag.
+func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
+	result := ToolResultContent{
+		ToolCallID:       toolCall.ToolCallID,
+		ToolName:         toolCall.ToolName,
+		ProviderExecuted: false,
+	}
+
+	// Skip invalid tool calls - create error result (not critical)
+	if toolCall.Invalid {
+		result.Result = ToolResultOutputContentError{
+			Error: toolCall.ValidationError,
+		}
+		if toolResultCallback != nil {
+			_ = toolResultCallback(result)
+		}
+		return result, false
+	}
+
+	tool, exists := toolMap[toolCall.ToolName]
+	if !exists {
+		result.Result = ToolResultOutputContentError{
+			Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
+		}
+		if toolResultCallback != nil {
+			_ = toolResultCallback(result)
+		}
+		return result, false
+	}
+
+	// Execute the tool
+	toolResult, err := tool.Run(ctx, ToolCall{
+		ID:    toolCall.ToolCallID,
+		Name:  toolCall.ToolName,
+		Input: toolCall.Input,
+	})
+	if err != nil {
+		result.Result = ToolResultOutputContentError{
+			Error: err,
+		}
+		result.ClientMetadata = toolResult.Metadata
+		if toolResultCallback != nil {
+			_ = toolResultCallback(result)
+		}
+		// This is a critical error - tool.Run() failed
+		return result, true
+	}
+
+	result.ClientMetadata = toolResult.Metadata
+	if toolResult.IsError {
+		result.Result = ToolResultOutputContentError{
+			Error: errors.New(toolResult.Content),
+		}
+	} else {
+		result.Result = ToolResultOutputContentText{
+			Text: toolResult.Content,
+		}
+	}
+	if toolResultCallback != nil {
+		_ = toolResultCallback(result)
+	}
+	// Not a critical error - tool ran successfully (even if it reported an error state)
+	return result, false
+}
+
 // Stream implements Agent.
 func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 	// Convert AgentStreamCall to AgentCall for preparation
@@ -1121,6 +1187,60 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 	}
 	activeReasoningContent := make(map[string]reasoningContent)
 
+	// Set up concurrent tool execution
+	type toolExecutionRequest struct {
+		toolCall ToolCallContent
+		parallel bool
+	}
+	toolChan := make(chan toolExecutionRequest, 10)
+	var toolExecutionWg sync.WaitGroup
+	var toolStateMu sync.Mutex
+	toolResults := make([]ToolResultContent, 0)
+	var toolExecutionErr error
+
+	// Create a map for quick tool lookup
+	toolMap := make(map[string]AgentTool)
+	for _, tool := range stepTools {
+		toolMap[tool.Info().Name] = tool
+	}
+
+	// Semaphores for controlling parallelism
+	parallelSem := make(chan struct{}, 5)
+	var sequentialMu sync.Mutex
+
+	// Single coordinator goroutine that dispatches tools
+	toolExecutionWg.Go(func() {
+		for req := range toolChan {
+			if req.parallel {
+				parallelSem <- struct{}{}
+				toolExecutionWg.Go(func() {
+					defer func() { <-parallelSem }()
+					result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
+					toolStateMu.Lock()
+					toolResults = append(toolResults, result)
+					if isCriticalError && toolExecutionErr == nil {
+						if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
+							toolExecutionErr = errorResult.Error
+						}
+					}
+					toolStateMu.Unlock()
+				})
+			} else {
+				sequentialMu.Lock()
+				result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
+				toolStateMu.Lock()
+				toolResults = append(toolResults, result)
+				if isCriticalError && toolExecutionErr == nil {
+					if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
+						toolExecutionErr = errorResult.Error
+					}
+				}
+				toolStateMu.Unlock()
+				sequentialMu.Unlock()
+			}
+		}
+	})
+
 	// Process stream parts
 	for part := range stream {
 		// Forward all parts to chunk callback
@@ -1275,6 +1395,15 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 				}
 			}
 
+			// Determine if tool can run in parallel
+			isParallel := false
+			if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
+				isParallel = tool.Info().Parallel
+			}
+
+			// Send tool call to execution channel
+			toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
+
 			// Clean up active tool call
 			delete(activeToolCalls, part.ID)
 
@@ -1310,15 +1439,17 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 		}
 	}
 
-	// Execute tools if any
-	var toolResults []ToolResultContent
-	if len(stepToolCalls) > 0 {
-		var err error
-		toolResults, err = a.executeTools(ctx, stepTools, stepToolCalls, opts.OnToolResult)
-		if err != nil {
-			return stepExecutionResult{}, err
-		}
-		// Add tool results to content
+	// Close the tool execution channel and wait for all executions to complete
+	close(toolChan)
+	toolExecutionWg.Wait()
+
+	// Check for tool execution errors
+	if toolExecutionErr != nil {
+		return stepExecutionResult{}, toolExecutionErr
+	}
+
+	// Add tool results to content if any
+	if len(toolResults) > 0 {
 		for _, result := range toolResults {
 			stepContent = append(stepContent, result)
 		}

tool.go 🔗

@@ -18,6 +18,7 @@ type ToolInfo struct {
 	Description string         `json:"description"`
 	Parameters  map[string]any `json:"parameters"`
 	Required    []string       `json:"required"`
+	Parallel    bool           `json:"parallel"` // Whether this tool can run in parallel with other tools
 }
 
 // ToolCall represents a tool invocation, matching the existing pattern.
@@ -88,9 +89,25 @@ func NewAgentTool[TInput any](
 		description: description,
 		fn:          fn,
 		schema:      schema,
+		parallel:    false, // Default to sequential execution
 	}
 }
 
+// NewParallelAgentTool creates a typed tool from a function with automatic schema generation.
+// This also marks a tool as safe to run in parallel with other tools.
+func NewParallelAgentTool[TInput any](
+	name string,
+	description string,
+	fn func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error),
+) AgentTool {
+	tool := NewAgentTool(name, description, fn)
+	// Try to use the SetParallel method if available
+	if setter, ok := tool.(interface{ SetParallel(bool) }); ok {
+		setter.SetParallel(true)
+	}
+	return tool
+}
+
 // funcToolWrapper wraps a function to implement the AgentTool interface.
 type funcToolWrapper[TInput any] struct {
 	name            string
@@ -98,6 +115,7 @@ type funcToolWrapper[TInput any] struct {
 	fn              func(ctx context.Context, input TInput, call ToolCall) (ToolResponse, error)
 	schema          Schema
 	providerOptions ProviderOptions
+	parallel        bool
 }
 
 func (w *funcToolWrapper[TInput]) SetProviderOptions(opts ProviderOptions) {
@@ -108,6 +126,10 @@ func (w *funcToolWrapper[TInput]) ProviderOptions() ProviderOptions {
 	return w.providerOptions
 }
 
+func (w *funcToolWrapper[TInput]) SetParallel(parallel bool) {
+	w.parallel = parallel
+}
+
 func (w *funcToolWrapper[TInput]) Info() ToolInfo {
 	if w.schema.Required == nil {
 		w.schema.Required = []string{}
@@ -117,6 +139,7 @@ func (w *funcToolWrapper[TInput]) Info() ToolInfo {
 		Description: w.description,
 		Parameters:  schema.ToParameters(w.schema),
 		Required:    w.schema.Required,
+		Parallel:    w.parallel,
 	}
 }