feat(agent): add the ability to stop a turn and end the agent loop

Christian Rocha created

Change summary

agent.go             |  22 ++++++-
agent_stream_test.go |  80 ++++++++++++++++++++++++++++
agent_test.go        | 130 ++++++++++++++++++++++++++++++++++++++++++++++
content.go           |   4 +
tool.go              |   1 
5 files changed, 234 insertions(+), 3 deletions(-)

Detailed changes

agent.go 🔗

@@ -485,7 +485,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 
 		toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)
 
-		// Build step content with validated tool calls and tool results.		// Provider-executed tool calls are kept as-is.
+		// If any tool result requested a stop, deliver all results but don't
+		// request another completion from the model.
+		stopTurnRequested := hasStopTurn(toolResults)
+
+		// Build step content with validated tool calls and tool results.
+		// Provider-executed tool calls are kept as-is.
 		stepContent := []Content{}
 		toolCallIndex := 0
 		for _, content := range result.Content {
@@ -523,7 +528,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 		steps = append(steps, stepResult)
 		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 
-		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
+		if shouldStop || err != nil || stopTurnRequested || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 			break
 		}
 	}
@@ -561,6 +566,15 @@ func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 	return false
 }
 
+func hasStopTurn(results []ToolResultContent) bool {
+	for _, r := range results {
+		if r.StopTurn {
+			return true
+		}
+	}
+	return false
+}
+
 func toResponseMessages(content []Content) []Message {
 	var assistantParts []MessagePart
 	var toolParts []MessagePart
@@ -729,6 +743,7 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT
 			Error: err,
 		}
 		result.ClientMetadata = toolResult.Metadata
+		result.StopTurn = toolResult.StopTurn
 		if toolResultCallback != nil {
 			_ = toolResultCallback(result)
 		}
@@ -736,6 +751,7 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT
 	}
 
 	result.ClientMetadata = toolResult.Metadata
+	result.StopTurn = toolResult.StopTurn
 	if toolResult.IsError {
 		result.Result = ToolResultOutputContentError{
 			Error: errors.New(toolResult.Content),
@@ -1573,7 +1589,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 	}
 
 	// Determine if we should continue (has tool calls and not stopped)
-	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
+	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls && !hasStopTurn(toolResults)
 
 	return stepExecutionResult{
 		StepResult:     stepResult,

agent_stream_test.go 🔗

@@ -681,3 +681,83 @@ func TestStreamingAgentSources(t *testing.T) {
 	resultSources := result.Response.Content.Sources()
 	require.Equal(t, 2, len(resultSources))
 }
+
+func TestStreamingAgent_StopTurn(t *testing.T) {
+	t.Parallel()
+
+	stepCount := 0
+	mockModel := &mockLanguageModel{
+		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+			stepCount++
+			return func(yield func(StreamPart) bool) {
+				if stepCount == 1 {
+					if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "blocked_tool"}) {
+						return
+					}
+					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
+						return
+					}
+					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
+						return
+					}
+					if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
+						return
+					}
+					if !yield(StreamPart{
+						Type:          StreamPartTypeToolCall,
+						ID:            "tool-1",
+						ToolCallName:  "blocked_tool",
+						ToolCallInput: `{"message": "test"}`,
+					}) {
+						return
+					}
+					yield(StreamPart{
+						Type:         StreamPartTypeFinish,
+						Usage:        Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
+						FinishReason: FinishReasonToolCalls,
+					})
+				} else {
+					// Should not be reached because StopTurn prevents a second step
+					t.Fatal("model should not be called a second time after StopTurn")
+				}
+			}, nil
+		},
+	}
+
+	type BlockedInput struct {
+		Message string `json:"message" description:"Message"`
+	}
+
+	blockedTool := NewAgentTool(
+		"blocked_tool",
+		"A tool that stops the turn",
+		func(ctx context.Context, input BlockedInput, _ ToolCall) (ToolResponse, error) {
+			resp := NewTextErrorResponse("permission denied")
+			resp.StopTurn = true
+			return resp, nil
+		},
+	)
+
+	agent := NewAgent(mockModel, WithTools(blockedTool))
+
+	result, err := agent.Stream(context.Background(), AgentStreamCall{
+		Prompt: "test stop turn",
+	})
+	require.NoError(t, err)
+	require.NotNil(t, result)
+
+	// Only one step — StopTurn prevented the second model call.
+	require.Len(t, result.Steps, 1)
+	require.Equal(t, 1, stepCount)
+
+	// Tool result should be present with StopTurn=true.
+	toolResults := result.Steps[0].Content.ToolResults()
+	require.Len(t, toolResults, 1)
+	require.Equal(t, "blocked_tool", toolResults[0].ToolName)
+	require.True(t, toolResults[0].StopTurn)
+
+	// The final response also includes the stop-marked tool result.
+	responseResults := result.Response.Content.ToolResults()
+	require.Len(t, responseResults, 1)
+	require.True(t, responseResults[0].StopTurn)
+}

