From 15bced7a1c8ab695bd73cf463343c79a2fb9a9ce Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 13 Apr 2026 12:08:58 -0400 Subject: [PATCH] chore(agent): move tool execution after stream loop + add test --- agent.go | 93 +++++++++++++++++++++++--------------------- agent_stream_test.go | 86 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 45 deletions(-) diff --git a/agent.go b/agent.go index f0b981be6c57d6d3308dafc9e73541e6b4017471..7adf521ff173923393ad6436667c21fbb876ee46 100644 --- a/agent.go +++ b/agent.go @@ -1247,12 +1247,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op toolCall ToolCallContent parallel bool } - toolChan := make(chan toolExecutionRequest, 10) var pendingDispatches []toolExecutionRequest - var toolExecutionWg sync.WaitGroup - var toolStateMu sync.Mutex - toolResults := make([]ToolResultContent, 0) - var toolExecutionErr error // Create a map for quick tool lookup toolMap := make(map[string]AgentTool) @@ -1265,43 +1260,6 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op execProviderToolMap[ept.GetName()] = ept } - // Semaphores for controlling parallelism - parallelSem := make(chan struct{}, 5) - var sequentialMu sync.Mutex - - // Single coordinator goroutine that dispatches tools - toolExecutionWg.Go(func() { - for req := range toolChan { - if req.parallel { - parallelSem <- struct{}{} - toolExecutionWg.Go(func() { - defer func() { <-parallelSem }() - result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult) - toolStateMu.Lock() - toolResults = append(toolResults, result) - if isCriticalError && toolExecutionErr == nil { - if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil { - toolExecutionErr = errorResult.Error - } - } - toolStateMu.Unlock() - }) - } else { - sequentialMu.Lock() - result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult) - toolStateMu.Lock() - toolResults = append(toolResults, result) - if isCriticalError && toolExecutionErr == nil { - if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil { - toolExecutionErr = errorResult.Error - } - } - toolStateMu.Unlock() - sequentialMu.Unlock() - } - } - }) - // Process stream parts for part := range stream { // Forward all parts to chunk callback @@ -1536,13 +1494,58 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } } - // Dispatch all buffered tool calls now that the complete set is known and - // every OnToolCall callback has been called. + // All tool calls are now collected. Create the execution channel sized to + // avoid blocking during dispatch, start the coordinator, then flush the batch. + toolChan := make(chan toolExecutionRequest, len(pendingDispatches)) + var toolExecutionWg sync.WaitGroup + var toolStateMu sync.Mutex + toolResults := make([]ToolResultContent, 0, len(pendingDispatches)) + var toolExecutionErr error + + // Semaphores for controlling parallelism. + parallelSem := make(chan struct{}, 5) + var sequentialMu sync.Mutex + + // Single coordinator goroutine that dispatches tools. + toolExecutionWg.Go(func() { + for req := range toolChan { + if req.parallel { + parallelSem <- struct{}{} + toolExecutionWg.Go(func() { + defer func() { <-parallelSem }() + result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult) + toolStateMu.Lock() + toolResults = append(toolResults, result) + if isCriticalError && toolExecutionErr == nil { + if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil { + toolExecutionErr = errorResult.Error + } + } + toolStateMu.Unlock() + }) + } else { + sequentialMu.Lock() + result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult) + toolStateMu.Lock() + toolResults = append(toolResults, result) + if isCriticalError && toolExecutionErr == nil { + if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil { + toolExecutionErr = errorResult.Error + } + } + toolStateMu.Unlock() + sequentialMu.Unlock() + } + } + }) + + // Dispatch all buffered tool calls now that every OnToolCall callback has + // been called, then close and wait. for _, req := range pendingDispatches { toolChan <- req } - // Close the tool execution channel and wait for all executions to complete + // Close the tool execution channel and wait for all executions to complete. close(toolChan) toolExecutionWg.Wait() diff --git a/agent_stream_test.go b/agent_stream_test.go index ea9009305c9c872a44a5074e5d361136250d1910..cddc83dd5cb66792e22558b767eaa108bbb1c087 100644 --- a/agent_stream_test.go +++ b/agent_stream_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "strings" + "sync" "testing" "github.com/stretchr/testify/require" @@ -366,6 +368,90 @@ func TestStreamingAgentWithTools(t *testing.T) { require.Equal(t, "echo", toolResults[0].ToolName) } +// TestStreamingAgentToolCallBeforeResult verifies that all OnToolCall callbacks +// complete before any OnToolResult fires. This is the ordering guarantee +// provided by buffering dispatches until the stream is fully consumed. +func TestStreamingAgentToolCallBeforeResult(t *testing.T) { + t.Parallel() + + stepCount := 0 + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + stepCount++ + return func(yield func(StreamPart) bool) { + if stepCount == 1 { + // Emit two tool calls in the same step. + for _, id := range []string{"tool-1", "tool-2"} { + if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: id, ToolCallName: "echo"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: id, Delta: `{"message": "` + id + `"}`}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: id}) { + return + } + if !yield(StreamPart{ + Type: StreamPartTypeToolCall, + ID: id, + ToolCallName: "echo", + ToolCallInput: `{"message": "` + id + `"}`, + }) { + return + } + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + FinishReason: FinishReasonToolCalls, + }) + } else { + yield(StreamPart{ + Type: StreamPartTypeFinish, + FinishReason: FinishReasonStop, + }) + } + }, nil + }, + } + + agent := NewAgent(mockModel, WithTools(&EchoTool{})) + + var mu sync.Mutex + var events []string + + _, err := agent.Stream(context.Background(), AgentStreamCall{ + Prompt: "echo twice", + OnToolCall: func(tc ToolCallContent) error { + mu.Lock() + events = append(events, "call:"+tc.ToolCallID) + mu.Unlock() + return nil + }, + OnToolResult: func(tr ToolResultContent) error { + mu.Lock() + events = append(events, "result:"+tr.ToolCallID) + mu.Unlock() + return nil + }, + }) + require.NoError(t, err) + + // Both OnToolCall events must appear before any OnToolResult event. + lastCallIdx := -1 + firstResultIdx := len(events) + for i, e := range events { + if strings.HasPrefix(e, "call:") { + lastCallIdx = i + } + if strings.HasPrefix(e, "result:") && i < firstResultIdx { + firstResultIdx = i + } + } + require.Equal(t, 2, stepCount) + require.Less(t, lastCallIdx, firstResultIdx, + "all OnToolCall events must complete before the first OnToolResult; got %v", events) +} + // TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests) func TestStreamingAgentTextDeltas(t *testing.T) { t.Parallel()