chore(agent): move tool execution after stream loop + add test

Christian Rocha created

Change summary

agent.go             | 93 +++++++++++++++++++++++----------------------
agent_stream_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++
2 files changed, 134 insertions(+), 45 deletions(-)

Detailed changes

agent.go 🔗

@@ -1247,12 +1247,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 		toolCall ToolCallContent
 		parallel bool
 	}
-	toolChan := make(chan toolExecutionRequest, 10)
 	var pendingDispatches []toolExecutionRequest
-	var toolExecutionWg sync.WaitGroup
-	var toolStateMu sync.Mutex
-	toolResults := make([]ToolResultContent, 0)
-	var toolExecutionErr error
 
 	// Create a map for quick tool lookup
 	toolMap := make(map[string]AgentTool)
@@ -1265,43 +1260,6 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 		execProviderToolMap[ept.GetName()] = ept
 	}
 
-	// Semaphores for controlling parallelism
-	parallelSem := make(chan struct{}, 5)
-	var sequentialMu sync.Mutex
-
-	// Single coordinator goroutine that dispatches tools
-	toolExecutionWg.Go(func() {
-		for req := range toolChan {
-			if req.parallel {
-				parallelSem <- struct{}{}
-				toolExecutionWg.Go(func() {
-					defer func() { <-parallelSem }()
-					result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
-					toolStateMu.Lock()
-					toolResults = append(toolResults, result)
-					if isCriticalError && toolExecutionErr == nil {
-						if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
-							toolExecutionErr = errorResult.Error
-						}
-					}
-					toolStateMu.Unlock()
-				})
-			} else {
-				sequentialMu.Lock()
-				result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
-				toolStateMu.Lock()
-				toolResults = append(toolResults, result)
-				if isCriticalError && toolExecutionErr == nil {
-					if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
-						toolExecutionErr = errorResult.Error
-					}
-				}
-				toolStateMu.Unlock()
-				sequentialMu.Unlock()
-			}
-		}
-	})
-
 	// Process stream parts
 	for part := range stream {
 		// Forward all parts to chunk callback
@@ -1536,13 +1494,58 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 		}
 	}
 
-	// Dispatch all buffered tool calls now that the complete set is known and
-	// every OnToolCall callback has been called.
+	// All tool calls are now collected. Create the execution channel sized to
+	// avoid blocking during dispatch, start the coordinator, then flush the batch.
+	toolChan := make(chan toolExecutionRequest, len(pendingDispatches))
+	var toolExecutionWg sync.WaitGroup
+	var toolStateMu sync.Mutex
+	toolResults := make([]ToolResultContent, 0, len(pendingDispatches))
+	var toolExecutionErr error
+
+	// Semaphores for controlling parallelism.
+	parallelSem := make(chan struct{}, 5)
+	var sequentialMu sync.Mutex
+
+	// Single coordinator goroutine that dispatches tools.
+	toolExecutionWg.Go(func() {
+		for req := range toolChan {
+			if req.parallel {
+				parallelSem <- struct{}{}
+				toolExecutionWg.Go(func() {
+					defer func() { <-parallelSem }()
+					result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
+					toolStateMu.Lock()
+					toolResults = append(toolResults, result)
+					if isCriticalError && toolExecutionErr == nil {
+						if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
+							toolExecutionErr = errorResult.Error
+						}
+					}
+					toolStateMu.Unlock()
+				})
+			} else {
+				sequentialMu.Lock()
+				result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
+				toolStateMu.Lock()
+				toolResults = append(toolResults, result)
+				if isCriticalError && toolExecutionErr == nil {
+					if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
+						toolExecutionErr = errorResult.Error
+					}
+				}
+				toolStateMu.Unlock()
+				sequentialMu.Unlock()
+			}
+		}
+	})
+
+	// Dispatch all buffered tool calls now that every OnToolCall callback has
+	// been called, then close and wait.
 	for _, req := range pendingDispatches {
 		toolChan <- req
 	}
 
-	// Close the tool execution channel and wait for all executions to complete
+	// Close the tool execution channel and wait for all executions to complete.
 	close(toolChan)
 	toolExecutionWg.Wait()
 

agent_stream_test.go 🔗

@@ -4,6 +4,8 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"strings"
+	"sync"
 	"testing"
 
 	"github.com/stretchr/testify/require"
@@ -366,6 +368,90 @@ func TestStreamingAgentWithTools(t *testing.T) {
 	require.Equal(t, "echo", toolResults[0].ToolName)
 }
 
+// TestStreamingAgentToolCallBeforeResult verifies that all OnToolCall callbacks
+// complete before any OnToolResult fires. This is the ordering guarantee
+// provided by buffering dispatches until the stream is fully consumed.
+func TestStreamingAgentToolCallBeforeResult(t *testing.T) {
+	t.Parallel()
+
+	stepCount := 0
+	mockModel := &mockLanguageModel{
+		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+			stepCount++
+			return func(yield func(StreamPart) bool) {
+				if stepCount == 1 {
+					// Emit two tool calls in the same step.
+					for _, id := range []string{"tool-1", "tool-2"} {
+						if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: id, ToolCallName: "echo"}) {
+							return
+						}
+						if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: id, Delta: `{"message": "` + id + `"}`}) {
+							return
+						}
+						if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: id}) {
+							return
+						}
+						if !yield(StreamPart{
+							Type:          StreamPartTypeToolCall,
+							ID:            id,
+							ToolCallName:  "echo",
+							ToolCallInput: `{"message": "` + id + `"}`,
+						}) {
+							return
+						}
+					}
+					yield(StreamPart{
+						Type:         StreamPartTypeFinish,
+						FinishReason: FinishReasonToolCalls,
+					})
+				} else {
+					yield(StreamPart{
+						Type:         StreamPartTypeFinish,
+						FinishReason: FinishReasonStop,
+					})
+				}
+			}, nil
+		},
+	}
+
+	agent := NewAgent(mockModel, WithTools(&EchoTool{}))
+
+	var mu sync.Mutex
+	var events []string
+
+	_, err := agent.Stream(context.Background(), AgentStreamCall{
+		Prompt: "echo twice",
+		OnToolCall: func(tc ToolCallContent) error {
+			mu.Lock()
+			events = append(events, "call:"+tc.ToolCallID)
+			mu.Unlock()
+			return nil
+		},
+		OnToolResult: func(tr ToolResultContent) error {
+			mu.Lock()
+			events = append(events, "result:"+tr.ToolCallID)
+			mu.Unlock()
+			return nil
+		},
+	})
+	require.NoError(t, err)
+
+	// Both OnToolCall events must appear before any OnToolResult event.
+	lastCallIdx := -1
+	firstResultIdx := len(events)
+	for i, e := range events {
+		if strings.HasPrefix(e, "call:") {
+			lastCallIdx = i
+		}
+		if strings.HasPrefix(e, "result:") && i < firstResultIdx {
+			firstResultIdx = i
+		}
+	}
+	require.Equal(t, 2, stepCount)
+	require.Less(t, lastCallIdx, firstResultIdx,
+		"all OnToolCall events must complete before the first OnToolResult; got %v", events)
+}
+
 // TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
 func TestStreamingAgentTextDeltas(t *testing.T) {
 	t.Parallel()