diff --git a/agent.go b/agent.go index 7adf521ff173923393ad6436667c21fbb876ee46..01912bbf3e8c3adf893d0430b0deadd38701a600 100644 --- a/agent.go +++ b/agent.go @@ -485,7 +485,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil) - // Build step content with validated tool calls and tool results. // Provider-executed tool calls are kept as-is. + // If any tool result requested a stop, deliver all results but don't + // request another completion from the model. + stopTurnRequested := hasStopTurn(toolResults) + + // Build step content with validated tool calls and tool results. + // Provider-executed tool calls are kept as-is. stepContent := []Content{} toolCallIndex := 0 for _, content := range result.Content { @@ -523,7 +528,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err steps = append(steps, stepResult) shouldStop := isStopConditionMet(opts.StopWhen, steps) - if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls { + if shouldStop || err != nil || stopTurnRequested || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls { break } } @@ -561,6 +566,15 @@ func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool { return false } +func hasStopTurn(results []ToolResultContent) bool { + for _, r := range results { + if r.StopTurn { + return true + } + } + return false +} + func toResponseMessages(content []Content) []Message { var assistantParts []MessagePart var toolParts []MessagePart @@ -729,6 +743,7 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT Error: err, } result.ClientMetadata = toolResult.Metadata + result.StopTurn = toolResult.StopTurn if toolResultCallback != nil { _ = toolResultCallback(result) } @@ -736,6 +751,7 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT } result.ClientMetadata = toolResult.Metadata + result.StopTurn = toolResult.StopTurn if toolResult.IsError { result.Result = ToolResultOutputContentError{ Error: errors.New(toolResult.Content), @@ -1573,7 +1589,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } // Determine if we should continue (has tool calls and not stopped) - shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls + shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls && !hasStopTurn(toolResults) return stepExecutionResult{ StepResult: stepResult, diff --git a/agent_stream_test.go b/agent_stream_test.go index cddc83dd5cb66792e22558b767eaa108bbb1c087..13e0ac3d25a26a2b6d725bfd1b7e2a0890512003 100644 --- a/agent_stream_test.go +++ b/agent_stream_test.go @@ -681,3 +681,83 @@ func TestStreamingAgentSources(t *testing.T) { resultSources := result.Response.Content.Sources() require.Equal(t, 2, len(resultSources)) } + +func TestStreamingAgent_StopTurn(t *testing.T) { + t.Parallel() + + stepCount := 0 + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + stepCount++ + return func(yield func(StreamPart) bool) { + if stepCount == 1 { + if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "blocked_tool"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) { + return + } + if !yield(StreamPart{ + Type: StreamPartTypeToolCall, + ID: "tool-1", + ToolCallName: "blocked_tool", + ToolCallInput: `{"message": "test"}`, + }) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + FinishReason: FinishReasonToolCalls, + }) + } else { + // Should not be reached because StopTurn prevents a second step + t.Fatal("model should not be called a second time after StopTurn") + } + }, nil + }, + } + + type BlockedInput struct { + Message string `json:"message" description:"Message"` + } + + blockedTool := NewAgentTool( + "blocked_tool", + "A tool that stops the turn", + func(ctx context.Context, input BlockedInput, _ ToolCall) (ToolResponse, error) { + resp := NewTextErrorResponse("permission denied") + resp.StopTurn = true + return resp, nil + }, + ) + + agent := NewAgent(mockModel, WithTools(blockedTool)) + + result, err := agent.Stream(context.Background(), AgentStreamCall{ + Prompt: "test stop turn", + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Only one step — StopTurn prevented the second model call. + require.Len(t, result.Steps, 1) + require.Equal(t, 1, stepCount) + + // Tool result should be present with StopTurn=true. + toolResults := result.Steps[0].Content.ToolResults() + require.Len(t, toolResults, 1) + require.Equal(t, "blocked_tool", toolResults[0].ToolName) + require.True(t, toolResults[0].StopTurn) + + // The final response also includes the stop-marked tool result. + responseResults := result.Response.Content.ToolResults() + require.Len(t, responseResults, 1) + require.True(t, responseResults[0].StopTurn) +} diff --git a/agent_test.go b/agent_test.go index 8d45429d284e88dac4fa7e16f9793ff7a375314c..da10747cd099a0ba7e807d56f0d81b6a0e80bd48 100644 --- a/agent_test.go +++ b/agent_test.go @@ -2428,3 +2428,133 @@ func TestAgent_Generate_ExecutableProviderTool_CriticalError(t *testing.T) { require.Equal(t, 1, callCount) require.Len(t, result.Steps, 1) } + +func TestAgent_Generate_StopTurn(t *testing.T) { + t.Parallel() + + type TestInput struct { + Value string `json:"value" description:"Test value"` + } + + tool1 := NewAgentTool( + "tool1", + "Test tool", + func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) { + resp := NewTextErrorResponse("permission denied: this tool call was blocked") + resp.StopTurn = true + return resp, nil + }, + ) + + callCount := 0 + model := &mockLanguageModel{ + generateFunc: func(ctx context.Context, call Call) (*Response, error) { + callCount++ + return &Response{ + Content: []Content{ + ToolCallContent{ + ToolCallID: "call-1", + ToolName: "tool1", + Input: `{"value":"test"}`, + }, + }, + Usage: Usage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + FinishReason: FinishReasonToolCalls, + }, nil + }, + } + + agent := NewAgent(model, WithTools(tool1)) + result, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "test-input", + }) + + require.NoError(t, err) + require.NotNil(t, result) + // The model should only be called once — StopTurn prevents the second call. + require.Equal(t, 1, callCount) + require.Len(t, result.Steps, 1) + + // The tool result should still be in the step content. + toolResults := result.Steps[0].Content.ToolResults() + require.Len(t, toolResults, 1) + require.Equal(t, "tool1", toolResults[0].ToolName) + require.True(t, toolResults[0].StopTurn) + + // The final response also includes the stop-marked tool result. + responseResults := result.Response.Content.ToolResults() + require.Len(t, responseResults, 1) + require.True(t, responseResults[0].StopTurn) +} + +func TestAgent_Generate_StopTurn_NotSet(t *testing.T) { + t.Parallel() + + type TestInput struct { + Value string `json:"value" description:"Test value"` + } + + tool1 := NewAgentTool( + "tool1", + "Test tool", + func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) { + return NewTextErrorResponse("normal error"), nil + }, + ) + + callCount := 0 + model := &mockLanguageModel{ + generateFunc: func(ctx context.Context, call Call) (*Response, error) { + callCount++ + switch callCount { + case 1: + return &Response{ + Content: []Content{ + ToolCallContent{ + ToolCallID: "call-1", + ToolName: "tool1", + Input: `{"value":"test"}`, + }, + }, + Usage: Usage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + FinishReason: FinishReasonToolCalls, + }, nil + case 2: + return &Response{ + Content: []Content{ + TextContent{Text: "Done"}, + }, + Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8}, + FinishReason: FinishReasonStop, + }, nil + default: + t.Fatalf("Unexpected call count: %d", callCount) + return nil, nil + } + }, + } + + agent := NewAgent(model, WithTools(tool1)) + result, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "test-input", + }) + + require.NoError(t, err) + require.NotNil(t, result) + // Without StopTurn, the model gets a second call. + require.Equal(t, 2, callCount) + require.Len(t, result.Steps, 2) + + // StopTurn should be false on the tool result. + toolResults := result.Steps[0].Content.ToolResults() + require.Len(t, toolResults, 1) + require.False(t, toolResults[0].StopTurn) +} diff --git a/content.go b/content.go index 8787f7cd0bd9b44a35903c674ea0e3b09dda9e78..f2a122ec491388f40558a392c9ae7031f9c8a7d3 100644 --- a/content.go +++ b/content.go @@ -463,6 +463,10 @@ type ToolResultContent struct { ProviderExecuted bool `json:"provider_executed"` // Additional provider-specific metadata for the tool result. ProviderMetadata ProviderMetadata `json:"provider_metadata"` + // StopTurn indicates that the agent loop should stop after this result. + // The tool result is still delivered to the model's context, but the model + // does not get another chance to make tool calls in the same turn. + StopTurn bool `json:"stop_turn,omitempty"` } // GetType returns the type of the tool result content. diff --git a/tool.go b/tool.go index be146ef3fce25310be2a2a569ab4c2965c40eaa7..b856d55fa7f1059ad8ff247712fd14383fe8c075 100644 --- a/tool.go +++ b/tool.go @@ -38,6 +38,7 @@ type ToolResponse struct { MediaType string `json:"media_type,omitempty"` Metadata string `json:"metadata,omitempty"` IsError bool `json:"is_error"` + StopTurn bool `json:"stop_turn,omitempty"` } // NewTextResponse creates a text response.