agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11	"sync"
  12
  13	"charm.land/fantasy/schema"
  14)
  15
  16// StepResult represents the result of a single step in an agent execution.
  17type StepResult struct {
  18	Response
  19	Messages []Message
  20}
  21
  22// stepExecutionResult encapsulates the result of executing a step with stream processing.
  23type stepExecutionResult struct {
  24	StepResult     StepResult
  25	ShouldContinue bool
  26}
  27
  28// StopCondition defines a function that determines when an agent should stop executing.
  29type StopCondition = func(steps []StepResult) bool
  30
  31// StepCountIs returns a stop condition that stops after the specified number of steps.
  32func StepCountIs(stepCount int) StopCondition {
  33	return func(steps []StepResult) bool {
  34		return len(steps) >= stepCount
  35	}
  36}
  37
  38// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  39func HasToolCall(toolName string) StopCondition {
  40	return func(steps []StepResult) bool {
  41		if len(steps) == 0 {
  42			return false
  43		}
  44		lastStep := steps[len(steps)-1]
  45		toolCalls := lastStep.Content.ToolCalls()
  46		for _, toolCall := range toolCalls {
  47			if toolCall.ToolName == toolName {
  48				return true
  49			}
  50		}
  51		return false
  52	}
  53}
  54
  55// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  56func HasContent(contentType ContentType) StopCondition {
  57	return func(steps []StepResult) bool {
  58		if len(steps) == 0 {
  59			return false
  60		}
  61		lastStep := steps[len(steps)-1]
  62		for _, content := range lastStep.Content {
  63			if content.GetType() == contentType {
  64				return true
  65			}
  66		}
  67		return false
  68	}
  69}
  70
  71// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  72func FinishReasonIs(reason FinishReason) StopCondition {
  73	return func(steps []StepResult) bool {
  74		if len(steps) == 0 {
  75			return false
  76		}
  77		lastStep := steps[len(steps)-1]
  78		return lastStep.FinishReason == reason
  79	}
  80}
  81
  82// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  83func MaxTokensUsed(maxTokens int64) StopCondition {
  84	return func(steps []StepResult) bool {
  85		var totalTokens int64
  86		for _, step := range steps {
  87			totalTokens += step.Usage.TotalTokens
  88		}
  89		return totalTokens >= maxTokens
  90	}
  91}
  92
  93// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  94type PrepareStepFunctionOptions struct {
  95	Steps      []StepResult
  96	StepNumber int
  97	Model      LanguageModel
  98	Messages   []Message
  99}
 100
 101// PrepareStepResult contains the result of preparing a step in an agent execution.
 102type PrepareStepResult struct {
 103	Model           LanguageModel
 104	Messages        []Message
 105	System          *string
 106	ToolChoice      *ToolChoice
 107	ActiveTools     []string
 108	DisableAllTools bool
 109	Tools           []AgentTool
 110}
 111
 112// ToolCallRepairOptions contains the options for repairing a tool call.
 113type ToolCallRepairOptions struct {
 114	OriginalToolCall ToolCallContent
 115	ValidationError  error
 116	AvailableTools   []AgentTool
 117	SystemPrompt     string
 118	Messages         []Message
 119}
 120
 121type (
 122	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 123	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 124
 125	// OnStepFinishedFunction defines a function that is called when a step finishes.
 126	OnStepFinishedFunction = func(step StepResult)
 127
 128	// RepairToolCallFunction defines a function that repairs a tool call.
 129	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 130)
 131
 132type agentSettings struct {
 133	systemPrompt     string
 134	maxOutputTokens  *int64
 135	temperature      *float64
 136	topP             *float64
 137	topK             *int64
 138	presencePenalty  *float64
 139	frequencyPenalty *float64
 140	headers          map[string]string
 141	userAgent        string
 142	providerOptions  ProviderOptions
 143
 144	providerDefinedTools []ProviderDefinedTool
 145	tools                []AgentTool
 146	maxRetries           *int
 147
 148	model LanguageModel
 149
 150	stopWhen       []StopCondition
 151	prepareStep    PrepareStepFunction
 152	repairToolCall RepairToolCallFunction
 153	onRetry        OnRetryCallback
 154}
 155
 156// AgentCall represents a call to an agent.
 157type AgentCall struct {
 158	Prompt           string     `json:"prompt"`
 159	Files            []FilePart `json:"files"`
 160	Messages         []Message  `json:"messages"`
 161	MaxOutputTokens  *int64
 162	Temperature      *float64 `json:"temperature"`
 163	TopP             *float64 `json:"top_p"`
 164	TopK             *int64   `json:"top_k"`
 165	PresencePenalty  *float64 `json:"presence_penalty"`
 166	FrequencyPenalty *float64 `json:"frequency_penalty"`
 167	ActiveTools      []string `json:"active_tools"`
 168	ProviderOptions  ProviderOptions
 169	OnRetry          OnRetryCallback
 170	MaxRetries       *int
 171
 172	StopWhen       []StopCondition
 173	PrepareStep    PrepareStepFunction
 174	RepairToolCall RepairToolCallFunction
 175}
 176
 177// Agent-level callbacks.
 178type (
 179	// OnAgentStartFunc is called when agent starts.
 180	OnAgentStartFunc func()
 181
 182	// OnAgentFinishFunc is called when agent finishes.
 183	OnAgentFinishFunc func(result *AgentResult) error
 184
 185	// OnStepStartFunc is called when a step starts.
 186	OnStepStartFunc func(stepNumber int) error
 187
 188	// OnStepFinishFunc is called when a step finishes.
 189	OnStepFinishFunc func(stepResult StepResult) error
 190
 191	// OnFinishFunc is called when entire agent completes.
 192	OnFinishFunc func(result *AgentResult)
 193
 194	// OnErrorFunc is called when an error occurs.
 195	OnErrorFunc func(error)
 196)
 197
 198// Stream part callbacks - called for each corresponding stream part type.
 199type (
 200	// OnChunkFunc is called for each stream part (catch-all).
 201	OnChunkFunc func(StreamPart) error
 202
 203	// OnWarningsFunc is called for warnings.
 204	OnWarningsFunc func(warnings []CallWarning) error
 205
 206	// OnTextStartFunc is called when text starts.
 207	OnTextStartFunc func(id string) error
 208
 209	// OnTextDeltaFunc is called for text deltas.
 210	OnTextDeltaFunc func(id, text string) error
 211
 212	// OnTextEndFunc is called when text ends.
 213	OnTextEndFunc func(id string) error
 214
 215	// OnReasoningStartFunc is called when reasoning starts.
 216	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 217
 218	// OnReasoningDeltaFunc is called for reasoning deltas.
 219	OnReasoningDeltaFunc func(id, text string) error
 220
 221	// OnReasoningEndFunc is called when reasoning ends.
 222	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 223
 224	// OnToolInputStartFunc is called when tool input starts.
 225	OnToolInputStartFunc func(id, toolName string) error
 226
 227	// OnToolInputDeltaFunc is called for tool input deltas.
 228	OnToolInputDeltaFunc func(id, delta string) error
 229
 230	// OnToolInputEndFunc is called when tool input ends.
 231	OnToolInputEndFunc func(id string) error
 232
 233	// OnToolCallFunc is called when tool call is complete.
 234	OnToolCallFunc func(toolCall ToolCallContent) error
 235
 236	// OnToolResultFunc is called when tool execution completes.
 237	OnToolResultFunc func(result ToolResultContent) error
 238
 239	// OnSourceFunc is called for source references.
 240	OnSourceFunc func(source SourceContent) error
 241
 242	// OnStreamFinishFunc is called when stream finishes.
 243	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 244)
 245
 246// AgentStreamCall represents a streaming call to an agent.
 247type AgentStreamCall struct {
 248	Prompt           string     `json:"prompt"`
 249	Files            []FilePart `json:"files"`
 250	Messages         []Message  `json:"messages"`
 251	MaxOutputTokens  *int64
 252	Temperature      *float64 `json:"temperature"`
 253	TopP             *float64 `json:"top_p"`
 254	TopK             *int64   `json:"top_k"`
 255	PresencePenalty  *float64 `json:"presence_penalty"`
 256	FrequencyPenalty *float64 `json:"frequency_penalty"`
 257	ActiveTools      []string `json:"active_tools"`
 258	Headers          map[string]string
 259	ProviderOptions  ProviderOptions
 260	OnRetry          OnRetryCallback
 261	MaxRetries       *int
 262
 263	StopWhen       []StopCondition
 264	PrepareStep    PrepareStepFunction
 265	RepairToolCall RepairToolCallFunction
 266
 267	// Agent-level callbacks
 268	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 269	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 270	OnStepStart   OnStepStartFunc   // Called when a step starts
 271	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 272	OnFinish      OnFinishFunc      // Called when entire agent completes
 273	OnError       OnErrorFunc       // Called when an error occurs
 274
 275	// Stream part callbacks - called for each corresponding stream part type
 276	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 277	OnWarnings       OnWarningsFunc       // Called for warnings
 278	OnTextStart      OnTextStartFunc      // Called when text starts
 279	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 280	OnTextEnd        OnTextEndFunc        // Called when text ends
 281	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 282	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 283	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 284	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 285	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 286	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 287	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 288	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 289	OnSource         OnSourceFunc         // Called for source references
 290	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 291}
 292
 293// AgentResult represents the result of an agent execution.
 294type AgentResult struct {
 295	Steps []StepResult
 296	// Final response
 297	Response   Response
 298	TotalUsage Usage
 299}
 300
 301// Agent represents an AI agent that can generate responses and stream responses.
 302type Agent interface {
 303	Generate(context.Context, AgentCall) (*AgentResult, error)
 304	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 305}
 306
 307// AgentOption defines a function that configures agent settings.
 308type AgentOption = func(*agentSettings)
 309
 310type agent struct {
 311	settings agentSettings
 312}
 313
 314// NewAgent creates a new agent with the given language model and options.
 315func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 316	settings := agentSettings{
 317		model: model,
 318	}
 319	for _, o := range opts {
 320		o(&settings)
 321	}
 322	return &agent{
 323		settings: settings,
 324	}
 325}
 326
 327func (a *agent) prepareCall(call AgentCall) AgentCall {
 328	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 329	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 330	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 331	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 332	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 333	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 334	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 335
 336	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 337		call.StopWhen = a.settings.stopWhen
 338	}
 339	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 340		call.PrepareStep = a.settings.prepareStep
 341	}
 342	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 343		call.RepairToolCall = a.settings.repairToolCall
 344	}
 345	if call.OnRetry == nil && a.settings.onRetry != nil {
 346		call.OnRetry = a.settings.onRetry
 347	}
 348
 349	providerOptions := ProviderOptions{}
 350	if a.settings.providerOptions != nil {
 351		maps.Copy(providerOptions, a.settings.providerOptions)
 352	}
 353	if call.ProviderOptions != nil {
 354		maps.Copy(providerOptions, call.ProviderOptions)
 355	}
 356	call.ProviderOptions = providerOptions
 357
 358	headers := map[string]string{}
 359
 360	if a.settings.headers != nil {
 361		maps.Copy(headers, a.settings.headers)
 362	}
 363
 364	return call
 365}
 366
 367// Generate implements Agent.
 368func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 369	opts = a.prepareCall(opts)
 370	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 371	if err != nil {
 372		return nil, err
 373	}
 374	var responseMessages []Message
 375	var steps []StepResult
 376
 377	for {
 378		stepInputMessages := append(initialPrompt, responseMessages...)
 379		stepModel := a.settings.model
 380		stepSystemPrompt := a.settings.systemPrompt
 381		stepActiveTools := opts.ActiveTools
 382		stepToolChoice := ToolChoiceAuto
 383		disableAllTools := false
 384		stepTools := a.settings.tools
 385		if opts.PrepareStep != nil {
 386			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 387				Model:      stepModel,
 388				Steps:      steps,
 389				StepNumber: len(steps),
 390				Messages:   stepInputMessages,
 391			})
 392			if err != nil {
 393				return nil, err
 394			}
 395
 396			ctx = updatedCtx
 397
 398			// Apply prepared step modifications
 399			if prepared.Messages != nil {
 400				stepInputMessages = prepared.Messages
 401			}
 402			if prepared.Model != nil {
 403				stepModel = prepared.Model
 404			}
 405			if prepared.System != nil {
 406				stepSystemPrompt = *prepared.System
 407			}
 408			if prepared.ToolChoice != nil {
 409				stepToolChoice = *prepared.ToolChoice
 410			}
 411			if len(prepared.ActiveTools) > 0 {
 412				stepActiveTools = prepared.ActiveTools
 413			}
 414			disableAllTools = prepared.DisableAllTools
 415			if prepared.Tools != nil {
 416				stepTools = prepared.Tools
 417			}
 418		}
 419
 420		// Recreate prompt with potentially modified system prompt
 421		if stepSystemPrompt != a.settings.systemPrompt {
 422			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 423			if err != nil {
 424				return nil, err
 425			}
 426			// Replace system message part, keep the rest
 427			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 428				stepInputMessages[0] = stepPrompt[0] // Replace system message
 429			}
 430		}
 431
 432		preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
 433
 434		retryOptions := DefaultRetryOptions()
 435		if opts.MaxRetries != nil {
 436			retryOptions.MaxRetries = *opts.MaxRetries
 437		}
 438		retryOptions.OnRetry = opts.OnRetry
 439		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 440
 441		result, err := retry(ctx, func() (*Response, error) {
 442			return stepModel.Generate(ctx, Call{
 443				Prompt:           stepInputMessages,
 444				MaxOutputTokens:  opts.MaxOutputTokens,
 445				Temperature:      opts.Temperature,
 446				TopP:             opts.TopP,
 447				TopK:             opts.TopK,
 448				PresencePenalty:  opts.PresencePenalty,
 449				FrequencyPenalty: opts.FrequencyPenalty,
 450				Tools:            preparedTools,
 451				ToolChoice:       &stepToolChoice,
 452				UserAgent:        a.settings.userAgent,
 453				ProviderOptions:  opts.ProviderOptions,
 454			})
 455		})
 456		if err != nil {
 457			return nil, err
 458		}
 459
 460		var stepToolCalls []ToolCallContent
 461		for _, content := range result.Content {
 462			if content.GetType() == ContentTypeToolCall {
 463				toolCall, ok := AsContentType[ToolCallContent](content)
 464				if !ok {
 465					continue
 466				}
 467				// Provider-executed tool calls (e.g. web search) are
 468				// handled by the provider and should not be validated
 469				// or executed by the agent.
 470				if toolCall.ProviderExecuted {
 471					continue
 472				}
 473				// Validate and potentially repair the tool call
 474				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 475				stepToolCalls = append(stepToolCalls, validatedToolCall)
 476			}
 477		}
 478
 479		toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
 480
 481		// Build step content with validated tool calls and tool results.
 482		// Provider-executed tool calls are kept as-is.
 483		stepContent := []Content{}
 484		toolCallIndex := 0
 485		for _, content := range result.Content {
 486			if content.GetType() == ContentTypeToolCall {
 487				tc, ok := AsContentType[ToolCallContent](content)
 488				if ok && tc.ProviderExecuted {
 489					stepContent = append(stepContent, content)
 490					continue
 491				}
 492				// Replace with validated tool call.
 493				if toolCallIndex < len(stepToolCalls) {
 494					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 495					toolCallIndex++
 496				}
 497			} else {
 498				stepContent = append(stepContent, content)
 499			}
 500		} // Add tool results
 501		for _, result := range toolResults {
 502			stepContent = append(stepContent, result)
 503		}
 504		currentStepMessages := toResponseMessages(stepContent)
 505		responseMessages = append(responseMessages, currentStepMessages...)
 506
 507		stepResult := StepResult{
 508			Response: Response{
 509				Content:          stepContent,
 510				FinishReason:     result.FinishReason,
 511				Usage:            result.Usage,
 512				Warnings:         result.Warnings,
 513				ProviderMetadata: result.ProviderMetadata,
 514			},
 515			Messages: currentStepMessages,
 516		}
 517		steps = append(steps, stepResult)
 518		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 519
 520		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 521			break
 522		}
 523	}
 524
 525	totalUsage := Usage{}
 526
 527	for _, step := range steps {
 528		usage := step.Usage
 529		totalUsage.InputTokens += usage.InputTokens
 530		totalUsage.OutputTokens += usage.OutputTokens
 531		totalUsage.ReasoningTokens += usage.ReasoningTokens
 532		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 533		totalUsage.CacheReadTokens += usage.CacheReadTokens
 534		totalUsage.TotalTokens += usage.TotalTokens
 535	}
 536
 537	agentResult := &AgentResult{
 538		Steps:      steps,
 539		Response:   steps[len(steps)-1].Response,
 540		TotalUsage: totalUsage,
 541	}
 542	return agentResult, nil
 543}
 544
 545func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 546	if len(conditions) == 0 {
 547		return false
 548	}
 549
 550	for _, condition := range conditions {
 551		if condition(steps) {
 552			return true
 553		}
 554	}
 555	return false
 556}
 557
 558func toResponseMessages(content []Content) []Message {
 559	var assistantParts []MessagePart
 560	var toolParts []MessagePart
 561
 562	for _, c := range content {
 563		switch c.GetType() {
 564		case ContentTypeText:
 565			text, ok := AsContentType[TextContent](c)
 566			if !ok {
 567				continue
 568			}
 569			assistantParts = append(assistantParts, TextPart{
 570				Text:            text.Text,
 571				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 572			})
 573		case ContentTypeReasoning:
 574			reasoning, ok := AsContentType[ReasoningContent](c)
 575			if !ok {
 576				continue
 577			}
 578			assistantParts = append(assistantParts, ReasoningPart{
 579				Text:            reasoning.Text,
 580				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 581			})
 582		case ContentTypeToolCall:
 583			toolCall, ok := AsContentType[ToolCallContent](c)
 584			if !ok {
 585				continue
 586			}
 587			assistantParts = append(assistantParts, ToolCallPart{
 588				ToolCallID:       toolCall.ToolCallID,
 589				ToolName:         toolCall.ToolName,
 590				Input:            toolCall.Input,
 591				ProviderExecuted: toolCall.ProviderExecuted,
 592				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 593			})
 594		case ContentTypeFile:
 595			file, ok := AsContentType[FileContent](c)
 596			if !ok {
 597				continue
 598			}
 599			assistantParts = append(assistantParts, FilePart{
 600				Data:            file.Data,
 601				MediaType:       file.MediaType,
 602				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 603			})
 604		case ContentTypeSource:
 605			// Sources are metadata about references used to generate the response.
 606			// They don't need to be included in the conversation messages.
 607			continue
 608		case ContentTypeToolResult:
 609			result, ok := AsContentType[ToolResultContent](c)
 610			if !ok {
 611				continue
 612			}
 613			resultPart := ToolResultPart{
 614				ToolCallID:       result.ToolCallID,
 615				Output:           result.Result,
 616				ProviderExecuted: result.ProviderExecuted,
 617				ProviderOptions:  ProviderOptions(result.ProviderMetadata),
 618			}
 619			if result.ProviderExecuted {
 620				// Provider-executed tool results (e.g. web search)
 621				// belong in the assistant message alongside the
 622				// server_tool_use block that produced them.
 623				assistantParts = append(assistantParts, resultPart)
 624			} else {
 625				toolParts = append(toolParts, resultPart)
 626			}
 627		}
 628	}
 629
 630	var messages []Message
 631	if len(assistantParts) > 0 {
 632		messages = append(messages, Message{
 633			Role:    MessageRoleAssistant,
 634			Content: assistantParts,
 635		})
 636	}
 637	if len(toolParts) > 0 {
 638		messages = append(messages, Message{
 639			Role:    MessageRoleTool,
 640			Content: toolParts,
 641		})
 642	}
 643	return messages
 644}
 645
 646func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 647	if len(toolCalls) == 0 {
 648		return nil, nil
 649	}
 650
 651	// Create a map for quick tool lookup
 652	toolMap := make(map[string]AgentTool)
 653	for _, tool := range allTools {
 654		toolMap[tool.Info().Name] = tool
 655	}
 656
 657	// Execute all tool calls sequentially in order
 658	results := make([]ToolResultContent, 0, len(toolCalls))
 659
 660	for _, toolCall := range toolCalls {
 661		result, isCriticalError := a.executeSingleTool(ctx, toolMap, toolCall, toolResultCallback)
 662		results = append(results, result)
 663		if isCriticalError {
 664			if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
 665				return nil, errorResult.Error
 666			}
 667		}
 668	}
 669
 670	return results, nil
 671}
 672
 673// executeSingleTool executes a single tool and returns its result and a critical error flag.
 674func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
 675	result := ToolResultContent{
 676		ToolCallID:       toolCall.ToolCallID,
 677		ToolName:         toolCall.ToolName,
 678		ProviderExecuted: false,
 679	}
 680
 681	// Skip invalid tool calls - create error result (not critical)
 682	if toolCall.Invalid {
 683		result.Result = ToolResultOutputContentError{
 684			Error: toolCall.ValidationError,
 685		}
 686		if toolResultCallback != nil {
 687			_ = toolResultCallback(result)
 688		}
 689		return result, false
 690	}
 691
 692	tool, exists := toolMap[toolCall.ToolName]
 693	if !exists {
 694		result.Result = ToolResultOutputContentError{
 695			Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
 696		}
 697		if toolResultCallback != nil {
 698			_ = toolResultCallback(result)
 699		}
 700		return result, false
 701	}
 702
 703	// Execute the tool
 704	toolResult, err := tool.Run(ctx, ToolCall{
 705		ID:    toolCall.ToolCallID,
 706		Name:  toolCall.ToolName,
 707		Input: toolCall.Input,
 708	})
 709	if err != nil {
 710		result.Result = ToolResultOutputContentError{
 711			Error: err,
 712		}
 713		result.ClientMetadata = toolResult.Metadata
 714		if toolResultCallback != nil {
 715			_ = toolResultCallback(result)
 716		}
 717		return result, true
 718	}
 719
 720	result.ClientMetadata = toolResult.Metadata
 721	if toolResult.IsError {
 722		result.Result = ToolResultOutputContentError{
 723			Error: errors.New(toolResult.Content),
 724		}
 725	} else if toolResult.Type == "image" || toolResult.Type == "media" {
 726		result.Result = ToolResultOutputContentMedia{
 727			Data:      string(toolResult.Data),
 728			MediaType: toolResult.MediaType,
 729			Text:      toolResult.Content,
 730		}
 731	} else {
 732		result.Result = ToolResultOutputContentText{
 733			Text: toolResult.Content,
 734		}
 735	}
 736	if toolResultCallback != nil {
 737		_ = toolResultCallback(result)
 738	}
 739	return result, false
 740}
 741
 742// Stream implements Agent.
 743func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 744	// Convert AgentStreamCall to AgentCall for preparation
 745	call := AgentCall{
 746		Prompt:           opts.Prompt,
 747		Files:            opts.Files,
 748		Messages:         opts.Messages,
 749		MaxOutputTokens:  opts.MaxOutputTokens,
 750		Temperature:      opts.Temperature,
 751		TopP:             opts.TopP,
 752		TopK:             opts.TopK,
 753		PresencePenalty:  opts.PresencePenalty,
 754		FrequencyPenalty: opts.FrequencyPenalty,
 755		ActiveTools:      opts.ActiveTools,
 756		ProviderOptions:  opts.ProviderOptions,
 757		MaxRetries:       opts.MaxRetries,
 758		OnRetry:          opts.OnRetry,
 759		StopWhen:         opts.StopWhen,
 760		PrepareStep:      opts.PrepareStep,
 761		RepairToolCall:   opts.RepairToolCall,
 762	}
 763
 764	call = a.prepareCall(call)
 765
 766	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 767	if err != nil {
 768		return nil, err
 769	}
 770
 771	var responseMessages []Message
 772	var steps []StepResult
 773	var totalUsage Usage
 774
 775	// Start agent stream
 776	if opts.OnAgentStart != nil {
 777		opts.OnAgentStart()
 778	}
 779
 780	for stepNumber := 0; ; stepNumber++ {
 781		stepInputMessages := append(initialPrompt, responseMessages...)
 782		stepModel := a.settings.model
 783		stepSystemPrompt := a.settings.systemPrompt
 784		stepActiveTools := call.ActiveTools
 785		stepToolChoice := ToolChoiceAuto
 786		disableAllTools := false
 787		stepTools := a.settings.tools
 788		// Apply step preparation if provided
 789		if call.PrepareStep != nil {
 790			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 791				Model:      stepModel,
 792				Steps:      steps,
 793				StepNumber: stepNumber,
 794				Messages:   stepInputMessages,
 795			})
 796			if err != nil {
 797				return nil, err
 798			}
 799
 800			ctx = updatedCtx
 801
 802			if prepared.Messages != nil {
 803				stepInputMessages = prepared.Messages
 804			}
 805			if prepared.Model != nil {
 806				stepModel = prepared.Model
 807			}
 808			if prepared.System != nil {
 809				stepSystemPrompt = *prepared.System
 810			}
 811			if prepared.ToolChoice != nil {
 812				stepToolChoice = *prepared.ToolChoice
 813			}
 814			if len(prepared.ActiveTools) > 0 {
 815				stepActiveTools = prepared.ActiveTools
 816			}
 817			disableAllTools = prepared.DisableAllTools
 818			if prepared.Tools != nil {
 819				stepTools = prepared.Tools
 820			}
 821		}
 822
 823		// Recreate prompt with potentially modified system prompt
 824		if stepSystemPrompt != a.settings.systemPrompt {
 825			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 826			if err != nil {
 827				return nil, err
 828			}
 829			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 830				stepInputMessages[0] = stepPrompt[0]
 831			}
 832		}
 833
 834		preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
 835
 836		// Start step stream
 837		if opts.OnStepStart != nil {
 838			_ = opts.OnStepStart(stepNumber)
 839		}
 840
 841		// Create streaming call
 842		streamCall := Call{
 843			Prompt:           stepInputMessages,
 844			MaxOutputTokens:  call.MaxOutputTokens,
 845			Temperature:      call.Temperature,
 846			TopP:             call.TopP,
 847			TopK:             call.TopK,
 848			PresencePenalty:  call.PresencePenalty,
 849			FrequencyPenalty: call.FrequencyPenalty,
 850			Tools:            preparedTools,
 851			ToolChoice:       &stepToolChoice,
 852			UserAgent:        a.settings.userAgent,
 853			ProviderOptions:  call.ProviderOptions,
 854		}
 855
 856		// Execute step with retry logic wrapping both stream creation and processing
 857		retryOptions := DefaultRetryOptions()
 858		if call.MaxRetries != nil {
 859			retryOptions.MaxRetries = *call.MaxRetries
 860		}
 861		retryOptions.OnRetry = call.OnRetry
 862		retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
 863
 864		result, err := retry(ctx, func() (stepExecutionResult, error) {
 865			// Create the stream
 866			stream, err := stepModel.Stream(ctx, streamCall)
 867			if err != nil {
 868				return stepExecutionResult{}, err
 869			}
 870
 871			// Process the stream
 872			result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
 873			if err != nil {
 874				return stepExecutionResult{}, err
 875			}
 876
 877			return result, nil
 878		})
 879		if err != nil {
 880			if opts.OnError != nil {
 881				opts.OnError(err)
 882			}
 883			return nil, err
 884		}
 885
 886		steps = append(steps, result.StepResult)
 887		totalUsage = addUsage(totalUsage, result.StepResult.Usage)
 888
 889		// Call step finished callback
 890		if opts.OnStepFinish != nil {
 891			_ = opts.OnStepFinish(result.StepResult)
 892		}
 893
 894		// Add step messages to response messages
 895		stepMessages := toResponseMessages(result.StepResult.Content)
 896		responseMessages = append(responseMessages, stepMessages...)
 897
 898		// Check stop conditions
 899		shouldStop := isStopConditionMet(call.StopWhen, steps)
 900		if shouldStop || !result.ShouldContinue {
 901			break
 902		}
 903	}
 904
 905	// Finish agent stream
 906	agentResult := &AgentResult{
 907		Steps:      steps,
 908		Response:   steps[len(steps)-1].Response,
 909		TotalUsage: totalUsage,
 910	}
 911
 912	if opts.OnFinish != nil {
 913		opts.OnFinish(agentResult)
 914	}
 915
 916	if opts.OnAgentFinish != nil {
 917		_ = opts.OnAgentFinish(agentResult)
 918	}
 919
 920	return agentResult, nil
 921}
 922
 923func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
 924	preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
 925
 926	// If explicitly disabling all tools, return no tools
 927	if disableAllTools {
 928		return preparedTools
 929	}
 930
 931	for _, tool := range tools {
 932		// If activeTools has items, only include tools in the list
 933		// If activeTools is empty, include all tools
 934		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 935			continue
 936		}
 937		info := tool.Info()
 938		inputSchema := map[string]any{
 939			"type":       "object",
 940			"properties": info.Parameters,
 941			"required":   info.Required,
 942		}
 943		schema.Normalize(inputSchema)
 944		preparedTools = append(preparedTools, FunctionTool{
 945			Name:            info.Name,
 946			Description:     info.Description,
 947			InputSchema:     inputSchema,
 948			ProviderOptions: tool.ProviderOptions(),
 949		})
 950	}
 951	for _, tool := range providerDefinedTools {
 952		// If activeTools has items, only include tools in the list. If
 953		// activeTools is empty, include all tools
 954		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
 955			continue
 956		}
 957		preparedTools = append(preparedTools, tool)
 958	}
 959	return preparedTools
 960}
 961
 962// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 963func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 964	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 965		return toolCall
 966	} else { //nolint: revive
 967		if repairFunc != nil {
 968			repairOptions := ToolCallRepairOptions{
 969				OriginalToolCall: toolCall,
 970				ValidationError:  err,
 971				AvailableTools:   availableTools,
 972				SystemPrompt:     systemPrompt,
 973				Messages:         messages,
 974			}
 975
 976			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 977				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 978					return *repairedToolCall
 979				}
 980			}
 981		}
 982
 983		invalidToolCall := toolCall
 984		invalidToolCall.Invalid = true
 985		invalidToolCall.ValidationError = err
 986		return invalidToolCall
 987	}
 988}
 989
 990// validateToolCall validates a tool call against available tools and their schemas.
 991func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 992	var tool AgentTool
 993	for _, t := range availableTools {
 994		if t.Info().Name == toolCall.ToolName {
 995			tool = t
 996			break
 997		}
 998	}
 999
