@@ -1247,12 +1247,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
toolCall ToolCallContent
parallel bool
}
- toolChan := make(chan toolExecutionRequest, 10)
var pendingDispatches []toolExecutionRequest
- var toolExecutionWg sync.WaitGroup
- var toolStateMu sync.Mutex
- toolResults := make([]ToolResultContent, 0)
- var toolExecutionErr error
// Create a map for quick tool lookup
toolMap := make(map[string]AgentTool)
@@ -1265,43 +1260,6 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
execProviderToolMap[ept.GetName()] = ept
}
- // Semaphores for controlling parallelism
- parallelSem := make(chan struct{}, 5)
- var sequentialMu sync.Mutex
-
- // Single coordinator goroutine that dispatches tools
- toolExecutionWg.Go(func() {
- for req := range toolChan {
- if req.parallel {
- parallelSem <- struct{}{}
- toolExecutionWg.Go(func() {
- defer func() { <-parallelSem }()
- result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
- toolStateMu.Lock()
- toolResults = append(toolResults, result)
- if isCriticalError && toolExecutionErr == nil {
- if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
- toolExecutionErr = errorResult.Error
- }
- }
- toolStateMu.Unlock()
- })
- } else {
- sequentialMu.Lock()
- result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
- toolStateMu.Lock()
- toolResults = append(toolResults, result)
- if isCriticalError && toolExecutionErr == nil {
- if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
- toolExecutionErr = errorResult.Error
- }
- }
- toolStateMu.Unlock()
- sequentialMu.Unlock()
- }
- }
- })
-
// Process stream parts
for part := range stream {
// Forward all parts to chunk callback
@@ -1536,13 +1494,58 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
}
- // Dispatch all buffered tool calls now that the complete set is known and
- // every OnToolCall callback has been called.
+ // All tool calls are now collected. Create the execution channel sized to
+ // avoid blocking during dispatch, start the coordinator, then flush the batch.
+ toolChan := make(chan toolExecutionRequest, len(pendingDispatches))
+ var toolExecutionWg sync.WaitGroup
+ var toolStateMu sync.Mutex
+ toolResults := make([]ToolResultContent, 0, len(pendingDispatches))
+ var toolExecutionErr error
+
+ // Semaphores for controlling parallelism.
+ parallelSem := make(chan struct{}, 5)
+ var sequentialMu sync.Mutex
+
+ // Single coordinator goroutine that dispatches tools.
+ toolExecutionWg.Go(func() {
+ for req := range toolChan {
+ if req.parallel {
+ parallelSem <- struct{}{}
+ toolExecutionWg.Go(func() {
+ defer func() { <-parallelSem }()
+ result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
+ toolStateMu.Lock()
+ toolResults = append(toolResults, result)
+ if isCriticalError && toolExecutionErr == nil {
+ if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
+ toolExecutionErr = errorResult.Error
+ }
+ }
+ toolStateMu.Unlock()
+ })
+ } else {
+ sequentialMu.Lock()
+ result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
+ toolStateMu.Lock()
+ toolResults = append(toolResults, result)
+ if isCriticalError && toolExecutionErr == nil {
+ if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
+ toolExecutionErr = errorResult.Error
+ }
+ }
+ toolStateMu.Unlock()
+ sequentialMu.Unlock()
+ }
+ }
+ })
+
+ // Dispatch all buffered tool calls now that every OnToolCall callback has
+ // been called, then close and wait.
for _, req := range pendingDispatches {
toolChan <- req
}
- // Close the tool execution channel and wait for all executions to complete
+ // Close the tool execution channel and wait for all executions to complete.
close(toolChan)
toolExecutionWg.Wait()
@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
+ "strings"
+ "sync"
"testing"
"github.com/stretchr/testify/require"
@@ -366,6 +368,90 @@ func TestStreamingAgentWithTools(t *testing.T) {
require.Equal(t, "echo", toolResults[0].ToolName)
}
+// TestStreamingAgentToolCallBeforeResult verifies that all OnToolCall callbacks
+// complete before any OnToolResult fires. This is the ordering guarantee
+// provided by buffering dispatches until the stream is fully consumed.
+func TestStreamingAgentToolCallBeforeResult(t *testing.T) {
+ t.Parallel()
+
+ stepCount := 0
+ mockModel := &mockLanguageModel{
+ streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+ stepCount++
+ return func(yield func(StreamPart) bool) {
+ if stepCount == 1 {
+ // Emit two tool calls in the same step.
+ for _, id := range []string{"tool-1", "tool-2"} {
+ if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: id, ToolCallName: "echo"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: id, Delta: `{"message": "` + id + `"}`}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: id}) {
+ return
+ }
+ if !yield(StreamPart{
+ Type: StreamPartTypeToolCall,
+ ID: id,
+ ToolCallName: "echo",
+ ToolCallInput: `{"message": "` + id + `"}`,
+ }) {
+ return
+ }
+ }
+ yield(StreamPart{
+ Type: StreamPartTypeFinish,
+ FinishReason: FinishReasonToolCalls,
+ })
+ } else {
+ yield(StreamPart{
+ Type: StreamPartTypeFinish,
+ FinishReason: FinishReasonStop,
+ })
+ }
+ }, nil
+ },
+ }
+
+ agent := NewAgent(mockModel, WithTools(&EchoTool{}))
+
+ var mu sync.Mutex
+ var events []string
+
+ _, err := agent.Stream(context.Background(), AgentStreamCall{
+ Prompt: "echo twice",
+ OnToolCall: func(tc ToolCallContent) error {
+ mu.Lock()
+ events = append(events, "call:"+tc.ToolCallID)
+ mu.Unlock()
+ return nil
+ },
+ OnToolResult: func(tr ToolResultContent) error {
+ mu.Lock()
+ events = append(events, "result:"+tr.ToolCallID)
+ mu.Unlock()
+ return nil
+ },
+ })
+ require.NoError(t, err)
+
+ // Both OnToolCall events must appear before any OnToolResult event.
+ lastCallIdx := -1
+ firstResultIdx := len(events)
+ for i, e := range events {
+ if strings.HasPrefix(e, "call:") {
+ lastCallIdx = i
+ }
+ if strings.HasPrefix(e, "result:") && i < firstResultIdx {
+ firstResultIdx = i
+ }
+ }
+ require.Equal(t, 2, stepCount)
+ require.Less(t, lastCallIdx, firstResultIdx,
+ "all OnToolCall events must complete before the first OnToolResult; got %v", events)
+}
+
// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
func TestStreamingAgentTextDeltas(t *testing.T) {
t.Parallel()