agent_test.go 🔗

@@ -2428,3 +2428,133 @@ func TestAgent_Generate_ExecutableProviderTool_CriticalError(t *testing.T) {
 	require.Equal(t, 1, callCount)
 	require.Len(t, result.Steps, 1)
 }
+
+func TestAgent_Generate_StopTurn(t *testing.T) {
+	t.Parallel()
+
+	type TestInput struct {
+		Value string `json:"value" description:"Test value"`
+	}
+
+	tool1 := NewAgentTool(
+		"tool1",
+		"Test tool",
+		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
+			resp := NewTextErrorResponse("permission denied: this tool call was blocked")
+			resp.StopTurn = true
+			return resp, nil
+		},
+	)
+
+	callCount := 0
+	model := &mockLanguageModel{
+		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+			callCount++
+			return &Response{
+				Content: []Content{
+					ToolCallContent{
+						ToolCallID: "call-1",
+						ToolName:   "tool1",
+						Input:      `{"value":"test"}`,
+					},
+				},
+				Usage: Usage{
+					InputTokens:  10,
+					OutputTokens: 5,
+					TotalTokens:  15,
+				},
+				FinishReason: FinishReasonToolCalls,
+			}, nil
+		},
+	}
+
+	agent := NewAgent(model, WithTools(tool1))
+	result, err := agent.Generate(context.Background(), AgentCall{
+		Prompt: "test-input",
+	})
+
+	require.NoError(t, err)
+	require.NotNil(t, result)
+	// The model should only be called once — StopTurn prevents the second call.
+	require.Equal(t, 1, callCount)
+	require.Len(t, result.Steps, 1)
+
+	// The tool result should still be in the step content.
+	toolResults := result.Steps[0].Content.ToolResults()
+	require.Len(t, toolResults, 1)
+	require.Equal(t, "tool1", toolResults[0].ToolName)
+	require.True(t, toolResults[0].StopTurn)
+
+	// The final response also includes the stop-marked tool result.
+	responseResults := result.Response.Content.ToolResults()
+	require.Len(t, responseResults, 1)
+	require.True(t, responseResults[0].StopTurn)
+}
+
+func TestAgent_Generate_StopTurn_NotSet(t *testing.T) {
+	t.Parallel()
+
+	type TestInput struct {
+		Value string `json:"value" description:"Test value"`
+	}
+
+	tool1 := NewAgentTool(
+		"tool1",
+		"Test tool",
+		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
+			return NewTextErrorResponse("normal error"), nil
+		},
+	)
+
+	callCount := 0
+	model := &mockLanguageModel{
+		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+			callCount++
+			switch callCount {
+			case 1:
+				return &Response{
+					Content: []Content{
+						ToolCallContent{
+							ToolCallID: "call-1",
+							ToolName:   "tool1",
+							Input:      `{"value":"test"}`,
+						},
+					},
+					Usage: Usage{
+						InputTokens:  10,
+						OutputTokens: 5,
+						TotalTokens:  15,
+					},
+					FinishReason: FinishReasonToolCalls,
+				}, nil
+			case 2:
+				return &Response{
+					Content: []Content{
+						TextContent{Text: "Done"},
+					},
+					Usage:        Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
+					FinishReason: FinishReasonStop,
+				}, nil
+			default:
+				t.Fatalf("Unexpected call count: %d", callCount)
+				return nil, nil
+			}
+		},
+	}
+
+	agent := NewAgent(model, WithTools(tool1))
+	result, err := agent.Generate(context.Background(), AgentCall{
+		Prompt: "test-input",
+	})
+
+	require.NoError(t, err)
+	require.NotNil(t, result)
+	// Without StopTurn, the model gets a second call.
+	require.Equal(t, 2, callCount)
+	require.Len(t, result.Steps, 2)
+
+	// StopTurn should be false on the tool result.
+	toolResults := result.Steps[0].Content.ToolResults()
+	require.Len(t, toolResults, 1)
+	require.False(t, toolResults[0].StopTurn)
+}

content.go 🔗

@@ -463,6 +463,10 @@ type ToolResultContent struct {
 	ProviderExecuted bool `json:"provider_executed"`
 	// Additional provider-specific metadata for the tool result.
 	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
+	// StopTurn indicates that the agent loop should stop after this result.
+	// The tool result is still delivered to the model's context, but the model
+	// does not get another chance to make tool calls in the same turn.
+	StopTurn bool `json:"stop_turn,omitempty"`
 }
 
 // GetType returns the type of the tool result content.

tool.go 🔗

@@ -38,6 +38,7 @@ type ToolResponse struct {
 	MediaType string `json:"media_type,omitempty"`
 	Metadata  string `json:"metadata,omitempty"`
 	IsError   bool   `json:"is_error"`
+	StopTurn  bool   `json:"stop_turn,omitempty"`
 }
 
 // NewTextResponse creates a text response.