1000	if tool == nil {
1001		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1002	}
1003
1004	// Validate JSON parsing
1005	var input map[string]any
1006	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1007		return fmt.Errorf("invalid JSON input: %w", err)
1008	}
1009
1010	// Basic schema validation (check required fields)
1011	// TODO: more robust schema validation using JSON Schema or similar
1012	toolInfo := tool.Info()
1013	for _, required := range toolInfo.Required {
1014		if _, exists := input[required]; !exists {
1015			return fmt.Errorf("missing required parameter: %s", required)
1016		}
1017	}
1018	return nil
1019}
1020
1021func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1022	if prompt == "" {
1023		return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
1024	}
1025
1026	var preparedPrompt Prompt
1027
1028	if system != "" {
1029		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1030	}
1031	preparedPrompt = append(preparedPrompt, messages...)
1032	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1033	return preparedPrompt, nil
1034}
1035
1036// WithSystemPrompt sets the system prompt for the agent.
1037func WithSystemPrompt(prompt string) AgentOption {
1038	return func(s *agentSettings) {
1039		s.systemPrompt = prompt
1040	}
1041}
1042
1043// WithMaxOutputTokens sets the maximum output tokens for the agent.
1044func WithMaxOutputTokens(tokens int64) AgentOption {
1045	return func(s *agentSettings) {
1046		s.maxOutputTokens = &tokens
1047	}
1048}
1049
1050// WithTemperature sets the temperature for the agent.
1051func WithTemperature(temp float64) AgentOption {
1052	return func(s *agentSettings) {
1053		s.temperature = &temp
1054	}
1055}
1056
1057// WithTopP sets the top-p value for the agent.
1058func WithTopP(topP float64) AgentOption {
1059	return func(s *agentSettings) {
1060		s.topP = &topP
1061	}
1062}
1063
1064// WithTopK sets the top-k value for the agent.
1065func WithTopK(topK int64) AgentOption {
1066	return func(s *agentSettings) {
1067		s.topK = &topK
1068	}
1069}
1070
1071// WithPresencePenalty sets the presence penalty for the agent.
1072func WithPresencePenalty(penalty float64) AgentOption {
1073	return func(s *agentSettings) {
1074		s.presencePenalty = &penalty
1075	}
1076}
1077
1078// WithFrequencyPenalty sets the frequency penalty for the agent.
1079func WithFrequencyPenalty(penalty float64) AgentOption {
1080	return func(s *agentSettings) {
1081		s.frequencyPenalty = &penalty
1082	}
1083}
1084
1085// WithTools sets the tools for the agent.
1086func WithTools(tools ...AgentTool) AgentOption {
1087	return func(s *agentSettings) {
1088		s.tools = append(s.tools, tools...)
1089	}
1090}
1091
1092// WithProviderDefinedTools sets the provider-defined tools for the agent.
1093// These tools are executed by the provider (e.g. web search) rather
1094// than by the client.
1095func WithProviderDefinedTools(tools ...ProviderDefinedTool) AgentOption {
1096	return func(s *agentSettings) {
1097		s.providerDefinedTools = append(s.providerDefinedTools, tools...)
1098	}
1099}
1100
1101// WithStopConditions sets the stop conditions for the agent.
1102func WithStopConditions(conditions ...StopCondition) AgentOption {
1103	return func(s *agentSettings) {
1104		s.stopWhen = append(s.stopWhen, conditions...)
1105	}
1106}
1107
1108// WithPrepareStep sets the prepare step function for the agent.
1109func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1110	return func(s *agentSettings) {
1111		s.prepareStep = fn
1112	}
1113}
1114
1115// WithRepairToolCall sets the repair tool call function for the agent.
1116func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1117	return func(s *agentSettings) {
1118		s.repairToolCall = fn
1119	}
1120}
1121
1122// WithMaxRetries sets the maximum number of retries for the agent.
1123func WithMaxRetries(maxRetries int) AgentOption {
1124	return func(s *agentSettings) {
1125		s.maxRetries = &maxRetries
1126	}
1127}
1128
1129// WithOnRetry sets the retry callback for the agent.
1130func WithOnRetry(callback OnRetryCallback) AgentOption {
1131	return func(s *agentSettings) {
1132		s.onRetry = callback
1133	}
1134}
1135
1136// processStepStream processes a single step's stream and returns the step result.
1137func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1138	var stepContent []Content
1139	var stepToolCalls []ToolCallContent
1140	var stepUsage Usage
1141	stepFinishReason := FinishReasonUnknown
1142	var stepWarnings []CallWarning
1143	var stepProviderMetadata ProviderMetadata
1144
1145	activeToolCalls := make(map[string]*ToolCallContent)
1146	activeTextContent := make(map[string]string)
1147	type reasoningContent struct {
1148		content string
1149		options ProviderMetadata
1150	}
1151	activeReasoningContent := make(map[string]reasoningContent)
1152
1153	// Set up concurrent tool execution
1154	type toolExecutionRequest struct {
1155		toolCall ToolCallContent
1156		parallel bool
1157	}
1158	toolChan := make(chan toolExecutionRequest, 10)
1159	var toolExecutionWg sync.WaitGroup
1160	var toolStateMu sync.Mutex
1161	toolResults := make([]ToolResultContent, 0)
1162	var toolExecutionErr error
1163
1164	// Create a map for quick tool lookup
1165	toolMap := make(map[string]AgentTool)
1166	for _, tool := range stepTools {
1167		toolMap[tool.Info().Name] = tool
1168	}
1169
1170	// Semaphores for controlling parallelism
1171	parallelSem := make(chan struct{}, 5)
1172	var sequentialMu sync.Mutex
1173
1174	// Single coordinator goroutine that dispatches tools
1175	toolExecutionWg.Go(func() {
1176		for req := range toolChan {
1177			if req.parallel {
1178				parallelSem <- struct{}{}
1179				toolExecutionWg.Go(func() {
1180					defer func() { <-parallelSem }()
1181					result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1182					toolStateMu.Lock()
1183					toolResults = append(toolResults, result)
1184					if isCriticalError && toolExecutionErr == nil {
1185						if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1186							toolExecutionErr = errorResult.Error
1187						}
1188					}
1189					toolStateMu.Unlock()
1190				})
1191			} else {
1192				sequentialMu.Lock()
1193				result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1194				toolStateMu.Lock()
1195				toolResults = append(toolResults, result)
1196				if isCriticalError && toolExecutionErr == nil {
1197					if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1198						toolExecutionErr = errorResult.Error
1199					}
1200				}
1201				toolStateMu.Unlock()
1202				sequentialMu.Unlock()
1203			}
1204		}
1205	})
1206
1207	// Process stream parts
1208	for part := range stream {
1209		// Forward all parts to chunk callback
1210		if opts.OnChunk != nil {
1211			err := opts.OnChunk(part)
1212			if err != nil {
1213				return stepExecutionResult{}, err
1214			}
1215		}
1216
1217		switch part.Type {
1218		case StreamPartTypeWarnings:
1219			stepWarnings = part.Warnings
1220			if opts.OnWarnings != nil {
1221				err := opts.OnWarnings(part.Warnings)
1222				if err != nil {
1223					return stepExecutionResult{}, err
1224				}
1225			}
1226
1227		case StreamPartTypeTextStart:
1228			activeTextContent[part.ID] = ""
1229			if opts.OnTextStart != nil {
1230				err := opts.OnTextStart(part.ID)
1231				if err != nil {
1232					return stepExecutionResult{}, err
1233				}
1234			}
1235
1236		case StreamPartTypeTextDelta:
1237			if _, exists := activeTextContent[part.ID]; exists {
1238				activeTextContent[part.ID] += part.Delta
1239			}
1240			if opts.OnTextDelta != nil {
1241				err := opts.OnTextDelta(part.ID, part.Delta)
1242				if err != nil {
1243					return stepExecutionResult{}, err
1244				}
1245			}
1246
1247		case StreamPartTypeTextEnd:
1248			if text, exists := activeTextContent[part.ID]; exists {
1249				stepContent = append(stepContent, TextContent{
1250					Text:             text,
1251					ProviderMetadata: part.ProviderMetadata,
1252				})
1253				delete(activeTextContent, part.ID)
1254			}
1255			if opts.OnTextEnd != nil {
1256				err := opts.OnTextEnd(part.ID)
1257				if err != nil {
1258					return stepExecutionResult{}, err
1259				}
1260			}
1261
1262		case StreamPartTypeReasoningStart:
1263			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1264			if opts.OnReasoningStart != nil {
1265				content := ReasoningContent{
1266					Text:             part.Delta,
1267					ProviderMetadata: part.ProviderMetadata,
1268				}
1269				err := opts.OnReasoningStart(part.ID, content)
1270				if err != nil {
1271					return stepExecutionResult{}, err
1272				}
1273			}
1274
1275		case StreamPartTypeReasoningDelta:
1276			if active, exists := activeReasoningContent[part.ID]; exists {
1277				active.content += part.Delta
1278				if part.ProviderMetadata != nil {
1279					active.options = part.ProviderMetadata
1280				}
1281				activeReasoningContent[part.ID] = active
1282			}
1283			if opts.OnReasoningDelta != nil {
1284				err := opts.OnReasoningDelta(part.ID, part.Delta)
1285				if err != nil {
1286					return stepExecutionResult{}, err
1287				}
1288			}
1289
1290		case StreamPartTypeReasoningEnd:
1291			if active, exists := activeReasoningContent[part.ID]; exists {
1292				if part.ProviderMetadata != nil {
1293					active.options = part.ProviderMetadata
1294				}
1295				content := ReasoningContent{
1296					Text:             active.content,
1297					ProviderMetadata: active.options,
1298				}
1299				stepContent = append(stepContent, content)
1300				if opts.OnReasoningEnd != nil {
1301					err := opts.OnReasoningEnd(part.ID, content)
1302					if err != nil {
1303						return stepExecutionResult{}, err
1304					}
1305				}
1306				delete(activeReasoningContent, part.ID)
1307			}
1308
1309		case StreamPartTypeToolInputStart:
1310			activeToolCalls[part.ID] = &ToolCallContent{
1311				ToolCallID:       part.ID,
1312				ToolName:         part.ToolCallName,
1313				Input:            "",
1314				ProviderExecuted: part.ProviderExecuted,
1315			}
1316			if opts.OnToolInputStart != nil {
1317				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1318				if err != nil {
1319					return stepExecutionResult{}, err
1320				}
1321			}
1322
1323		case StreamPartTypeToolInputDelta:
1324			if toolCall, exists := activeToolCalls[part.ID]; exists {
1325				toolCall.Input += part.Delta
1326			}
1327			if opts.OnToolInputDelta != nil {
1328				err := opts.OnToolInputDelta(part.ID, part.Delta)
1329				if err != nil {
1330					return stepExecutionResult{}, err
1331				}
1332			}
1333
1334		case StreamPartTypeToolInputEnd:
1335			if opts.OnToolInputEnd != nil {
1336				err := opts.OnToolInputEnd(part.ID)
1337				if err != nil {
1338					return stepExecutionResult{}, err
1339				}
1340			}
1341
1342		case StreamPartTypeToolCall:
1343			toolCall := ToolCallContent{
1344				ToolCallID:       part.ID,
1345				ToolName:         part.ToolCallName,
1346				Input:            part.ToolCallInput,
1347				ProviderExecuted: part.ProviderExecuted,
1348				ProviderMetadata: part.ProviderMetadata,
1349			}
1350
1351			// Provider-executed tool calls are handled by the provider
1352			// and should not be validated or executed by the agent.
1353			if toolCall.ProviderExecuted {
1354				stepContent = append(stepContent, toolCall)
1355				if opts.OnToolCall != nil {
1356					err := opts.OnToolCall(toolCall)
1357					if err != nil {
1358						return stepExecutionResult{}, err
1359					}
1360				}
1361				delete(activeToolCalls, part.ID)
1362			} else {
1363				// Validate and potentially repair the tool call
1364				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1365				stepToolCalls = append(stepToolCalls, validatedToolCall)
1366				stepContent = append(stepContent, validatedToolCall)
1367
1368				if opts.OnToolCall != nil {
1369					err := opts.OnToolCall(validatedToolCall)
1370					if err != nil {
1371						return stepExecutionResult{}, err
1372					}
1373				}
1374
1375				// Determine if tool can run in parallel
1376				isParallel := false
1377				if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1378					isParallel = tool.Info().Parallel
1379				}
1380
1381				// Send tool call to execution channel
1382				toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1383
1384				// Clean up active tool call
1385				delete(activeToolCalls, part.ID)
1386			}
1387
1388		case StreamPartTypeToolResult:
1389			// Provider-executed tool results (e.g. web search)
1390			// are emitted by the provider and added directly
1391			// to the step content for multi-turn round-tripping.
1392			if part.ProviderExecuted {
1393				resultContent := ToolResultContent{
1394					ToolCallID:       part.ID,
1395					ToolName:         part.ToolCallName,
1396					ProviderExecuted: true,
1397					ProviderMetadata: part.ProviderMetadata,
1398				}
1399				stepContent = append(stepContent, resultContent)
1400				if opts.OnToolResult != nil {
1401					err := opts.OnToolResult(resultContent)
1402					if err != nil {
1403						return stepExecutionResult{}, err
1404					}
1405				}
1406			}
1407
1408		case StreamPartTypeSource:
1409			sourceContent := SourceContent{
1410				SourceType:       part.SourceType,
1411				ID:               part.ID,
1412				URL:              part.URL,
1413				Title:            part.Title,
1414				ProviderMetadata: part.ProviderMetadata,
1415			}
1416			stepContent = append(stepContent, sourceContent)
1417			if opts.OnSource != nil {
1418				err := opts.OnSource(sourceContent)
1419				if err != nil {
1420					return stepExecutionResult{}, err
1421				}
1422			}
1423
1424		case StreamPartTypeFinish:
1425			stepUsage = part.Usage
1426			stepFinishReason = part.FinishReason
1427			stepProviderMetadata = part.ProviderMetadata
1428			if opts.OnStreamFinish != nil {
1429				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1430				if err != nil {
1431					return stepExecutionResult{}, err
1432				}
1433			}
1434
1435		case StreamPartTypeError:
1436			return stepExecutionResult{}, part.Error
1437		}
1438	}
1439
1440	// Close the tool execution channel and wait for all executions to complete
1441	close(toolChan)
1442	toolExecutionWg.Wait()
1443
1444	// Check for tool execution errors
1445	if toolExecutionErr != nil {
1446		return stepExecutionResult{}, toolExecutionErr
1447	}
1448
1449	// Add tool results to content if any
1450	if len(toolResults) > 0 {
1451		for _, result := range toolResults {
1452			stepContent = append(stepContent, result)
1453		}
1454	}
1455
1456	stepResult := StepResult{
1457		Response: Response{
1458			Content:          stepContent,
1459			FinishReason:     stepFinishReason,
1460			Usage:            stepUsage,
1461			Warnings:         stepWarnings,
1462			ProviderMetadata: stepProviderMetadata,
1463		},
1464		Messages: toResponseMessages(stepContent),
1465	}
1466
1467	// Determine if we should continue (has tool calls and not stopped)
1468	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1469
1470	return stepExecutionResult{
1471		StepResult:     stepResult,
1472		ShouldContinue: shouldContinue,
1473	}, nil
1474}
1475
1476func addUsage(a, b Usage) Usage {
1477	return Usage{
1478		InputTokens:         a.InputTokens + b.InputTokens,
1479		OutputTokens:        a.OutputTokens + b.OutputTokens,
1480		TotalTokens:         a.TotalTokens + b.TotalTokens,
1481		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1482		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1483		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1484	}
1485}
1486
1487// WithHeaders sets the headers for the agent.
1488func WithHeaders(headers map[string]string) AgentOption {
1489	return func(s *agentSettings) {
1490		s.headers = headers
1491	}
1492}
1493
1494// WithUserAgent sets the User-Agent header for the agent. This overrides any
1495// provider-level User-Agent setting.
1496func WithUserAgent(ua string) AgentOption {
1497	return func(s *agentSettings) {
1498		s.userAgent = ua
1499	}
1500}
1501
1502// WithProviderOptions sets the provider options for the agent.
1503func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1504	return func(s *agentSettings) {
1505		s.providerOptions = providerOptions
1506	}
1507}