agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"maps"
  11	"slices"
  12	"sync"
  13
  14	"charm.land/fantasy/schema"
  15	"github.com/charmbracelet/x/exp/slice"
  16)
  17
  18// StepResult represents the result of a single step in an agent execution.
  19type StepResult struct {
  20	Response
  21	Messages []Message
  22}
  23
  24// stepExecutionResult encapsulates the result of executing a step with stream processing.
  25type stepExecutionResult struct {
  26	StepResult     StepResult
  27	ShouldContinue bool
  28}
  29
  30// StopCondition defines a function that determines when an agent should stop executing.
  31type StopCondition = func(steps []StepResult) bool
  32
  33// StepCountIs returns a stop condition that stops after the specified number of steps.
  34func StepCountIs(stepCount int) StopCondition {
  35	return func(steps []StepResult) bool {
  36		return len(steps) >= stepCount
  37	}
  38}
  39
  40// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  41func HasToolCall(toolName string) StopCondition {
  42	return func(steps []StepResult) bool {
  43		if len(steps) == 0 {
  44			return false
  45		}
  46		lastStep := steps[len(steps)-1]
  47		toolCalls := lastStep.Content.ToolCalls()
  48		for _, toolCall := range toolCalls {
  49			if toolCall.ToolName == toolName {
  50				return true
  51			}
  52		}
  53		return false
  54	}
  55}
  56
  57// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  58func HasContent(contentType ContentType) StopCondition {
  59	return func(steps []StepResult) bool {
  60		if len(steps) == 0 {
  61			return false
  62		}
  63		lastStep := steps[len(steps)-1]
  64		for _, content := range lastStep.Content {
  65			if content.GetType() == contentType {
  66				return true
  67			}
  68		}
  69		return false
  70	}
  71}
  72
  73// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  74func FinishReasonIs(reason FinishReason) StopCondition {
  75	return func(steps []StepResult) bool {
  76		if len(steps) == 0 {
  77			return false
  78		}
  79		lastStep := steps[len(steps)-1]
  80		return lastStep.FinishReason == reason
  81	}
  82}
  83
  84// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  85func MaxTokensUsed(maxTokens int64) StopCondition {
  86	return func(steps []StepResult) bool {
  87		var totalTokens int64
  88		for _, step := range steps {
  89			totalTokens += step.Usage.TotalTokens
  90		}
  91		return totalTokens >= maxTokens
  92	}
  93}
  94
  95// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  96type PrepareStepFunctionOptions struct {
  97	Steps      []StepResult
  98	StepNumber int
  99	Model      LanguageModel
 100	Messages   []Message
 101}
 102
 103// PrepareStepResult contains the result of preparing a step in an agent execution.
 104type PrepareStepResult struct {
 105	Model           LanguageModel
 106	Messages        []Message
 107	System          *string
 108	ToolChoice      *ToolChoice
 109	ActiveTools     []string
 110	DisableAllTools bool
 111	Tools           []AgentTool
 112}
 113
 114// ToolCallRepairOptions contains the options for repairing a tool call.
 115type ToolCallRepairOptions struct {
 116	OriginalToolCall ToolCallContent
 117	ValidationError  error
 118	AvailableTools   []AgentTool
 119	SystemPrompt     string
 120	Messages         []Message
 121}
 122
 123type (
 124	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 125	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 126
 127	// OnStepFinishedFunction defines a function that is called when a step finishes.
 128	OnStepFinishedFunction = func(step StepResult)
 129
 130	// RepairToolCallFunction defines a function that repairs a tool call.
 131	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 132)
 133
 134type agentSettings struct {
 135	systemPrompt     string
 136	maxOutputTokens  *int64
 137	temperature      *float64
 138	topP             *float64
 139	topK             *int64
 140	presencePenalty  *float64
 141	frequencyPenalty *float64
 142	headers          map[string]string
 143	userAgent        string
 144	providerOptions  ProviderOptions
 145
 146	providerDefinedTools    []ProviderDefinedTool
 147	executableProviderTools []ExecutableProviderTool
 148	tools                   []AgentTool
 149	toolChoice              *ToolChoice
 150	maxRetries              *int
 151
 152	model LanguageModel
 153
 154	stopWhen       []StopCondition
 155	prepareStep    PrepareStepFunction
 156	repairToolCall RepairToolCallFunction
 157	onRetry        OnRetryCallback
 158}
 159
 160// AgentCall represents a call to an agent.
 161type AgentCall struct {
 162	Prompt           string     `json:"prompt"`
 163	Files            []FilePart `json:"files"`
 164	Messages         []Message  `json:"messages"`
 165	MaxOutputTokens  *int64
 166	Temperature      *float64    `json:"temperature"`
 167	TopP             *float64    `json:"top_p"`
 168	TopK             *int64      `json:"top_k"`
 169	PresencePenalty  *float64    `json:"presence_penalty"`
 170	FrequencyPenalty *float64    `json:"frequency_penalty"`
 171	ActiveTools      []string    `json:"active_tools"`
 172	ToolChoice       *ToolChoice `json:"tool_choice"`
 173	ProviderOptions  ProviderOptions
 174	OnRetry          OnRetryCallback
 175	MaxRetries       *int
 176
 177	StopWhen       []StopCondition
 178	PrepareStep    PrepareStepFunction
 179	RepairToolCall RepairToolCallFunction
 180}
 181
 182// Agent-level callbacks.
 183type (
 184	// OnAgentStartFunc is called when agent starts.
 185	OnAgentStartFunc func()
 186
 187	// OnAgentFinishFunc is called when agent finishes.
 188	OnAgentFinishFunc func(result *AgentResult) error
 189
 190	// OnStepStartFunc is called when a step starts.
 191	OnStepStartFunc func(stepNumber int) error
 192
 193	// OnStepFinishFunc is called when a step finishes.
 194	OnStepFinishFunc func(stepResult StepResult) error
 195
 196	// OnFinishFunc is called when entire agent completes.
 197	OnFinishFunc func(result *AgentResult)
 198
 199	// OnErrorFunc is called when an error occurs.
 200	OnErrorFunc func(error)
 201)
 202
 203// Stream part callbacks - called for each corresponding stream part type.
 204type (
 205	// OnChunkFunc is called for each stream part (catch-all).
 206	OnChunkFunc func(StreamPart) error
 207
 208	// OnWarningsFunc is called for warnings.
 209	OnWarningsFunc func(warnings []CallWarning) error
 210
 211	// OnTextStartFunc is called when text starts.
 212	OnTextStartFunc func(id string) error
 213
 214	// OnTextDeltaFunc is called for text deltas.
 215	OnTextDeltaFunc func(id, text string) error
 216
 217	// OnTextEndFunc is called when text ends.
 218	OnTextEndFunc func(id string) error
 219
 220	// OnReasoningStartFunc is called when reasoning starts.
 221	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 222
 223	// OnReasoningDeltaFunc is called for reasoning deltas.
 224	OnReasoningDeltaFunc func(id, text string) error
 225
 226	// OnReasoningEndFunc is called when reasoning ends.
 227	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 228
 229	// OnToolInputStartFunc is called when tool input starts.
 230	OnToolInputStartFunc func(id, toolName string) error
 231
 232	// OnToolInputDeltaFunc is called for tool input deltas.
 233	OnToolInputDeltaFunc func(id, delta string) error
 234
 235	// OnToolInputEndFunc is called when tool input ends.
 236	OnToolInputEndFunc func(id string) error
 237
 238	// OnToolCallFunc is called when tool call is complete.
 239	OnToolCallFunc func(toolCall ToolCallContent) error
 240
 241	// OnToolResultFunc is called when tool execution completes.
 242	OnToolResultFunc func(result ToolResultContent) error
 243
 244	// OnSourceFunc is called for source references.
 245	OnSourceFunc func(source SourceContent) error
 246
 247	// OnStreamFinishFunc is called when stream finishes.
 248	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 249)
 250
 251// AgentStreamCall represents a streaming call to an agent.
 252type AgentStreamCall struct {
 253	Prompt           string     `json:"prompt"`
 254	Files            []FilePart `json:"files"`
 255	Messages         []Message  `json:"messages"`
 256	MaxOutputTokens  *int64
 257	Temperature      *float64    `json:"temperature"`
 258	TopP             *float64    `json:"top_p"`
 259	TopK             *int64      `json:"top_k"`
 260	PresencePenalty  *float64    `json:"presence_penalty"`
 261	FrequencyPenalty *float64    `json:"frequency_penalty"`
 262	ActiveTools      []string    `json:"active_tools"`
 263	ToolChoice       *ToolChoice `json:"tool_choice"`
 264	Headers          map[string]string
 265	ProviderOptions  ProviderOptions
 266	OnRetry          OnRetryCallback
 267	MaxRetries       *int
 268
 269	StopWhen       []StopCondition
 270	PrepareStep    PrepareStepFunction
 271	RepairToolCall RepairToolCallFunction
 272
 273	// Agent-level callbacks
 274	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 275	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 276	OnStepStart   OnStepStartFunc   // Called when a step starts
 277	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 278	OnFinish      OnFinishFunc      // Called when entire agent completes
 279	OnError       OnErrorFunc       // Called when an error occurs
 280
 281	// Stream part callbacks - called for each corresponding stream part type
 282	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 283	OnWarnings       OnWarningsFunc       // Called for warnings
 284	OnTextStart      OnTextStartFunc      // Called when text starts
 285	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 286	OnTextEnd        OnTextEndFunc        // Called when text ends
 287	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 288	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 289	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 290	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 291	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 292	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 293	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 294	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 295	OnSource         OnSourceFunc         // Called for source references
 296	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 297}
 298
 299// AgentResult represents the result of an agent execution.
 300type AgentResult struct {
 301	Steps []StepResult
 302	// Final response
 303	Response   Response
 304	TotalUsage Usage
 305}
 306
 307// Agent represents an AI agent that can generate responses and stream responses.
 308type Agent interface {
 309	Generate(context.Context, AgentCall) (*AgentResult, error)
 310	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 311}
 312
 313// AgentOption defines a function that configures agent settings.
 314type AgentOption = func(*agentSettings)
 315
 316type agent struct {
 317	settings agentSettings
 318}
 319
 320// NewAgent creates a new agent with the given language model and options.
 321func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 322	settings := agentSettings{
 323		model: model,
 324	}
 325	for _, o := range opts {
 326		o(&settings)
 327	}
 328	return &agent{
 329		settings: settings,
 330	}
 331}
 332
 333func (a *agent) prepareCall(call AgentCall) AgentCall {
 334	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 335	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 336	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 337	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 338	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 339	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 340	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 341	call.ToolChoice = cmp.Or(call.ToolChoice, a.settings.toolChoice)
 342
 343	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 344		call.StopWhen = a.settings.stopWhen
 345	}
 346	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 347		call.PrepareStep = a.settings.prepareStep
 348	}
 349	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 350		call.RepairToolCall = a.settings.repairToolCall
 351	}
 352	if call.OnRetry == nil && a.settings.onRetry != nil {
 353		call.OnRetry = a.settings.onRetry
 354	}
 355
 356	providerOptions := ProviderOptions{}
 357	if a.settings.providerOptions != nil {
 358		maps.Copy(providerOptions, a.settings.providerOptions)
 359	}
 360	if call.ProviderOptions != nil {
 361		maps.Copy(providerOptions, call.ProviderOptions)
 362	}
 363	call.ProviderOptions = providerOptions
 364
 365	headers := map[string]string{}
 366
 367	if a.settings.headers != nil {
 368		maps.Copy(headers, a.settings.headers)
 369	}
 370
 371	return call
 372}
 373
 374// Generate implements Agent.
 375func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 376	opts = a.prepareCall(opts)
 377	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 378	if err != nil {
 379		return nil, err
 380	}
 381	var responseMessages []Message
 382	var steps []StepResult
 383
 384	for {
 385		stepInputMessages := append(initialPrompt, responseMessages...)
 386		stepModel := a.settings.model
 387		stepSystemPrompt := a.settings.systemPrompt
 388		stepActiveTools := opts.ActiveTools
 389		stepToolChoice := ToolChoiceAuto
 390		if opts.ToolChoice != nil {
 391			stepToolChoice = *opts.ToolChoice
 392		}
 393		disableAllTools := false
 394		stepTools := a.settings.tools
 395		if opts.PrepareStep != nil {
 396			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 397				Model:      stepModel,
 398				Steps:      steps,
 399				StepNumber: len(steps),
 400				Messages:   stepInputMessages,
 401			})
 402			if err != nil {
 403				return nil, err
 404			}
 405
 406			ctx = updatedCtx
 407
 408			// Apply prepared step modifications
 409			if prepared.Messages != nil {
 410				stepInputMessages = prepared.Messages
 411			}
 412			if prepared.Model != nil {
 413				stepModel = prepared.Model
 414			}
 415			if prepared.System != nil {
 416				stepSystemPrompt = *prepared.System
 417			}
 418			if prepared.ToolChoice != nil {
 419				stepToolChoice = *prepared.ToolChoice
 420			}
 421			if len(prepared.ActiveTools) > 0 {
 422				stepActiveTools = prepared.ActiveTools
 423			}
 424			disableAllTools = prepared.DisableAllTools
 425			if prepared.Tools != nil {
 426				stepTools = prepared.Tools
 427			}
 428		}
 429
 430		// Recreate prompt with potentially modified system prompt
 431		if stepSystemPrompt != a.settings.systemPrompt {
 432			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 433			if err != nil {
 434				return nil, err
 435			}
 436			// Replace system message part, keep the rest
 437			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 438				stepInputMessages[0] = stepPrompt[0] // Replace system message
 439			}
 440		}
 441
 442		preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
 443
 444		// Filter executable provider tools by activeTools at the
 445		// step level, consistent with how stepTools (AgentTools)
 446		// are scoped before being passed to inner functions.
 447		stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
 448
 449		retryOptions := DefaultRetryOptions()
 450		if opts.MaxRetries != nil {
 451			retryOptions.MaxRetries = *opts.MaxRetries
 452		}
 453		retryOptions.OnRetry = opts.OnRetry
 454		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 455		result, err := retry(ctx, func() (*Response, error) {
 456			return stepModel.Generate(ctx, Call{
 457				Prompt:           stepInputMessages,
 458				MaxOutputTokens:  opts.MaxOutputTokens,
 459				Temperature:      opts.Temperature,
 460				TopP:             opts.TopP,
 461				TopK:             opts.TopK,
 462				PresencePenalty:  opts.PresencePenalty,
 463				FrequencyPenalty: opts.FrequencyPenalty,
 464				Tools:            preparedTools,
 465				ToolChoice:       &stepToolChoice,
 466				UserAgent:        a.settings.userAgent,
 467				ProviderOptions:  opts.ProviderOptions,
 468			})
 469		})
 470		if err != nil {
 471			return nil, err
 472		}
 473
 474		var stepToolCalls []ToolCallContent
 475		for _, content := range result.Content {
 476			if content.GetType() == ContentTypeToolCall {
 477				toolCall, ok := AsContentType[ToolCallContent](content)
 478				if !ok {
 479					continue
 480				}
 481				// Provider-executed tool calls (e.g. web search) are
 482				// handled by the provider and should not be validated
 483				// or executed by the agent.
 484				if toolCall.ProviderExecuted {
 485					continue
 486				}
 487				// Validate and potentially repair the tool call
 488				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepExecProviderTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 489				stepToolCalls = append(stepToolCalls, validatedToolCall)
 490			}
 491		}
 492
 493		toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)
 494
 495		// If any tool result requested a stop, deliver all results but don't
 496		// request another completion from the model.
 497		stopTurnRequested := hasStopTurn(toolResults)
 498
 499		// Build step content with validated tool calls and tool results.
 500		// Provider-executed tool calls are kept as-is.
 501		stepContent := []Content{}
 502		toolCallIndex := 0
 503		for _, content := range result.Content {
 504			if content.GetType() == ContentTypeToolCall {
 505				tc, ok := AsContentType[ToolCallContent](content)
 506				if ok && tc.ProviderExecuted {
 507					stepContent = append(stepContent, content)
 508					continue
 509				}
 510				// Replace with validated tool call.
 511				if toolCallIndex < len(stepToolCalls) {
 512					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 513					toolCallIndex++
 514				}
 515			} else {
 516				stepContent = append(stepContent, content)
 517			}
 518		} // Add tool results
 519		for _, result := range toolResults {
 520			stepContent = append(stepContent, result)
 521		}
 522		currentStepMessages := toResponseMessages(stepContent)
 523		responseMessages = append(responseMessages, currentStepMessages...)
 524
 525		stepResult := StepResult{
 526			Response: Response{
 527				Content:          stepContent,
 528				FinishReason:     result.FinishReason,
 529				Usage:            result.Usage,
 530				Warnings:         result.Warnings,
 531				ProviderMetadata: result.ProviderMetadata,
 532			},
 533			Messages: currentStepMessages,
 534		}
 535		steps = append(steps, stepResult)
 536		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 537
 538		if shouldStop || err != nil || stopTurnRequested || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 539			break
 540		}
 541	}
 542
 543	totalUsage := Usage{}
 544
 545	for _, step := range steps {
 546		usage := step.Usage
 547		totalUsage.InputTokens += usage.InputTokens
 548		totalUsage.OutputTokens += usage.OutputTokens
 549		totalUsage.ReasoningTokens += usage.ReasoningTokens
 550		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 551		totalUsage.CacheReadTokens += usage.CacheReadTokens
 552		totalUsage.TotalTokens += usage.TotalTokens
 553	}
 554
 555	agentResult := &AgentResult{
 556		Steps:      steps,
 557		Response:   steps[len(steps)-1].Response,
 558		TotalUsage: totalUsage,
 559	}
 560	return agentResult, nil
 561}
 562
 563func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 564	if len(conditions) == 0 {
 565		return false
 566	}
 567
 568	for _, condition := range conditions {
 569		if condition(steps) {
 570			return true
 571		}
 572	}
 573	return false
 574}
 575
 576func hasStopTurn(results []ToolResultContent) bool {
 577	for _, r := range results {
 578		if r.StopTurn {
 579			return true
 580		}
 581	}
 582	return false
 583}
 584
 585func toResponseMessages(content []Content) []Message {
 586	var assistantParts []MessagePart
 587	var toolParts []MessagePart
 588
 589	for _, c := range content {
 590		switch c.GetType() {
 591		case ContentTypeText:
 592			text, ok := AsContentType[TextContent](c)
 593			if !ok {
 594				continue
 595			}
 596			assistantParts = append(assistantParts, TextPart{
 597				Text:            text.Text,
 598				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 599			})
 600		case ContentTypeReasoning:
 601			reasoning, ok := AsContentType[ReasoningContent](c)
 602			if !ok {
 603				continue
 604			}
 605			assistantParts = append(assistantParts, ReasoningPart{
 606				Text:            reasoning.Text,
 607				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 608			})
 609		case ContentTypeToolCall:
 610			toolCall, ok := AsContentType[ToolCallContent](c)
 611			if !ok {
 612				continue
 613			}
 614			assistantParts = append(assistantParts, ToolCallPart{
 615				ToolCallID:       toolCall.ToolCallID,
 616				ToolName:         toolCall.ToolName,
 617				Input:            toolCall.Input,
 618				ProviderExecuted: toolCall.ProviderExecuted,
 619				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 620			})
 621		case ContentTypeFile:
 622			file, ok := AsContentType[FileContent](c)
 623			if !ok {
 624				continue
 625			}
 626			assistantParts = append(assistantParts, FilePart{
 627				Data:            file.Data,
 628				MediaType:       file.MediaType,
 629				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 630			})
 631		case ContentTypeSource:
 632			// Sources are metadata about references used to generate the response.
 633			// They don't need to be included in the conversation messages.
 634			continue
 635		case ContentTypeToolResult:
 636			result, ok := AsContentType[ToolResultContent](c)
 637			if !ok {
 638				continue
 639			}
 640			resultPart := ToolResultPart{
 641				ToolCallID:       result.ToolCallID,
 642				Output:           result.Result,
 643				ProviderExecuted: result.ProviderExecuted,
 644				ProviderOptions:  ProviderOptions(result.ProviderMetadata),
 645			}
 646			if result.ProviderExecuted {
 647				// Provider-executed tool results (e.g. web search)
 648				// belong in the assistant message alongside the
 649				// server_tool_use block that produced them.
 650				assistantParts = append(assistantParts, resultPart)
 651			} else {
 652				toolParts = append(toolParts, resultPart)
 653			}
 654		}
 655	}
 656
 657	var messages []Message
 658	if len(assistantParts) > 0 {
 659		messages = append(messages, Message{
 660			Role:    MessageRoleAssistant,
 661			Content: assistantParts,
 662		})
 663	}
 664	if len(toolParts) > 0 {
 665		messages = append(messages, Message{
 666			Role:    MessageRoleTool,
 667			Content: toolParts,
 668		})
 669	}
 670	return messages
 671}
 672
 673func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, execProviderTools []ExecutableProviderTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 674	if len(toolCalls) == 0 {
 675		return nil, nil
 676	}
 677
 678	// Create a map for quick tool lookup
 679	toolMap := make(map[string]AgentTool)
 680	for _, tool := range allTools {
 681		toolMap[tool.Info().Name] = tool
 682	}
 683
 684	execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
 685	for _, ept := range execProviderTools {
 686		execProviderToolMap[ept.GetName()] = ept
 687	}
 688
 689	// Execute all tool calls sequentially in order
 690	results := make([]ToolResultContent, 0, len(toolCalls))
 691
 692	for _, toolCall := range toolCalls {
 693		result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, toolCall, toolResultCallback)
 694		results = append(results, result)
 695		if isCriticalError {
 696			if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
 697				return nil, errorResult.Error
 698			}
 699		}
 700	}
 701
 702	return results, nil
 703}
 704
 705// executeSingleTool executes a single tool and returns its result and a critical error flag.
 706func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, execProviderToolMap map[string]ExecutableProviderTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
 707	result := ToolResultContent{
 708		ToolCallID:       toolCall.ToolCallID,
 709		ToolName:         toolCall.ToolName,
 710		ProviderExecuted: false,
 711	}
 712
 713	// Skip invalid tool calls - create error result (not critical)
 714	if toolCall.Invalid {
 715		result.Result = ToolResultOutputContentError{
 716			Error: toolCall.ValidationError,
 717		}
 718		if toolResultCallback != nil {
 719			_ = toolResultCallback(result)
 720		}
 721		return result, false
 722	}
 723
 724	// Find the run function — either from a regular AgentTool or an
 725	// executable provider tool.
 726	var runTool func(ctx context.Context, call ToolCall) (ToolResponse, error)
 727	if tool, exists := toolMap[toolCall.ToolName]; exists {
 728		runTool = tool.Run
 729	} else if ept, ok := execProviderToolMap[toolCall.ToolName]; ok {
 730		runTool = ept.Run
 731	}
 732	if runTool == nil {
 733		result.Result = ToolResultOutputContentError{
 734			Error: errors.New("tool not found: " + toolCall.ToolName),
 735		}
 736		if toolResultCallback != nil {
 737			_ = toolResultCallback(result)
 738		}
 739		return result, false
 740	}
 741
 742	// Execute the tool
 743	toolResult, err := runTool(ctx, ToolCall{
 744		ID:    toolCall.ToolCallID,
 745		Name:  toolCall.ToolName,
 746		Input: toolCall.Input,
 747	})
 748	if err != nil {
 749		result.Result = ToolResultOutputContentError{
 750			Error: err,
 751		}
 752		result.ClientMetadata = toolResult.Metadata
 753		result.StopTurn = toolResult.StopTurn
 754		if toolResultCallback != nil {
 755			_ = toolResultCallback(result)
 756		}
 757		return result, true
 758	}
 759
 760	result.ClientMetadata = toolResult.Metadata
 761	result.StopTurn = toolResult.StopTurn
 762	if toolResult.IsError {
 763		result.Result = ToolResultOutputContentError{
 764			Error: errors.New(toolResult.Content),
 765		}
 766	} else if toolResult.Type == "image" || toolResult.Type == "media" {
 767		result.Result = ToolResultOutputContentMedia{
 768			Data:      base64.StdEncoding.EncodeToString(toolResult.Data),
 769			MediaType: toolResult.MediaType,
 770			Text:      toolResult.Content,
 771		}
 772	} else {
 773		result.Result = ToolResultOutputContentText{
 774			Text: toolResult.Content,
 775		}
 776	}
 777	if toolResultCallback != nil {
 778		_ = toolResultCallback(result)
 779	}
 780	return result, false
 781}
 782
 783// Stream implements Agent.
 784func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 785	// Convert AgentStreamCall to AgentCall for preparation
 786	call := AgentCall{
 787		Prompt:           opts.Prompt,
 788		Files:            opts.Files,
 789		Messages:         opts.Messages,
 790		MaxOutputTokens:  opts.MaxOutputTokens,
 791		Temperature:      opts.Temperature,
 792		TopP:             opts.TopP,
 793		TopK:             opts.TopK,
 794		PresencePenalty:  opts.PresencePenalty,
 795		FrequencyPenalty: opts.FrequencyPenalty,
 796		ActiveTools:      opts.ActiveTools,
 797		ToolChoice:       opts.ToolChoice,
 798		ProviderOptions:  opts.ProviderOptions,
 799		MaxRetries:       opts.MaxRetries,
 800		OnRetry:          opts.OnRetry,
 801		StopWhen:         opts.StopWhen,
 802		PrepareStep:      opts.PrepareStep,
 803		RepairToolCall:   opts.RepairToolCall,
 804	}
 805
 806	call = a.prepareCall(call)
 807
 808	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 809	if err != nil {
 810		return nil, err
 811	}
 812
 813	var responseMessages []Message
 814	var steps []StepResult
 815	var totalUsage Usage
 816
 817	// Start agent stream
 818	if opts.OnAgentStart != nil {
 819		opts.OnAgentStart()
 820	}
 821
 822	for stepNumber := 0; ; stepNumber++ {
 823		stepInputMessages := append(initialPrompt, responseMessages...)
 824		stepModel := a.settings.model
 825		stepSystemPrompt := a.settings.systemPrompt
 826		stepActiveTools := call.ActiveTools
 827		stepToolChoice := ToolChoiceAuto
 828		if call.ToolChoice != nil {
 829			stepToolChoice = *call.ToolChoice
 830		}
 831		disableAllTools := false
 832		stepTools := a.settings.tools
 833		// Apply step preparation if provided
 834		if call.PrepareStep != nil {
 835			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 836				Model:      stepModel,
 837				Steps:      steps,
 838				StepNumber: stepNumber,
 839				Messages:   stepInputMessages,
 840			})
 841			if err != nil {
 842				return nil, err
 843			}
 844
 845			ctx = updatedCtx
 846
 847			if prepared.Messages != nil {
 848				stepInputMessages = prepared.Messages
 849			}
 850			if prepared.Model != nil {
 851				stepModel = prepared.Model
 852			}
 853			if prepared.System != nil {
 854				stepSystemPrompt = *prepared.System
 855			}
 856			if prepared.ToolChoice != nil {
 857				stepToolChoice = *prepared.ToolChoice
 858			}
 859			if len(prepared.ActiveTools) > 0 {
 860				stepActiveTools = prepared.ActiveTools
 861			}
 862			disableAllTools = prepared.DisableAllTools
 863			if prepared.Tools != nil {
 864				stepTools = prepared.Tools
 865			}
 866		}
 867
 868		// Recreate prompt with potentially modified system prompt
 869		if stepSystemPrompt != a.settings.systemPrompt {
 870			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 871			if err != nil {
 872				return nil, err
 873			}
 874			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 875				stepInputMessages[0] = stepPrompt[0]
 876			}
 877		}
 878
 879		preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
 880
 881		// Filter executable provider tools by activeTools at the
 882		// step level, consistent with how stepTools (AgentTools)
 883		// are scoped before being passed to inner functions.
 884		stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
 885
 886		// Start step stream
 887		if opts.OnStepStart != nil {
 888			_ = opts.OnStepStart(stepNumber)
 889		}
 890		// Create streaming call
 891		streamCall := Call{
 892			Prompt:           stepInputMessages,
 893			MaxOutputTokens:  call.MaxOutputTokens,
 894			Temperature:      call.Temperature,
 895			TopP:             call.TopP,
 896			TopK:             call.TopK,
 897			PresencePenalty:  call.PresencePenalty,
 898			FrequencyPenalty: call.FrequencyPenalty,
 899			Tools:            preparedTools,
 900			ToolChoice:       &stepToolChoice,
 901			UserAgent:        a.settings.userAgent,
 902			ProviderOptions:  call.ProviderOptions,
 903		}
 904
 905		// Execute step with retry logic wrapping both stream creation and processing
 906		retryOptions := DefaultRetryOptions()
 907		if call.MaxRetries != nil {
 908			retryOptions.MaxRetries = *call.MaxRetries
 909		}
 910		retryOptions.OnRetry = call.OnRetry
 911		retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
 912
 913		result, err := retry(ctx, func() (stepExecutionResult, error) {
 914			// Create the stream
 915			stream, err := stepModel.Stream(ctx, streamCall)
 916			if err != nil {
 917				return stepExecutionResult{}, err
 918			}
 919
 920			// Process the stream
 921			result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools)
 922			if err != nil {
 923				return stepExecutionResult{}, err
 924			}
 925			return result, nil
 926		})
 927		if err != nil {
 928			if opts.OnError != nil {
 929				opts.OnError(err)
 930			}
 931			return nil, err
 932		}
 933
 934		steps = append(steps, result.StepResult)
 935		totalUsage = addUsage(totalUsage, result.StepResult.Usage)
 936
 937		// Call step finished callback
 938		if opts.OnStepFinish != nil {
 939			_ = opts.OnStepFinish(result.StepResult)
 940		}
 941
 942		// Add step messages to response messages
 943		stepMessages := toResponseMessages(result.StepResult.Content)
 944		responseMessages = append(responseMessages, stepMessages...)
 945
 946		// Check stop conditions
 947		shouldStop := isStopConditionMet(call.StopWhen, steps)
 948		if shouldStop || !result.ShouldContinue {
 949			break
 950		}
 951	}
 952
 953	// Finish agent stream
 954	agentResult := &AgentResult{
 955		Steps:      steps,
 956		Response:   steps[len(steps)-1].Response,
 957		TotalUsage: totalUsage,
 958	}
 959
 960	if opts.OnFinish != nil {
 961		opts.OnFinish(agentResult)
 962	}
 963
 964	if opts.OnAgentFinish != nil {
 965		_ = opts.OnAgentFinish(agentResult)
 966	}
 967
 968	return agentResult, nil
 969}
 970
 971// filterExecProviderTools returns the subset of executable provider
 972// tools permitted by activeTools. When activeTools is empty every
 973// tool is included (no filtering).
 974func (a *agent) filterExecProviderTools(activeTools []string) []ExecutableProviderTool {
 975	if len(activeTools) == 0 {
 976		return a.settings.executableProviderTools
 977	}
 978	filtered := make([]ExecutableProviderTool, 0, len(a.settings.executableProviderTools))
 979	for _, ept := range a.settings.executableProviderTools {
 980		if slices.Contains(activeTools, ept.GetName()) {
 981			filtered = append(filtered, ept)
 982		}
 983	}
 984	return filtered
 985}
 986
 987func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
 988	preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
 989
 990	// If explicitly disabling all tools, return no tools
 991	if disableAllTools {
 992		return preparedTools
 993	}
 994
 995	for _, tool := range tools {
 996		// If activeTools has items, only include tools in the list
 997		// If activeTools is empty, include all tools
 998		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 999			continue
1000		}
1001		info := tool.Info()
1002		inputSchema := map[string]any{
1003			"type":       "object",
1004			"properties": info.Parameters,
1005			"required":   info.Required,
1006		}
1007		schema.Normalize(inputSchema)
1008		preparedTools = append(preparedTools, FunctionTool{
1009			Name:            info.Name,
1010			Description:     info.Description,
1011			InputSchema:     inputSchema,
1012			ProviderOptions: tool.ProviderOptions(),
1013		})
1014	}
1015	for _, tool := range providerDefinedTools {
1016		// If activeTools has items, only include tools in the list. If
1017		// activeTools is empty, include all tools
1018		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
1019			continue
1020		}
1021		preparedTools = append(preparedTools, tool)
1022	}
1023	return preparedTools
1024}
1025
1026// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
1027func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
1028	if err := a.validateToolCall(toolCall, availableTools, execProviderTools); err == nil {
1029		return toolCall
1030	} else { //nolint: revive
1031		if repairFunc != nil {
1032			repairOptions := ToolCallRepairOptions{
1033				OriginalToolCall: toolCall,
1034				ValidationError:  err,
1035				AvailableTools:   availableTools,
1036				SystemPrompt:     systemPrompt,
1037				Messages:         messages,
1038			}
1039
1040			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
1041				if validateErr := a.validateToolCall(*repairedToolCall, availableTools, execProviderTools); validateErr == nil {
1042					return *repairedToolCall
1043				}
1044			}
1045		}
1046
1047		invalidToolCall := toolCall
1048		invalidToolCall.Invalid = true
1049		invalidToolCall.ValidationError = err
1050		return invalidToolCall
1051	}
1052}
1053
1054// validateToolCall validates a tool call against available tools and their schemas.
1055// Both availableTools and execProviderTools must already be filtered by the
1056// caller (e.g. via activeTools); this function trusts that the slices
1057// represent exactly the tools permitted for the current step.
1058func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool) error {
1059	var tool AgentTool
1060	for _, t := range availableTools {
1061		if t.Info().Name == toolCall.ToolName {
1062			tool = t
1063			break
1064		}
1065	}
1066
1067	if tool == nil {
1068		// Check if this is an executable provider tool. Provider-
1069		// defined tools have their schema enforced server-side, so
1070		// we only validate that the input is parseable JSON.
1071		for _, ept := range execProviderTools {
1072			if ept.GetName() == toolCall.ToolName {
1073				var input map[string]any
1074				if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1075					return fmt.Errorf("invalid JSON input: %w", err)
1076				}
1077				return nil
1078			}
1079		}
1080		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1081	}
1082
1083	// Validate JSON parsing
1084	var input map[string]any
1085	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1086		return fmt.Errorf("invalid JSON input: %w", err)
1087	}
1088
1089	// Basic schema validation (check required fields)
1090	// TODO: more robust schema validation using JSON Schema or similar
1091	toolInfo := tool.Info()
1092	for _, required := range toolInfo.Required {
1093		if _, exists := input[required]; !exists {
1094			return fmt.Errorf("missing required parameter: %s", required)
1095		}
1096	}
1097	return nil
1098}
1099
1100func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1101	// Validation: empty prompt is only allowed when there are messages,
1102	// no files to attach, and the last message is a user or tool message.
1103	if prompt == "" {
1104		lastMessage, hasMessages := slice.Last(messages)
1105
1106		if !hasMessages {
1107			return nil, &Error{
1108				Title:   "invalid argument",
1109				Message: "prompt can't be empty when there are no messages",
1110			}
1111		}
1112
1113		if len(files) > 0 {
1114			return nil, &Error{
1115				Title:   "invalid argument",
1116				Message: "prompt can't be empty when there are files",
1117			}
1118		}
1119
1120		switch lastMessage.Role {
1121		case MessageRoleUser, MessageRoleTool:
1122		default:
1123			return nil, &Error{
1124				Title:   "invalid argument",
1125				Message: "prompt can't be empty when the last message is not a user or tool message",
1126			}
1127		}
1128	}
1129
1130	var preparedPrompt Prompt
1131
1132	if system != "" {
1133		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1134	}
1135	preparedPrompt = append(preparedPrompt, messages...)
1136	if prompt != "" {
1137		preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1138	}
1139	return preparedPrompt, nil
1140}
1141
1142// WithSystemPrompt sets the system prompt for the agent.
1143func WithSystemPrompt(prompt string) AgentOption {
1144	return func(s *agentSettings) {
1145		s.systemPrompt = prompt
1146	}
1147}
1148
1149// WithMaxOutputTokens sets the maximum output tokens for the agent.
1150func WithMaxOutputTokens(tokens int64) AgentOption {
1151	return func(s *agentSettings) {
1152		s.maxOutputTokens = &tokens
1153	}
1154}
1155
1156// WithTemperature sets the temperature for the agent.
1157func WithTemperature(temp float64) AgentOption {
1158	return func(s *agentSettings) {
1159		s.temperature = &temp
1160	}
1161}
1162
1163// WithTopP sets the top-p value for the agent.
1164func WithTopP(topP float64) AgentOption {
1165	return func(s *agentSettings) {
1166		s.topP = &topP
1167	}
1168}
1169
1170// WithTopK sets the top-k value for the agent.
1171func WithTopK(topK int64) AgentOption {
1172	return func(s *agentSettings) {
1173		s.topK = &topK
1174	}
1175}
1176
1177// WithPresencePenalty sets the presence penalty for the agent.
1178func WithPresencePenalty(penalty float64) AgentOption {
1179	return func(s *agentSettings) {
1180		s.presencePenalty = &penalty
1181	}
1182}
1183
1184// WithFrequencyPenalty sets the frequency penalty for the agent.
1185func WithFrequencyPenalty(penalty float64) AgentOption {
1186	return func(s *agentSettings) {
1187		s.frequencyPenalty = &penalty
1188	}
1189}
1190
1191// WithTools sets the tools for the agent.
1192func WithTools(tools ...AgentTool) AgentOption {
1193	return func(s *agentSettings) {
1194		s.tools = append(s.tools, tools...)
1195	}
1196}
1197
1198// WithProviderDefinedTools registers provider-defined tools with the
1199// agent. Provider-executed tools (e.g. web search) are passed through
1200// to the API. Client-executed tools (ExecutableProviderTool) are also
1201// registered for local execution.
1202func WithProviderDefinedTools(tools ...ProviderTool) AgentOption {
1203	return func(s *agentSettings) {
1204		for _, t := range tools {
1205			// Every provider tool goes into providerDefinedTools
1206			// for wire formatting.
1207			s.providerDefinedTools = append(
1208				s.providerDefinedTools, t.providerDefinedTool(),
1209			)
1210			// Executable ones also register for local execution.
1211			if exec, ok := t.(ExecutableProviderTool); ok {
1212				s.executableProviderTools = append(
1213					s.executableProviderTools, exec,
1214				)
1215			}
1216		}
1217	}
1218}
1219
1220// WithToolChoice sets the default tool choice for the agent. It is overridden
1221// by the ToolChoice on a specific call, and by PrepareStep at the step level.
1222func WithToolChoice(choice ToolChoice) AgentOption {
1223	return func(s *agentSettings) {
1224		s.toolChoice = &choice
1225	}
1226}
1227
1228// WithStopConditions sets the stop conditions for the agent.
1229func WithStopConditions(conditions ...StopCondition) AgentOption {
1230	return func(s *agentSettings) {
1231		s.stopWhen = append(s.stopWhen, conditions...)
1232	}
1233}
1234
1235// WithPrepareStep sets the prepare step function for the agent.
1236func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1237	return func(s *agentSettings) {
1238		s.prepareStep = fn
1239	}
1240}
1241
1242// WithRepairToolCall sets the repair tool call function for the agent.
1243func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1244	return func(s *agentSettings) {
1245		s.repairToolCall = fn
1246	}
1247}
1248
1249// WithMaxRetries sets the maximum number of retries for the agent.
1250func WithMaxRetries(maxRetries int) AgentOption {
1251	return func(s *agentSettings) {
1252		s.maxRetries = &maxRetries
1253	}
1254}
1255
1256// WithOnRetry sets the retry callback for the agent.
1257func WithOnRetry(callback OnRetryCallback) AgentOption {
1258	return func(s *agentSettings) {
1259		s.onRetry = callback
1260	}
1261}
1262
1263// processStepStream processes a single step's stream and returns the step result.
1264func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) {
1265	var stepContent []Content
1266	var stepToolCalls []ToolCallContent
1267	var stepUsage Usage
1268	stepFinishReason := FinishReasonUnknown
1269	var stepWarnings []CallWarning
1270	var stepProviderMetadata ProviderMetadata
1271
1272	activeToolCalls := make(map[string]*ToolCallContent)
1273	activeTextContent := make(map[string]string)
1274	type reasoningContent struct {
1275		content string
1276		options ProviderMetadata
1277	}
1278	activeReasoningContent := make(map[string]reasoningContent)
1279
1280	// Set up concurrent tool execution
1281	type toolExecutionRequest struct {
1282		toolCall ToolCallContent
1283		parallel bool
1284	}
1285	var pendingDispatches []toolExecutionRequest
1286
1287	// Create a map for quick tool lookup
1288	toolMap := make(map[string]AgentTool)
1289	for _, tool := range stepTools {
1290		toolMap[tool.Info().Name] = tool
1291	}
1292
1293	execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
1294	for _, ept := range execProviderTools {
1295		execProviderToolMap[ept.GetName()] = ept
1296	}
1297
1298	// Process stream parts
1299	for part := range stream {
1300		// Forward all parts to chunk callback
1301		if opts.OnChunk != nil {
1302			err := opts.OnChunk(part)
1303			if err != nil {
1304				return stepExecutionResult{}, err
1305			}
1306		}
1307
1308		switch part.Type {
1309		case StreamPartTypeWarnings:
1310			stepWarnings = part.Warnings
1311			if opts.OnWarnings != nil {
1312				err := opts.OnWarnings(part.Warnings)
1313				if err != nil {
1314					return stepExecutionResult{}, err
1315				}
1316			}
1317
1318		case StreamPartTypeTextStart:
1319			activeTextContent[part.ID] = ""
1320			if opts.OnTextStart != nil {
1321				err := opts.OnTextStart(part.ID)
1322				if err != nil {
1323					return stepExecutionResult{}, err
1324				}
1325			}
1326
1327		case StreamPartTypeTextDelta:
1328			if _, exists := activeTextContent[part.ID]; exists {
1329				activeTextContent[part.ID] += part.Delta
1330			}
1331			if opts.OnTextDelta != nil {
1332				err := opts.OnTextDelta(part.ID, part.Delta)
1333				if err != nil {
1334					return stepExecutionResult{}, err
1335				}
1336			}
1337
1338		case StreamPartTypeTextEnd:
1339			if text, exists := activeTextContent[part.ID]; exists {
1340				stepContent = append(stepContent, TextContent{
1341					Text:             text,
1342					ProviderMetadata: part.ProviderMetadata,
1343				})
1344				delete(activeTextContent, part.ID)
1345			}
1346			if opts.OnTextEnd != nil {
1347				err := opts.OnTextEnd(part.ID)
1348				if err != nil {
1349					return stepExecutionResult{}, err
1350				}
1351			}
1352
1353		case StreamPartTypeReasoningStart:
1354			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1355			if opts.OnReasoningStart != nil {
1356				content := ReasoningContent{
1357					Text:             part.Delta,
1358					ProviderMetadata: part.ProviderMetadata,
1359				}
1360				err := opts.OnReasoningStart(part.ID, content)
1361				if err != nil {
1362					return stepExecutionResult{}, err
1363				}
1364			}
1365
1366		case StreamPartTypeReasoningDelta:
1367			if active, exists := activeReasoningContent[part.ID]; exists {
1368				active.content += part.Delta
1369				if part.ProviderMetadata != nil {
1370					active.options = part.ProviderMetadata
1371				}
1372				activeReasoningContent[part.ID] = active
1373			}
1374			if opts.OnReasoningDelta != nil {
1375				err := opts.OnReasoningDelta(part.ID, part.Delta)
1376				if err != nil {
1377					return stepExecutionResult{}, err
1378				}
1379			}
1380
1381		case StreamPartTypeReasoningEnd:
1382			if active, exists := activeReasoningContent[part.ID]; exists {
1383				if part.ProviderMetadata != nil {
1384					active.options = part.ProviderMetadata
1385				}
1386				content := ReasoningContent{
1387					Text:             active.content,
1388					ProviderMetadata: active.options,
1389				}
1390				stepContent = append(stepContent, content)
1391				if opts.OnReasoningEnd != nil {
1392					err := opts.OnReasoningEnd(part.ID, content)
1393					if err != nil {
1394						return stepExecutionResult{}, err
1395					}
1396				}
1397				delete(activeReasoningContent, part.ID)
1398			}
1399
1400		case StreamPartTypeToolInputStart:
1401			activeToolCalls[part.ID] = &ToolCallContent{
1402				ToolCallID:       part.ID,
1403				ToolName:         part.ToolCallName,
1404				Input:            "",
1405				ProviderExecuted: part.ProviderExecuted,
1406			}
1407			if opts.OnToolInputStart != nil {
1408				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1409				if err != nil {
1410					return stepExecutionResult{}, err
1411				}
1412			}
1413
1414		case StreamPartTypeToolInputDelta:
1415			if toolCall, exists := activeToolCalls[part.ID]; exists {
1416				toolCall.Input += part.Delta
1417			}
1418			if opts.OnToolInputDelta != nil {
1419				err := opts.OnToolInputDelta(part.ID, part.Delta)
1420				if err != nil {
1421					return stepExecutionResult{}, err
1422				}
1423			}
1424
1425		case StreamPartTypeToolInputEnd:
1426			if opts.OnToolInputEnd != nil {
1427				err := opts.OnToolInputEnd(part.ID)
1428				if err != nil {
1429					return stepExecutionResult{}, err
1430				}
1431			}
1432
1433		case StreamPartTypeToolCall:
1434			toolCall := ToolCallContent{
1435				ToolCallID:       part.ID,
1436				ToolName:         part.ToolCallName,
1437				Input:            part.ToolCallInput,
1438				ProviderExecuted: part.ProviderExecuted,
1439				ProviderMetadata: part.ProviderMetadata,
1440			}
1441
1442			// Provider-executed tool calls are handled by the provider
1443			// and should not be validated or executed by the agent.
1444			if toolCall.ProviderExecuted {
1445				stepContent = append(stepContent, toolCall)
1446				if opts.OnToolCall != nil {
1447					err := opts.OnToolCall(toolCall)
1448					if err != nil {
1449						return stepExecutionResult{}, err
1450					}
1451				}
1452				delete(activeToolCalls, part.ID)
1453			} else {
1454				// Validate and potentially repair the tool call
1455				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1456				stepToolCalls = append(stepToolCalls, validatedToolCall)
1457				stepContent = append(stepContent, validatedToolCall)
1458
1459				if opts.OnToolCall != nil {
1460					err := opts.OnToolCall(validatedToolCall)
1461					if err != nil {
1462						return stepExecutionResult{}, err
1463					}
1464				}
1465
1466				// Determine if tool can run in parallel
1467				isParallel := false
1468				if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1469					isParallel = tool.Info().Parallel
1470				}
1471
1472				// Buffer dispatch until stream is fully consumed so that all
1473				// OnToolCall callbacks complete before any tool result is written.
1474				pendingDispatches = append(pendingDispatches, toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel})
1475
1476				// Clean up active tool call
1477				delete(activeToolCalls, part.ID)
1478			}
1479
1480		case StreamPartTypeToolResult:
1481			// Provider-executed tool results (e.g. web search)
1482			// are emitted by the provider and added directly
1483			// to the step content for multi-turn round-tripping.
1484			if part.ProviderExecuted {
1485				resultContent := ToolResultContent{
1486					ToolCallID:       part.ID,
1487					ToolName:         part.ToolCallName,
1488					ProviderExecuted: true,
1489					ProviderMetadata: part.ProviderMetadata,
1490				}
1491				stepContent = append(stepContent, resultContent)
1492				if opts.OnToolResult != nil {
1493					err := opts.OnToolResult(resultContent)
1494					if err != nil {
1495						return stepExecutionResult{}, err
1496					}
1497				}
1498			}
1499
1500		case StreamPartTypeSource:
1501			sourceContent := SourceContent{
1502				SourceType:       part.SourceType,
1503				ID:               part.ID,
1504				URL:              part.URL,
1505				Title:            part.Title,
1506				ProviderMetadata: part.ProviderMetadata,
1507			}
1508			stepContent = append(stepContent, sourceContent)
1509			if opts.OnSource != nil {
1510				err := opts.OnSource(sourceContent)
1511				if err != nil {
1512					return stepExecutionResult{}, err
1513				}
1514			}
1515
1516		case StreamPartTypeFinish:
1517			stepUsage = part.Usage
1518			stepFinishReason = part.FinishReason
1519			stepProviderMetadata = part.ProviderMetadata
1520			if opts.OnStreamFinish != nil {
1521				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1522				if err != nil {
1523					return stepExecutionResult{}, err
1524				}
1525			}
1526
1527		case StreamPartTypeError:
1528			return stepExecutionResult{}, part.Error
1529		}
1530	}
1531
1532	// All tool calls are now collected. Create the execution channel sized to
1533	// avoid blocking during dispatch, start the coordinator, then flush the batch.
1534	toolChan := make(chan toolExecutionRequest, len(pendingDispatches))
1535	var toolExecutionWg sync.WaitGroup
1536	var toolStateMu sync.Mutex
1537	toolResults := make([]ToolResultContent, 0, len(pendingDispatches))
1538	var toolExecutionErr error
1539
1540	// Semaphores for controlling parallelism.
1541	parallelSem := make(chan struct{}, 5)
1542	var sequentialMu sync.Mutex
1543
1544	// Single coordinator goroutine that dispatches tools.
1545	toolExecutionWg.Go(func() {
1546		for req := range toolChan {
1547			if req.parallel {
1548				parallelSem <- struct{}{}
1549				toolExecutionWg.Go(func() {
1550					defer func() { <-parallelSem }()
1551					result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1552					toolStateMu.Lock()
1553					toolResults = append(toolResults, result)
1554					if isCriticalError && toolExecutionErr == nil {
1555						if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1556							toolExecutionErr = errorResult.Error
1557						}
1558					}
1559					toolStateMu.Unlock()
1560				})
1561			} else {
1562				sequentialMu.Lock()
1563				result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1564				toolStateMu.Lock()
1565				toolResults = append(toolResults, result)
1566				if isCriticalError && toolExecutionErr == nil {
1567					if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1568						toolExecutionErr = errorResult.Error
1569					}
1570				}
1571				toolStateMu.Unlock()
1572				sequentialMu.Unlock()
1573			}
1574		}
1575	})
1576
1577	// Dispatch all buffered tool calls now that every OnToolCall callback has
1578	// been called, then close and wait.
1579	for _, req := range pendingDispatches {
1580		toolChan <- req
1581	}
1582
1583	// Close the tool execution channel and wait for all executions to complete.
1584	close(toolChan)
1585	toolExecutionWg.Wait()
1586
1587	// Check for tool execution errors
1588	if toolExecutionErr != nil {
1589		return stepExecutionResult{}, toolExecutionErr
1590	}
1591
1592	// Add tool results to content if any
1593	if len(toolResults) > 0 {
1594		for _, result := range toolResults {
1595			stepContent = append(stepContent, result)
1596		}
1597	}
1598
1599	stepResult := StepResult{
1600		Response: Response{
1601			Content:          stepContent,
1602			FinishReason:     stepFinishReason,
1603			Usage:            stepUsage,
1604			Warnings:         stepWarnings,
1605			ProviderMetadata: stepProviderMetadata,
1606		},
1607		Messages: toResponseMessages(stepContent),
1608	}
1609
1610	// Determine if we should continue (has tool calls and not stopped)
1611	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls && !hasStopTurn(toolResults)
1612
1613	return stepExecutionResult{
1614		StepResult:     stepResult,
1615		ShouldContinue: shouldContinue,
1616	}, nil
1617}
1618
1619func addUsage(a, b Usage) Usage {
1620	return Usage{
1621		InputTokens:         a.InputTokens + b.InputTokens,
1622		OutputTokens:        a.OutputTokens + b.OutputTokens,
1623		TotalTokens:         a.TotalTokens + b.TotalTokens,
1624		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1625		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1626		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1627	}
1628}
1629
1630// WithHeaders sets the headers for the agent.
1631func WithHeaders(headers map[string]string) AgentOption {
1632	return func(s *agentSettings) {
1633		s.headers = headers
1634	}
1635}
1636
1637// WithUserAgent sets the User-Agent header for the agent. This overrides any
1638// provider-level User-Agent setting.
1639func WithUserAgent(ua string) AgentOption {
1640	return func(s *agentSettings) {
1641		s.userAgent = ua
1642	}
1643}
1644
1645// WithProviderOptions sets the provider options for the agent.
1646func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1647	return func(s *agentSettings) {
1648		s.providerOptions = providerOptions
1649	}
1650}