agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"maps"
  11	"slices"
  12	"sync"
  13
  14	"charm.land/fantasy/schema"
  15	"github.com/charmbracelet/x/exp/slice"
  16)
  17
  18// StepResult represents the result of a single step in an agent execution.
  19type StepResult struct {
  20	Response
  21	Messages []Message
  22}
  23
  24// stepExecutionResult encapsulates the result of executing a step with stream processing.
  25type stepExecutionResult struct {
  26	StepResult     StepResult
  27	ShouldContinue bool
  28}
  29
  30// StopCondition defines a function that determines when an agent should stop executing.
  31type StopCondition = func(steps []StepResult) bool
  32
  33// StepCountIs returns a stop condition that stops after the specified number of steps.
  34func StepCountIs(stepCount int) StopCondition {
  35	return func(steps []StepResult) bool {
  36		return len(steps) >= stepCount
  37	}
  38}
  39
  40// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  41func HasToolCall(toolName string) StopCondition {
  42	return func(steps []StepResult) bool {
  43		if len(steps) == 0 {
  44			return false
  45		}
  46		lastStep := steps[len(steps)-1]
  47		toolCalls := lastStep.Content.ToolCalls()
  48		for _, toolCall := range toolCalls {
  49			if toolCall.ToolName == toolName {
  50				return true
  51			}
  52		}
  53		return false
  54	}
  55}
  56
  57// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  58func HasContent(contentType ContentType) StopCondition {
  59	return func(steps []StepResult) bool {
  60		if len(steps) == 0 {
  61			return false
  62		}
  63		lastStep := steps[len(steps)-1]
  64		for _, content := range lastStep.Content {
  65			if content.GetType() == contentType {
  66				return true
  67			}
  68		}
  69		return false
  70	}
  71}
  72
  73// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  74func FinishReasonIs(reason FinishReason) StopCondition {
  75	return func(steps []StepResult) bool {
  76		if len(steps) == 0 {
  77			return false
  78		}
  79		lastStep := steps[len(steps)-1]
  80		return lastStep.FinishReason == reason
  81	}
  82}
  83
  84// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  85func MaxTokensUsed(maxTokens int64) StopCondition {
  86	return func(steps []StepResult) bool {
  87		var totalTokens int64
  88		for _, step := range steps {
  89			totalTokens += step.Usage.TotalTokens
  90		}
  91		return totalTokens >= maxTokens
  92	}
  93}
  94
  95// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  96type PrepareStepFunctionOptions struct {
  97	Steps      []StepResult
  98	StepNumber int
  99	Model      LanguageModel
 100	Messages   []Message
 101}
 102
 103// PrepareStepResult contains the result of preparing a step in an agent execution.
 104type PrepareStepResult struct {
 105	Model           LanguageModel
 106	Messages        []Message
 107	System          *string
 108	ToolChoice      *ToolChoice
 109	ActiveTools     []string
 110	DisableAllTools bool
 111	Tools           []AgentTool
 112}
 113
 114// ToolCallRepairOptions contains the options for repairing a tool call.
 115type ToolCallRepairOptions struct {
 116	OriginalToolCall ToolCallContent
 117	ValidationError  error
 118	AvailableTools   []AgentTool
 119	SystemPrompt     string
 120	Messages         []Message
 121}
 122
 123type (
 124	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 125	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 126
 127	// OnStepFinishedFunction defines a function that is called when a step finishes.
 128	OnStepFinishedFunction = func(step StepResult)
 129
 130	// RepairToolCallFunction defines a function that repairs a tool call.
 131	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 132)
 133
 134type agentSettings struct {
 135	systemPrompt     string
 136	maxOutputTokens  *int64
 137	temperature      *float64
 138	topP             *float64
 139	topK             *int64
 140	presencePenalty  *float64
 141	frequencyPenalty *float64
 142	headers          map[string]string
 143	userAgent        string
 144	providerOptions  ProviderOptions
 145
 146	providerDefinedTools    []ProviderDefinedTool
 147	executableProviderTools []ExecutableProviderTool
 148	tools                   []AgentTool
 149	maxRetries              *int
 150
 151	model LanguageModel
 152
 153	stopWhen       []StopCondition
 154	prepareStep    PrepareStepFunction
 155	repairToolCall RepairToolCallFunction
 156	onRetry        OnRetryCallback
 157}
 158
 159// AgentCall represents a call to an agent.
 160type AgentCall struct {
 161	Prompt           string     `json:"prompt"`
 162	Files            []FilePart `json:"files"`
 163	Messages         []Message  `json:"messages"`
 164	MaxOutputTokens  *int64
 165	Temperature      *float64 `json:"temperature"`
 166	TopP             *float64 `json:"top_p"`
 167	TopK             *int64   `json:"top_k"`
 168	PresencePenalty  *float64 `json:"presence_penalty"`
 169	FrequencyPenalty *float64 `json:"frequency_penalty"`
 170	ActiveTools      []string `json:"active_tools"`
 171	ProviderOptions  ProviderOptions
 172	OnRetry          OnRetryCallback
 173	MaxRetries       *int
 174
 175	StopWhen       []StopCondition
 176	PrepareStep    PrepareStepFunction
 177	RepairToolCall RepairToolCallFunction
 178}
 179
 180// Agent-level callbacks.
 181type (
 182	// OnAgentStartFunc is called when agent starts.
 183	OnAgentStartFunc func()
 184
 185	// OnAgentFinishFunc is called when agent finishes.
 186	OnAgentFinishFunc func(result *AgentResult) error
 187
 188	// OnStepStartFunc is called when a step starts.
 189	OnStepStartFunc func(stepNumber int) error
 190
 191	// OnStepFinishFunc is called when a step finishes.
 192	OnStepFinishFunc func(stepResult StepResult) error
 193
 194	// OnFinishFunc is called when entire agent completes.
 195	OnFinishFunc func(result *AgentResult)
 196
 197	// OnErrorFunc is called when an error occurs.
 198	OnErrorFunc func(error)
 199)
 200
 201// Stream part callbacks - called for each corresponding stream part type.
 202type (
 203	// OnChunkFunc is called for each stream part (catch-all).
 204	OnChunkFunc func(StreamPart) error
 205
 206	// OnWarningsFunc is called for warnings.
 207	OnWarningsFunc func(warnings []CallWarning) error
 208
 209	// OnTextStartFunc is called when text starts.
 210	OnTextStartFunc func(id string) error
 211
 212	// OnTextDeltaFunc is called for text deltas.
 213	OnTextDeltaFunc func(id, text string) error
 214
 215	// OnTextEndFunc is called when text ends.
 216	OnTextEndFunc func(id string) error
 217
 218	// OnReasoningStartFunc is called when reasoning starts.
 219	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 220
 221	// OnReasoningDeltaFunc is called for reasoning deltas.
 222	OnReasoningDeltaFunc func(id, text string) error
 223
 224	// OnReasoningEndFunc is called when reasoning ends.
 225	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 226
 227	// OnToolInputStartFunc is called when tool input starts.
 228	OnToolInputStartFunc func(id, toolName string) error
 229
 230	// OnToolInputDeltaFunc is called for tool input deltas.
 231	OnToolInputDeltaFunc func(id, delta string) error
 232
 233	// OnToolInputEndFunc is called when tool input ends.
 234	OnToolInputEndFunc func(id string) error
 235
 236	// OnToolCallFunc is called when tool call is complete.
 237	OnToolCallFunc func(toolCall ToolCallContent) error
 238
 239	// OnToolResultFunc is called when tool execution completes.
 240	OnToolResultFunc func(result ToolResultContent) error
 241
 242	// OnSourceFunc is called for source references.
 243	OnSourceFunc func(source SourceContent) error
 244
 245	// OnStreamFinishFunc is called when stream finishes.
 246	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 247)
 248
 249// AgentStreamCall represents a streaming call to an agent.
 250type AgentStreamCall struct {
 251	Prompt           string     `json:"prompt"`
 252	Files            []FilePart `json:"files"`
 253	Messages         []Message  `json:"messages"`
 254	MaxOutputTokens  *int64
 255	Temperature      *float64 `json:"temperature"`
 256	TopP             *float64 `json:"top_p"`
 257	TopK             *int64   `json:"top_k"`
 258	PresencePenalty  *float64 `json:"presence_penalty"`
 259	FrequencyPenalty *float64 `json:"frequency_penalty"`
 260	ActiveTools      []string `json:"active_tools"`
 261	Headers          map[string]string
 262	ProviderOptions  ProviderOptions
 263	OnRetry          OnRetryCallback
 264	MaxRetries       *int
 265
 266	StopWhen       []StopCondition
 267	PrepareStep    PrepareStepFunction
 268	RepairToolCall RepairToolCallFunction
 269
 270	// Agent-level callbacks
 271	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 272	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 273	OnStepStart   OnStepStartFunc   // Called when a step starts
 274	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 275	OnFinish      OnFinishFunc      // Called when entire agent completes
 276	OnError       OnErrorFunc       // Called when an error occurs
 277
 278	// Stream part callbacks - called for each corresponding stream part type
 279	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 280	OnWarnings       OnWarningsFunc       // Called for warnings
 281	OnTextStart      OnTextStartFunc      // Called when text starts
 282	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 283	OnTextEnd        OnTextEndFunc        // Called when text ends
 284	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 285	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 286	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 287	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 288	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 289	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 290	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 291	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 292	OnSource         OnSourceFunc         // Called for source references
 293	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 294}
 295
 296// AgentResult represents the result of an agent execution.
 297type AgentResult struct {
 298	Steps []StepResult
 299	// Final response
 300	Response   Response
 301	TotalUsage Usage
 302}
 303
 304// Agent represents an AI agent that can generate responses and stream responses.
 305type Agent interface {
 306	Generate(context.Context, AgentCall) (*AgentResult, error)
 307	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 308}
 309
 310// AgentOption defines a function that configures agent settings.
 311type AgentOption = func(*agentSettings)
 312
 313type agent struct {
 314	settings agentSettings
 315}
 316
 317// NewAgent creates a new agent with the given language model and options.
 318func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 319	settings := agentSettings{
 320		model: model,
 321	}
 322	for _, o := range opts {
 323		o(&settings)
 324	}
 325	return &agent{
 326		settings: settings,
 327	}
 328}
 329
 330func (a *agent) prepareCall(call AgentCall) AgentCall {
 331	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 332	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 333	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 334	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 335	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 336	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 337	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 338
 339	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 340		call.StopWhen = a.settings.stopWhen
 341	}
 342	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 343		call.PrepareStep = a.settings.prepareStep
 344	}
 345	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 346		call.RepairToolCall = a.settings.repairToolCall
 347	}
 348	if call.OnRetry == nil && a.settings.onRetry != nil {
 349		call.OnRetry = a.settings.onRetry
 350	}
 351
 352	providerOptions := ProviderOptions{}
 353	if a.settings.providerOptions != nil {
 354		maps.Copy(providerOptions, a.settings.providerOptions)
 355	}
 356	if call.ProviderOptions != nil {
 357		maps.Copy(providerOptions, call.ProviderOptions)
 358	}
 359	call.ProviderOptions = providerOptions
 360
 361	headers := map[string]string{}
 362
 363	if a.settings.headers != nil {
 364		maps.Copy(headers, a.settings.headers)
 365	}
 366
 367	return call
 368}
 369
 370// Generate implements Agent.
 371func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 372	opts = a.prepareCall(opts)
 373	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 374	if err != nil {
 375		return nil, err
 376	}
 377	var responseMessages []Message
 378	var steps []StepResult
 379
 380	for {
 381		stepInputMessages := append(initialPrompt, responseMessages...)
 382		stepModel := a.settings.model
 383		stepSystemPrompt := a.settings.systemPrompt
 384		stepActiveTools := opts.ActiveTools
 385		stepToolChoice := ToolChoiceAuto
 386		disableAllTools := false
 387		stepTools := a.settings.tools
 388		if opts.PrepareStep != nil {
 389			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 390				Model:      stepModel,
 391				Steps:      steps,
 392				StepNumber: len(steps),
 393				Messages:   stepInputMessages,
 394			})
 395			if err != nil {
 396				return nil, err
 397			}
 398
 399			ctx = updatedCtx
 400
 401			// Apply prepared step modifications
 402			if prepared.Messages != nil {
 403				stepInputMessages = prepared.Messages
 404			}
 405			if prepared.Model != nil {
 406				stepModel = prepared.Model
 407			}
 408			if prepared.System != nil {
 409				stepSystemPrompt = *prepared.System
 410			}
 411			if prepared.ToolChoice != nil {
 412				stepToolChoice = *prepared.ToolChoice
 413			}
 414			if len(prepared.ActiveTools) > 0 {
 415				stepActiveTools = prepared.ActiveTools
 416			}
 417			disableAllTools = prepared.DisableAllTools
 418			if prepared.Tools != nil {
 419				stepTools = prepared.Tools
 420			}
 421		}
 422
 423		// Recreate prompt with potentially modified system prompt
 424		if stepSystemPrompt != a.settings.systemPrompt {
 425			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 426			if err != nil {
 427				return nil, err
 428			}
 429			// Replace system message part, keep the rest
 430			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 431				stepInputMessages[0] = stepPrompt[0] // Replace system message
 432			}
 433		}
 434
 435		preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
 436
 437		// Filter executable provider tools by activeTools at the
 438		// step level, consistent with how stepTools (AgentTools)
 439		// are scoped before being passed to inner functions.
 440		stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
 441
 442		retryOptions := DefaultRetryOptions()
 443		if opts.MaxRetries != nil {
 444			retryOptions.MaxRetries = *opts.MaxRetries
 445		}
 446		retryOptions.OnRetry = opts.OnRetry
 447		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 448		result, err := retry(ctx, func() (*Response, error) {
 449			return stepModel.Generate(ctx, Call{
 450				Prompt:           stepInputMessages,
 451				MaxOutputTokens:  opts.MaxOutputTokens,
 452				Temperature:      opts.Temperature,
 453				TopP:             opts.TopP,
 454				TopK:             opts.TopK,
 455				PresencePenalty:  opts.PresencePenalty,
 456				FrequencyPenalty: opts.FrequencyPenalty,
 457				Tools:            preparedTools,
 458				ToolChoice:       &stepToolChoice,
 459				UserAgent:        a.settings.userAgent,
 460				ProviderOptions:  opts.ProviderOptions,
 461			})
 462		})
 463		if err != nil {
 464			return nil, err
 465		}
 466
 467		var stepToolCalls []ToolCallContent
 468		for _, content := range result.Content {
 469			if content.GetType() == ContentTypeToolCall {
 470				toolCall, ok := AsContentType[ToolCallContent](content)
 471				if !ok {
 472					continue
 473				}
 474				// Provider-executed tool calls (e.g. web search) are
 475				// handled by the provider and should not be validated
 476				// or executed by the agent.
 477				if toolCall.ProviderExecuted {
 478					continue
 479				}
 480				// Validate and potentially repair the tool call
 481				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepExecProviderTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 482				stepToolCalls = append(stepToolCalls, validatedToolCall)
 483			}
 484		}
 485
 486		toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)
 487
 488		// If any tool result requested a stop, deliver all results but don't
 489		// request another completion from the model.
 490		stopTurnRequested := hasStopTurn(toolResults)
 491
 492		// Build step content with validated tool calls and tool results.
 493		// Provider-executed tool calls are kept as-is.
 494		stepContent := []Content{}
 495		toolCallIndex := 0
 496		for _, content := range result.Content {
 497			if content.GetType() == ContentTypeToolCall {
 498				tc, ok := AsContentType[ToolCallContent](content)
 499				if ok && tc.ProviderExecuted {
 500					stepContent = append(stepContent, content)
 501					continue
 502				}
 503				// Replace with validated tool call.
 504				if toolCallIndex < len(stepToolCalls) {
 505					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 506					toolCallIndex++
 507				}
 508			} else {
 509				stepContent = append(stepContent, content)
 510			}
 511		} // Add tool results
 512		for _, result := range toolResults {
 513			stepContent = append(stepContent, result)
 514		}
 515		currentStepMessages := toResponseMessages(stepContent)
 516		responseMessages = append(responseMessages, currentStepMessages...)
 517
 518		stepResult := StepResult{
 519			Response: Response{
 520				Content:          stepContent,
 521				FinishReason:     result.FinishReason,
 522				Usage:            result.Usage,
 523				Warnings:         result.Warnings,
 524				ProviderMetadata: result.ProviderMetadata,
 525			},
 526			Messages: currentStepMessages,
 527		}
 528		steps = append(steps, stepResult)
 529		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 530
 531		if shouldStop || err != nil || stopTurnRequested || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 532			break
 533		}
 534	}
 535
 536	totalUsage := Usage{}
 537
 538	for _, step := range steps {
 539		usage := step.Usage
 540		totalUsage.InputTokens += usage.InputTokens
 541		totalUsage.OutputTokens += usage.OutputTokens
 542		totalUsage.ReasoningTokens += usage.ReasoningTokens
 543		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 544		totalUsage.CacheReadTokens += usage.CacheReadTokens
 545		totalUsage.TotalTokens += usage.TotalTokens
 546	}
 547
 548	agentResult := &AgentResult{
 549		Steps:      steps,
 550		Response:   steps[len(steps)-1].Response,
 551		TotalUsage: totalUsage,
 552	}
 553	return agentResult, nil
 554}
 555
 556func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 557	if len(conditions) == 0 {
 558		return false
 559	}
 560
 561	for _, condition := range conditions {
 562		if condition(steps) {
 563			return true
 564		}
 565	}
 566	return false
 567}
 568
 569func hasStopTurn(results []ToolResultContent) bool {
 570	for _, r := range results {
 571		if r.StopTurn {
 572			return true
 573		}
 574	}
 575	return false
 576}
 577
 578func toResponseMessages(content []Content) []Message {
 579	var assistantParts []MessagePart
 580	var toolParts []MessagePart
 581
 582	for _, c := range content {
 583		switch c.GetType() {
 584		case ContentTypeText:
 585			text, ok := AsContentType[TextContent](c)
 586			if !ok {
 587				continue
 588			}
 589			assistantParts = append(assistantParts, TextPart{
 590				Text:            text.Text,
 591				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 592			})
 593		case ContentTypeReasoning:
 594			reasoning, ok := AsContentType[ReasoningContent](c)
 595			if !ok {
 596				continue
 597			}
 598			assistantParts = append(assistantParts, ReasoningPart{
 599				Text:            reasoning.Text,
 600				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 601			})
 602		case ContentTypeToolCall:
 603			toolCall, ok := AsContentType[ToolCallContent](c)
 604			if !ok {
 605				continue
 606			}
 607			assistantParts = append(assistantParts, ToolCallPart{
 608				ToolCallID:       toolCall.ToolCallID,
 609				ToolName:         toolCall.ToolName,
 610				Input:            toolCall.Input,
 611				ProviderExecuted: toolCall.ProviderExecuted,
 612				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 613			})
 614		case ContentTypeFile:
 615			file, ok := AsContentType[FileContent](c)
 616			if !ok {
 617				continue
 618			}
 619			assistantParts = append(assistantParts, FilePart{
 620				Data:            file.Data,
 621				MediaType:       file.MediaType,
 622				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 623			})
 624		case ContentTypeSource:
 625			// Sources are metadata about references used to generate the response.
 626			// They don't need to be included in the conversation messages.
 627			continue
 628		case ContentTypeToolResult:
 629			result, ok := AsContentType[ToolResultContent](c)
 630			if !ok {
 631				continue
 632			}
 633			resultPart := ToolResultPart{
 634				ToolCallID:       result.ToolCallID,
 635				Output:           result.Result,
 636				ProviderExecuted: result.ProviderExecuted,
 637				ProviderOptions:  ProviderOptions(result.ProviderMetadata),
 638			}
 639			if result.ProviderExecuted {
 640				// Provider-executed tool results (e.g. web search)
 641				// belong in the assistant message alongside the
 642				// server_tool_use block that produced them.
 643				assistantParts = append(assistantParts, resultPart)
 644			} else {
 645				toolParts = append(toolParts, resultPart)
 646			}
 647		}
 648	}
 649
 650	var messages []Message
 651	if len(assistantParts) > 0 {
 652		messages = append(messages, Message{
 653			Role:    MessageRoleAssistant,
 654			Content: assistantParts,
 655		})
 656	}
 657	if len(toolParts) > 0 {
 658		messages = append(messages, Message{
 659			Role:    MessageRoleTool,
 660			Content: toolParts,
 661		})
 662	}
 663	return messages
 664}
 665
 666func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, execProviderTools []ExecutableProviderTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 667	if len(toolCalls) == 0 {
 668		return nil, nil
 669	}
 670
 671	// Create a map for quick tool lookup
 672	toolMap := make(map[string]AgentTool)
 673	for _, tool := range allTools {
 674		toolMap[tool.Info().Name] = tool
 675	}
 676
 677	execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
 678	for _, ept := range execProviderTools {
 679		execProviderToolMap[ept.GetName()] = ept
 680	}
 681
 682	// Execute all tool calls sequentially in order
 683	results := make([]ToolResultContent, 0, len(toolCalls))
 684
 685	for _, toolCall := range toolCalls {
 686		result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, toolCall, toolResultCallback)
 687		results = append(results, result)
 688		if isCriticalError {
 689			if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
 690				return nil, errorResult.Error
 691			}
 692		}
 693	}
 694
 695	return results, nil
 696}
 697
 698// executeSingleTool executes a single tool and returns its result and a critical error flag.
 699func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, execProviderToolMap map[string]ExecutableProviderTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
 700	result := ToolResultContent{
 701		ToolCallID:       toolCall.ToolCallID,
 702		ToolName:         toolCall.ToolName,
 703		ProviderExecuted: false,
 704	}
 705
 706	// Skip invalid tool calls - create error result (not critical)
 707	if toolCall.Invalid {
 708		result.Result = ToolResultOutputContentError{
 709			Error: toolCall.ValidationError,
 710		}
 711		if toolResultCallback != nil {
 712			_ = toolResultCallback(result)
 713		}
 714		return result, false
 715	}
 716
 717	// Find the run function — either from a regular AgentTool or an
 718	// executable provider tool.
 719	var runTool func(ctx context.Context, call ToolCall) (ToolResponse, error)
 720	if tool, exists := toolMap[toolCall.ToolName]; exists {
 721		runTool = tool.Run
 722	} else if ept, ok := execProviderToolMap[toolCall.ToolName]; ok {
 723		runTool = ept.Run
 724	}
 725	if runTool == nil {
 726		result.Result = ToolResultOutputContentError{
 727			Error: errors.New("tool not found: " + toolCall.ToolName),
 728		}
 729		if toolResultCallback != nil {
 730			_ = toolResultCallback(result)
 731		}
 732		return result, false
 733	}
 734
 735	// Execute the tool
 736	toolResult, err := runTool(ctx, ToolCall{
 737		ID:    toolCall.ToolCallID,
 738		Name:  toolCall.ToolName,
 739		Input: toolCall.Input,
 740	})
 741	if err != nil {
 742		result.Result = ToolResultOutputContentError{
 743			Error: err,
 744		}
 745		result.ClientMetadata = toolResult.Metadata
 746		result.StopTurn = toolResult.StopTurn
 747		if toolResultCallback != nil {
 748			_ = toolResultCallback(result)
 749		}
 750		return result, true
 751	}
 752
 753	result.ClientMetadata = toolResult.Metadata
 754	result.StopTurn = toolResult.StopTurn
 755	if toolResult.IsError {
 756		result.Result = ToolResultOutputContentError{
 757			Error: errors.New(toolResult.Content),
 758		}
 759	} else if toolResult.Type == "image" || toolResult.Type == "media" {
 760		result.Result = ToolResultOutputContentMedia{
 761			Data:      base64.StdEncoding.EncodeToString(toolResult.Data),
 762			MediaType: toolResult.MediaType,
 763			Text:      toolResult.Content,
 764		}
 765	} else {
 766		result.Result = ToolResultOutputContentText{
 767			Text: toolResult.Content,
 768		}
 769	}
 770	if toolResultCallback != nil {
 771		_ = toolResultCallback(result)
 772	}
 773	return result, false
 774}
 775
 776// Stream implements Agent.
 777func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 778	// Convert AgentStreamCall to AgentCall for preparation
 779	call := AgentCall{
 780		Prompt:           opts.Prompt,
 781		Files:            opts.Files,
 782		Messages:         opts.Messages,
 783		MaxOutputTokens:  opts.MaxOutputTokens,
 784		Temperature:      opts.Temperature,
 785		TopP:             opts.TopP,
 786		TopK:             opts.TopK,
 787		PresencePenalty:  opts.PresencePenalty,
 788		FrequencyPenalty: opts.FrequencyPenalty,
 789		ActiveTools:      opts.ActiveTools,
 790		ProviderOptions:  opts.ProviderOptions,
 791		MaxRetries:       opts.MaxRetries,
 792		OnRetry:          opts.OnRetry,
 793		StopWhen:         opts.StopWhen,
 794		PrepareStep:      opts.PrepareStep,
 795		RepairToolCall:   opts.RepairToolCall,
 796	}
 797
 798	call = a.prepareCall(call)
 799
 800	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 801	if err != nil {
 802		return nil, err
 803	}
 804
 805	var responseMessages []Message
 806	var steps []StepResult
 807	var totalUsage Usage
 808
 809	// Start agent stream
 810	if opts.OnAgentStart != nil {
 811		opts.OnAgentStart()
 812	}
 813
 814	for stepNumber := 0; ; stepNumber++ {
 815		stepInputMessages := append(initialPrompt, responseMessages...)
 816		stepModel := a.settings.model
 817		stepSystemPrompt := a.settings.systemPrompt
 818		stepActiveTools := call.ActiveTools
 819		stepToolChoice := ToolChoiceAuto
 820		disableAllTools := false
 821		stepTools := a.settings.tools
 822		// Apply step preparation if provided
 823		if call.PrepareStep != nil {
 824			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 825				Model:      stepModel,
 826				Steps:      steps,
 827				StepNumber: stepNumber,
 828				Messages:   stepInputMessages,
 829			})
 830			if err != nil {
 831				return nil, err
 832			}
 833
 834			ctx = updatedCtx
 835
 836			if prepared.Messages != nil {
 837				stepInputMessages = prepared.Messages
 838			}
 839			if prepared.Model != nil {
 840				stepModel = prepared.Model
 841			}
 842			if prepared.System != nil {
 843				stepSystemPrompt = *prepared.System
 844			}
 845			if prepared.ToolChoice != nil {
 846				stepToolChoice = *prepared.ToolChoice
 847			}
 848			if len(prepared.ActiveTools) > 0 {
 849				stepActiveTools = prepared.ActiveTools
 850			}
 851			disableAllTools = prepared.DisableAllTools
 852			if prepared.Tools != nil {
 853				stepTools = prepared.Tools
 854			}
 855		}
 856
 857		// Recreate prompt with potentially modified system prompt
 858		if stepSystemPrompt != a.settings.systemPrompt {
 859			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 860			if err != nil {
 861				return nil, err
 862			}
 863			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 864				stepInputMessages[0] = stepPrompt[0]
 865			}
 866		}
 867
 868		preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
 869
 870		// Filter executable provider tools by activeTools at the
 871		// step level, consistent with how stepTools (AgentTools)
 872		// are scoped before being passed to inner functions.
 873		stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
 874
 875		// Start step stream
 876		if opts.OnStepStart != nil {
 877			_ = opts.OnStepStart(stepNumber)
 878		}
 879		// Create streaming call
 880		streamCall := Call{
 881			Prompt:           stepInputMessages,
 882			MaxOutputTokens:  call.MaxOutputTokens,
 883			Temperature:      call.Temperature,
 884			TopP:             call.TopP,
 885			TopK:             call.TopK,
 886			PresencePenalty:  call.PresencePenalty,
 887			FrequencyPenalty: call.FrequencyPenalty,
 888			Tools:            preparedTools,
 889			ToolChoice:       &stepToolChoice,
 890			UserAgent:        a.settings.userAgent,
 891			ProviderOptions:  call.ProviderOptions,
 892		}
 893
 894		// Execute step with retry logic wrapping both stream creation and processing
 895		retryOptions := DefaultRetryOptions()
 896		if call.MaxRetries != nil {
 897			retryOptions.MaxRetries = *call.MaxRetries
 898		}
 899		retryOptions.OnRetry = call.OnRetry
 900		retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
 901
 902		result, err := retry(ctx, func() (stepExecutionResult, error) {
 903			// Create the stream
 904			stream, err := stepModel.Stream(ctx, streamCall)
 905			if err != nil {
 906				return stepExecutionResult{}, err
 907			}
 908
 909			// Process the stream
 910			result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools)
 911			if err != nil {
 912				return stepExecutionResult{}, err
 913			}
 914			return result, nil
 915		})
 916		if err != nil {
 917			if opts.OnError != nil {
 918				opts.OnError(err)
 919			}
 920			return nil, err
 921		}
 922
 923		steps = append(steps, result.StepResult)
 924		totalUsage = addUsage(totalUsage, result.StepResult.Usage)
 925
 926		// Call step finished callback
 927		if opts.OnStepFinish != nil {
 928			_ = opts.OnStepFinish(result.StepResult)
 929		}
 930
 931		// Add step messages to response messages
 932		stepMessages := toResponseMessages(result.StepResult.Content)
 933		responseMessages = append(responseMessages, stepMessages...)
 934
 935		// Check stop conditions
 936		shouldStop := isStopConditionMet(call.StopWhen, steps)
 937		if shouldStop || !result.ShouldContinue {
 938			break
 939		}
 940	}
 941
 942	// Finish agent stream
 943	agentResult := &AgentResult{
 944		Steps:      steps,
 945		Response:   steps[len(steps)-1].Response,
 946		TotalUsage: totalUsage,
 947	}
 948
 949	if opts.OnFinish != nil {
 950		opts.OnFinish(agentResult)
 951	}
 952
 953	if opts.OnAgentFinish != nil {
 954		_ = opts.OnAgentFinish(agentResult)
 955	}
 956
 957	return agentResult, nil
 958}
 959
 960// filterExecProviderTools returns the subset of executable provider
 961// tools permitted by activeTools. When activeTools is empty every
 962// tool is included (no filtering).
 963func (a *agent) filterExecProviderTools(activeTools []string) []ExecutableProviderTool {
 964	if len(activeTools) == 0 {
 965		return a.settings.executableProviderTools
 966	}
 967	filtered := make([]ExecutableProviderTool, 0, len(a.settings.executableProviderTools))
 968	for _, ept := range a.settings.executableProviderTools {
 969		if slices.Contains(activeTools, ept.GetName()) {
 970			filtered = append(filtered, ept)
 971		}
 972	}
 973	return filtered
 974}
 975
 976func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
 977	preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
 978
 979	// If explicitly disabling all tools, return no tools
 980	if disableAllTools {
 981		return preparedTools
 982	}
 983
 984	for _, tool := range tools {
 985		// If activeTools has items, only include tools in the list
 986		// If activeTools is empty, include all tools
 987		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 988			continue
 989		}
 990		info := tool.Info()
 991		inputSchema := map[string]any{
 992			"type":       "object",
 993			"properties": info.Parameters,
 994			"required":   info.Required,
 995		}
 996		schema.Normalize(inputSchema)
 997		preparedTools = append(preparedTools, FunctionTool{
 998			Name:            info.Name,
 999			Description:     info.Description,
1000			InputSchema:     inputSchema,
1001			ProviderOptions: tool.ProviderOptions(),
1002		})
1003	}
1004	for _, tool := range providerDefinedTools {
1005		// If activeTools has items, only include tools in the list. If
1006		// activeTools is empty, include all tools
1007		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
1008			continue
1009		}
1010		preparedTools = append(preparedTools, tool)
1011	}
1012	return preparedTools
1013}
1014
1015// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
1016func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
1017	if err := a.validateToolCall(toolCall, availableTools, execProviderTools); err == nil {
1018		return toolCall
1019	} else { //nolint: revive
1020		if repairFunc != nil {
1021			repairOptions := ToolCallRepairOptions{
1022				OriginalToolCall: toolCall,
1023				ValidationError:  err,
1024				AvailableTools:   availableTools,
1025				SystemPrompt:     systemPrompt,
1026				Messages:         messages,
1027			}
1028
1029			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
1030				if validateErr := a.validateToolCall(*repairedToolCall, availableTools, execProviderTools); validateErr == nil {
1031					return *repairedToolCall
1032				}
1033			}
1034		}
1035
1036		invalidToolCall := toolCall
1037		invalidToolCall.Invalid = true
1038		invalidToolCall.ValidationError = err
1039		return invalidToolCall
1040	}
1041}
1042
1043// validateToolCall validates a tool call against available tools and their schemas.
1044// Both availableTools and execProviderTools must already be filtered by the
1045// caller (e.g. via activeTools); this function trusts that the slices
1046// represent exactly the tools permitted for the current step.
1047func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool) error {
1048	var tool AgentTool
1049	for _, t := range availableTools {
1050		if t.Info().Name == toolCall.ToolName {
1051			tool = t
1052			break
1053		}
1054	}
1055
1056	if tool == nil {
1057		// Check if this is an executable provider tool. Provider-
1058		// defined tools have their schema enforced server-side, so
1059		// we only validate that the input is parseable JSON.
1060		for _, ept := range execProviderTools {
1061			if ept.GetName() == toolCall.ToolName {
1062				var input map[string]any
1063				if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1064					return fmt.Errorf("invalid JSON input: %w", err)
1065				}
1066				return nil
1067			}
1068		}
1069		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1070	}
1071
1072	// Validate JSON parsing
1073	var input map[string]any
1074	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1075		return fmt.Errorf("invalid JSON input: %w", err)
1076	}
1077
1078	// Basic schema validation (check required fields)
1079	// TODO: more robust schema validation using JSON Schema or similar
1080	toolInfo := tool.Info()
1081	for _, required := range toolInfo.Required {
1082		if _, exists := input[required]; !exists {
1083			return fmt.Errorf("missing required parameter: %s", required)
1084		}
1085	}
1086	return nil
1087}
1088
1089func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1090	// Validation: empty prompt is only allowed when there are messages,
1091	// no files to attach, and the last message is a user or tool message.
1092	if prompt == "" {
1093		lastMessage, hasMessages := slice.Last(messages)
1094
1095		if !hasMessages {
1096			return nil, &Error{
1097				Title:   "invalid argument",
1098				Message: "prompt can't be empty when there are no messages",
1099			}
1100		}
1101
1102		if len(files) > 0 {
1103			return nil, &Error{
1104				Title:   "invalid argument",
1105				Message: "prompt can't be empty when there are files",
1106			}
1107		}
1108
1109		switch lastMessage.Role {
1110		case MessageRoleUser, MessageRoleTool:
1111		default:
1112			return nil, &Error{
1113				Title:   "invalid argument",
1114				Message: "prompt can't be empty when the last message is not a user or tool message",
1115			}
1116		}
1117	}
1118
1119	var preparedPrompt Prompt
1120
1121	if system != "" {
1122		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1123	}
1124	preparedPrompt = append(preparedPrompt, messages...)
1125	if prompt != "" {
1126		preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1127	}
1128	return preparedPrompt, nil
1129}
1130
1131// WithSystemPrompt sets the system prompt for the agent.
1132func WithSystemPrompt(prompt string) AgentOption {
1133	return func(s *agentSettings) {
1134		s.systemPrompt = prompt
1135	}
1136}
1137
1138// WithMaxOutputTokens sets the maximum output tokens for the agent.
1139func WithMaxOutputTokens(tokens int64) AgentOption {
1140	return func(s *agentSettings) {
1141		s.maxOutputTokens = &tokens
1142	}
1143}
1144
1145// WithTemperature sets the temperature for the agent.
1146func WithTemperature(temp float64) AgentOption {
1147	return func(s *agentSettings) {
1148		s.temperature = &temp
1149	}
1150}
1151
1152// WithTopP sets the top-p value for the agent.
1153func WithTopP(topP float64) AgentOption {
1154	return func(s *agentSettings) {
1155		s.topP = &topP
1156	}
1157}
1158
1159// WithTopK sets the top-k value for the agent.
1160func WithTopK(topK int64) AgentOption {
1161	return func(s *agentSettings) {
1162		s.topK = &topK
1163	}
1164}
1165
1166// WithPresencePenalty sets the presence penalty for the agent.
1167func WithPresencePenalty(penalty float64) AgentOption {
1168	return func(s *agentSettings) {
1169		s.presencePenalty = &penalty
1170	}
1171}
1172
1173// WithFrequencyPenalty sets the frequency penalty for the agent.
1174func WithFrequencyPenalty(penalty float64) AgentOption {
1175	return func(s *agentSettings) {
1176		s.frequencyPenalty = &penalty
1177	}
1178}
1179
1180// WithTools sets the tools for the agent.
1181func WithTools(tools ...AgentTool) AgentOption {
1182	return func(s *agentSettings) {
1183		s.tools = append(s.tools, tools...)
1184	}
1185}
1186
1187// WithProviderDefinedTools registers provider-defined tools with the
1188// agent. Provider-executed tools (e.g. web search) are passed through
1189// to the API. Client-executed tools (ExecutableProviderTool) are also
1190// registered for local execution.
1191func WithProviderDefinedTools(tools ...ProviderTool) AgentOption {
1192	return func(s *agentSettings) {
1193		for _, t := range tools {
1194			// Every provider tool goes into providerDefinedTools
1195			// for wire formatting.
1196			s.providerDefinedTools = append(
1197				s.providerDefinedTools, t.providerDefinedTool(),
1198			)
1199			// Executable ones also register for local execution.
1200			if exec, ok := t.(ExecutableProviderTool); ok {
1201				s.executableProviderTools = append(
1202					s.executableProviderTools, exec,
1203				)
1204			}
1205		}
1206	}
1207}
1208
1209// WithStopConditions sets the stop conditions for the agent.
1210func WithStopConditions(conditions ...StopCondition) AgentOption {
1211	return func(s *agentSettings) {
1212		s.stopWhen = append(s.stopWhen, conditions...)
1213	}
1214}
1215
1216// WithPrepareStep sets the prepare step function for the agent.
1217func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1218	return func(s *agentSettings) {
1219		s.prepareStep = fn
1220	}
1221}
1222
1223// WithRepairToolCall sets the repair tool call function for the agent.
1224func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1225	return func(s *agentSettings) {
1226		s.repairToolCall = fn
1227	}
1228}
1229
1230// WithMaxRetries sets the maximum number of retries for the agent.
1231func WithMaxRetries(maxRetries int) AgentOption {
1232	return func(s *agentSettings) {
1233		s.maxRetries = &maxRetries
1234	}
1235}
1236
1237// WithOnRetry sets the retry callback for the agent.
1238func WithOnRetry(callback OnRetryCallback) AgentOption {
1239	return func(s *agentSettings) {
1240		s.onRetry = callback
1241	}
1242}
1243
1244// processStepStream processes a single step's stream and returns the step result.
1245func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) {
1246	var stepContent []Content
1247	var stepToolCalls []ToolCallContent
1248	var stepUsage Usage
1249	stepFinishReason := FinishReasonUnknown
1250	var stepWarnings []CallWarning
1251	var stepProviderMetadata ProviderMetadata
1252
1253	activeToolCalls := make(map[string]*ToolCallContent)
1254	activeTextContent := make(map[string]string)
1255	type reasoningContent struct {
1256		content string
1257		options ProviderMetadata
1258	}
1259	activeReasoningContent := make(map[string]reasoningContent)
1260
1261	// Set up concurrent tool execution
1262	type toolExecutionRequest struct {
1263		toolCall ToolCallContent
1264		parallel bool
1265	}
1266	var pendingDispatches []toolExecutionRequest
1267
1268	// Create a map for quick tool lookup
1269	toolMap := make(map[string]AgentTool)
1270	for _, tool := range stepTools {
1271		toolMap[tool.Info().Name] = tool
1272	}
1273
1274	execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
1275	for _, ept := range execProviderTools {
1276		execProviderToolMap[ept.GetName()] = ept
1277	}
1278
1279	// Process stream parts
1280	for part := range stream {
1281		// Forward all parts to chunk callback
1282		if opts.OnChunk != nil {
1283			err := opts.OnChunk(part)
1284			if err != nil {
1285				return stepExecutionResult{}, err
1286			}
1287		}
1288
1289		switch part.Type {
1290		case StreamPartTypeWarnings:
1291			stepWarnings = part.Warnings
1292			if opts.OnWarnings != nil {
1293				err := opts.OnWarnings(part.Warnings)
1294				if err != nil {
1295					return stepExecutionResult{}, err
1296				}
1297			}
1298
1299		case StreamPartTypeTextStart:
1300			activeTextContent[part.ID] = ""
1301			if opts.OnTextStart != nil {
1302				err := opts.OnTextStart(part.ID)
1303				if err != nil {
1304					return stepExecutionResult{}, err
1305				}
1306			}
1307
1308		case StreamPartTypeTextDelta:
1309			if _, exists := activeTextContent[part.ID]; exists {
1310				activeTextContent[part.ID] += part.Delta
1311			}
1312			if opts.OnTextDelta != nil {
1313				err := opts.OnTextDelta(part.ID, part.Delta)
1314				if err != nil {
1315					return stepExecutionResult{}, err
1316				}
1317			}
1318
1319		case StreamPartTypeTextEnd:
1320			if text, exists := activeTextContent[part.ID]; exists {
1321				stepContent = append(stepContent, TextContent{
1322					Text:             text,
1323					ProviderMetadata: part.ProviderMetadata,
1324				})
1325				delete(activeTextContent, part.ID)
1326			}
1327			if opts.OnTextEnd != nil {
1328				err := opts.OnTextEnd(part.ID)
1329				if err != nil {
1330					return stepExecutionResult{}, err
1331				}
1332			}
1333
1334		case StreamPartTypeReasoningStart:
1335			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1336			if opts.OnReasoningStart != nil {
1337				content := ReasoningContent{
1338					Text:             part.Delta,
1339					ProviderMetadata: part.ProviderMetadata,
1340				}
1341				err := opts.OnReasoningStart(part.ID, content)
1342				if err != nil {
1343					return stepExecutionResult{}, err
1344				}
1345			}
1346
1347		case StreamPartTypeReasoningDelta:
1348			if active, exists := activeReasoningContent[part.ID]; exists {
1349				active.content += part.Delta
1350				if part.ProviderMetadata != nil {
1351					active.options = part.ProviderMetadata
1352				}
1353				activeReasoningContent[part.ID] = active
1354			}
1355			if opts.OnReasoningDelta != nil {
1356				err := opts.OnReasoningDelta(part.ID, part.Delta)
1357				if err != nil {
1358					return stepExecutionResult{}, err
1359				}
1360			}
1361
1362		case StreamPartTypeReasoningEnd:
1363			if active, exists := activeReasoningContent[part.ID]; exists {
1364				if part.ProviderMetadata != nil {
1365					active.options = part.ProviderMetadata
1366				}
1367				content := ReasoningContent{
1368					Text:             active.content,
1369					ProviderMetadata: active.options,
1370				}
1371				stepContent = append(stepContent, content)
1372				if opts.OnReasoningEnd != nil {
1373					err := opts.OnReasoningEnd(part.ID, content)
1374					if err != nil {
1375						return stepExecutionResult{}, err
1376					}
1377				}
1378				delete(activeReasoningContent, part.ID)
1379			}
1380
1381		case StreamPartTypeToolInputStart:
1382			activeToolCalls[part.ID] = &ToolCallContent{
1383				ToolCallID:       part.ID,
1384				ToolName:         part.ToolCallName,
1385				Input:            "",
1386				ProviderExecuted: part.ProviderExecuted,
1387			}
1388			if opts.OnToolInputStart != nil {
1389				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1390				if err != nil {
1391					return stepExecutionResult{}, err
1392				}
1393			}
1394
1395		case StreamPartTypeToolInputDelta:
1396			if toolCall, exists := activeToolCalls[part.ID]; exists {
1397				toolCall.Input += part.Delta
1398			}
1399			if opts.OnToolInputDelta != nil {
1400				err := opts.OnToolInputDelta(part.ID, part.Delta)
1401				if err != nil {
1402					return stepExecutionResult{}, err
1403				}
1404			}
1405
1406		case StreamPartTypeToolInputEnd:
1407			if opts.OnToolInputEnd != nil {
1408				err := opts.OnToolInputEnd(part.ID)
1409				if err != nil {
1410					return stepExecutionResult{}, err
1411				}
1412			}
1413
1414		case StreamPartTypeToolCall:
1415			toolCall := ToolCallContent{
1416				ToolCallID:       part.ID,
1417				ToolName:         part.ToolCallName,
1418				Input:            part.ToolCallInput,
1419				ProviderExecuted: part.ProviderExecuted,
1420				ProviderMetadata: part.ProviderMetadata,
1421			}
1422
1423			// Provider-executed tool calls are handled by the provider
1424			// and should not be validated or executed by the agent.
1425			if toolCall.ProviderExecuted {
1426				stepContent = append(stepContent, toolCall)
1427				if opts.OnToolCall != nil {
1428					err := opts.OnToolCall(toolCall)
1429					if err != nil {
1430						return stepExecutionResult{}, err
1431					}
1432				}
1433				delete(activeToolCalls, part.ID)
1434			} else {
1435				// Validate and potentially repair the tool call
1436				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1437				stepToolCalls = append(stepToolCalls, validatedToolCall)
1438				stepContent = append(stepContent, validatedToolCall)
1439
1440				if opts.OnToolCall != nil {
1441					err := opts.OnToolCall(validatedToolCall)
1442					if err != nil {
1443						return stepExecutionResult{}, err
1444					}
1445				}
1446
1447				// Determine if tool can run in parallel
1448				isParallel := false
1449				if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1450					isParallel = tool.Info().Parallel
1451				}
1452
1453				// Buffer dispatch until stream is fully consumed so that all
1454				// OnToolCall callbacks complete before any tool result is written.
1455				pendingDispatches = append(pendingDispatches, toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel})
1456
1457				// Clean up active tool call
1458				delete(activeToolCalls, part.ID)
1459			}
1460
1461		case StreamPartTypeToolResult:
1462			// Provider-executed tool results (e.g. web search)
1463			// are emitted by the provider and added directly
1464			// to the step content for multi-turn round-tripping.
1465			if part.ProviderExecuted {
1466				resultContent := ToolResultContent{
1467					ToolCallID:       part.ID,
1468					ToolName:         part.ToolCallName,
1469					ProviderExecuted: true,
1470					ProviderMetadata: part.ProviderMetadata,
1471				}
1472				stepContent = append(stepContent, resultContent)
1473				if opts.OnToolResult != nil {
1474					err := opts.OnToolResult(resultContent)
1475					if err != nil {
1476						return stepExecutionResult{}, err
1477					}
1478				}
1479			}
1480
1481		case StreamPartTypeSource:
1482			sourceContent := SourceContent{
1483				SourceType:       part.SourceType,
1484				ID:               part.ID,
1485				URL:              part.URL,
1486				Title:            part.Title,
1487				ProviderMetadata: part.ProviderMetadata,
1488			}
1489			stepContent = append(stepContent, sourceContent)
1490			if opts.OnSource != nil {
1491				err := opts.OnSource(sourceContent)
1492				if err != nil {
1493					return stepExecutionResult{}, err
1494				}
1495			}
1496
1497		case StreamPartTypeFinish:
1498			stepUsage = part.Usage
1499			stepFinishReason = part.FinishReason
1500			stepProviderMetadata = part.ProviderMetadata
1501			if opts.OnStreamFinish != nil {
1502				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1503				if err != nil {
1504					return stepExecutionResult{}, err
1505				}
1506			}
1507
1508		case StreamPartTypeError:
1509			return stepExecutionResult{}, part.Error
1510		}
1511	}
1512
1513	// All tool calls are now collected. Create the execution channel sized to
1514	// avoid blocking during dispatch, start the coordinator, then flush the batch.
1515	toolChan := make(chan toolExecutionRequest, len(pendingDispatches))
1516	var toolExecutionWg sync.WaitGroup
1517	var toolStateMu sync.Mutex
1518	toolResults := make([]ToolResultContent, 0, len(pendingDispatches))
1519	var toolExecutionErr error
1520
1521	// Semaphores for controlling parallelism.
1522	parallelSem := make(chan struct{}, 5)
1523	var sequentialMu sync.Mutex
1524
1525	// Single coordinator goroutine that dispatches tools.
1526	toolExecutionWg.Go(func() {
1527		for req := range toolChan {
1528			if req.parallel {
1529				parallelSem <- struct{}{}
1530				toolExecutionWg.Go(func() {
1531					defer func() { <-parallelSem }()
1532					result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1533					toolStateMu.Lock()
1534					toolResults = append(toolResults, result)
1535					if isCriticalError && toolExecutionErr == nil {
1536						if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1537							toolExecutionErr = errorResult.Error
1538						}
1539					}
1540					toolStateMu.Unlock()
1541				})
1542			} else {
1543				sequentialMu.Lock()
1544				result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1545				toolStateMu.Lock()
1546				toolResults = append(toolResults, result)
1547				if isCriticalError && toolExecutionErr == nil {
1548					if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1549						toolExecutionErr = errorResult.Error
1550					}
1551				}
1552				toolStateMu.Unlock()
1553				sequentialMu.Unlock()
1554			}
1555		}
1556	})
1557
1558	// Dispatch all buffered tool calls now that every OnToolCall callback has
1559	// been called, then close and wait.
1560	for _, req := range pendingDispatches {
1561		toolChan <- req
1562	}
1563
1564	// Close the tool execution channel and wait for all executions to complete.
1565	close(toolChan)
1566	toolExecutionWg.Wait()
1567
1568	// Check for tool execution errors
1569	if toolExecutionErr != nil {
1570		return stepExecutionResult{}, toolExecutionErr
1571	}
1572
1573	// Add tool results to content if any
1574	if len(toolResults) > 0 {
1575		for _, result := range toolResults {
1576			stepContent = append(stepContent, result)
1577		}
1578	}
1579
1580	stepResult := StepResult{
1581		Response: Response{
1582			Content:          stepContent,
1583			FinishReason:     stepFinishReason,
1584			Usage:            stepUsage,
1585			Warnings:         stepWarnings,
1586			ProviderMetadata: stepProviderMetadata,
1587		},
1588		Messages: toResponseMessages(stepContent),
1589	}
1590
1591	// Determine if we should continue (has tool calls and not stopped)
1592	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls && !hasStopTurn(toolResults)
1593
1594	return stepExecutionResult{
1595		StepResult:     stepResult,
1596		ShouldContinue: shouldContinue,
1597	}, nil
1598}
1599
1600func addUsage(a, b Usage) Usage {
1601	return Usage{
1602		InputTokens:         a.InputTokens + b.InputTokens,
1603		OutputTokens:        a.OutputTokens + b.OutputTokens,
1604		TotalTokens:         a.TotalTokens + b.TotalTokens,
1605		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1606		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1607		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1608	}
1609}
1610
1611// WithHeaders sets the headers for the agent.
1612func WithHeaders(headers map[string]string) AgentOption {
1613	return func(s *agentSettings) {
1614		s.headers = headers
1615	}
1616}
1617
1618// WithUserAgent sets the User-Agent header for the agent. This overrides any
1619// provider-level User-Agent setting.
1620func WithUserAgent(ua string) AgentOption {
1621	return func(s *agentSettings) {
1622		s.userAgent = ua
1623	}
1624}
1625
1626// WithProviderOptions sets the provider options for the agent.
1627func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1628	return func(s *agentSettings) {
1629		s.providerOptions = providerOptions
1630	}
1631}