agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11	"sync"
  12
  13	"charm.land/fantasy/schema"
  14	"github.com/charmbracelet/x/exp/slice"
  15)
  16
  17// StepResult represents the result of a single step in an agent execution.
  18type StepResult struct {
  19	Response
  20	Messages []Message
  21}
  22
  23// stepExecutionResult encapsulates the result of executing a step with stream processing.
  24type stepExecutionResult struct {
  25	StepResult     StepResult
  26	ShouldContinue bool
  27}
  28
  29// StopCondition defines a function that determines when an agent should stop executing.
  30type StopCondition = func(steps []StepResult) bool
  31
  32// StepCountIs returns a stop condition that stops after the specified number of steps.
  33func StepCountIs(stepCount int) StopCondition {
  34	return func(steps []StepResult) bool {
  35		return len(steps) >= stepCount
  36	}
  37}
  38
  39// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  40func HasToolCall(toolName string) StopCondition {
  41	return func(steps []StepResult) bool {
  42		if len(steps) == 0 {
  43			return false
  44		}
  45		lastStep := steps[len(steps)-1]
  46		toolCalls := lastStep.Content.ToolCalls()
  47		for _, toolCall := range toolCalls {
  48			if toolCall.ToolName == toolName {
  49				return true
  50			}
  51		}
  52		return false
  53	}
  54}
  55
  56// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  57func HasContent(contentType ContentType) StopCondition {
  58	return func(steps []StepResult) bool {
  59		if len(steps) == 0 {
  60			return false
  61		}
  62		lastStep := steps[len(steps)-1]
  63		for _, content := range lastStep.Content {
  64			if content.GetType() == contentType {
  65				return true
  66			}
  67		}
  68		return false
  69	}
  70}
  71
  72// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  73func FinishReasonIs(reason FinishReason) StopCondition {
  74	return func(steps []StepResult) bool {
  75		if len(steps) == 0 {
  76			return false
  77		}
  78		lastStep := steps[len(steps)-1]
  79		return lastStep.FinishReason == reason
  80	}
  81}
  82
  83// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  84func MaxTokensUsed(maxTokens int64) StopCondition {
  85	return func(steps []StepResult) bool {
  86		var totalTokens int64
  87		for _, step := range steps {
  88			totalTokens += step.Usage.TotalTokens
  89		}
  90		return totalTokens >= maxTokens
  91	}
  92}
  93
  94// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  95type PrepareStepFunctionOptions struct {
  96	Steps      []StepResult
  97	StepNumber int
  98	Model      LanguageModel
  99	Messages   []Message
 100}
 101
 102// PrepareStepResult contains the result of preparing a step in an agent execution.
 103type PrepareStepResult struct {
 104	Model           LanguageModel
 105	Messages        []Message
 106	System          *string
 107	ToolChoice      *ToolChoice
 108	ActiveTools     []string
 109	DisableAllTools bool
 110	Tools           []AgentTool
 111}
 112
 113// ToolCallRepairOptions contains the options for repairing a tool call.
 114type ToolCallRepairOptions struct {
 115	OriginalToolCall ToolCallContent
 116	ValidationError  error
 117	AvailableTools   []AgentTool
 118	SystemPrompt     string
 119	Messages         []Message
 120}
 121
 122type (
 123	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 124	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 125
 126	// OnStepFinishedFunction defines a function that is called when a step finishes.
 127	OnStepFinishedFunction = func(step StepResult)
 128
 129	// RepairToolCallFunction defines a function that repairs a tool call.
 130	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 131)
 132
 133type agentSettings struct {
 134	systemPrompt     string
 135	maxOutputTokens  *int64
 136	temperature      *float64
 137	topP             *float64
 138	topK             *int64
 139	presencePenalty  *float64
 140	frequencyPenalty *float64
 141	headers          map[string]string
 142	userAgent        string
 143	providerOptions  ProviderOptions
 144
 145	providerDefinedTools []ProviderDefinedTool
 146	tools                []AgentTool
 147	maxRetries           *int
 148
 149	model LanguageModel
 150
 151	stopWhen       []StopCondition
 152	prepareStep    PrepareStepFunction
 153	repairToolCall RepairToolCallFunction
 154	onRetry        OnRetryCallback
 155}
 156
 157// AgentCall represents a call to an agent.
 158type AgentCall struct {
 159	Prompt           string     `json:"prompt"`
 160	Files            []FilePart `json:"files"`
 161	Messages         []Message  `json:"messages"`
 162	MaxOutputTokens  *int64
 163	Temperature      *float64 `json:"temperature"`
 164	TopP             *float64 `json:"top_p"`
 165	TopK             *int64   `json:"top_k"`
 166	PresencePenalty  *float64 `json:"presence_penalty"`
 167	FrequencyPenalty *float64 `json:"frequency_penalty"`
 168	ActiveTools      []string `json:"active_tools"`
 169	ProviderOptions  ProviderOptions
 170	OnRetry          OnRetryCallback
 171	MaxRetries       *int
 172
 173	StopWhen       []StopCondition
 174	PrepareStep    PrepareStepFunction
 175	RepairToolCall RepairToolCallFunction
 176}
 177
 178// Agent-level callbacks.
 179type (
 180	// OnAgentStartFunc is called when agent starts.
 181	OnAgentStartFunc func()
 182
 183	// OnAgentFinishFunc is called when agent finishes.
 184	OnAgentFinishFunc func(result *AgentResult) error
 185
 186	// OnStepStartFunc is called when a step starts.
 187	OnStepStartFunc func(stepNumber int) error
 188
 189	// OnStepFinishFunc is called when a step finishes.
 190	OnStepFinishFunc func(stepResult StepResult) error
 191
 192	// OnFinishFunc is called when entire agent completes.
 193	OnFinishFunc func(result *AgentResult)
 194
 195	// OnErrorFunc is called when an error occurs.
 196	OnErrorFunc func(error)
 197)
 198
 199// Stream part callbacks - called for each corresponding stream part type.
 200type (
 201	// OnChunkFunc is called for each stream part (catch-all).
 202	OnChunkFunc func(StreamPart) error
 203
 204	// OnWarningsFunc is called for warnings.
 205	OnWarningsFunc func(warnings []CallWarning) error
 206
 207	// OnTextStartFunc is called when text starts.
 208	OnTextStartFunc func(id string) error
 209
 210	// OnTextDeltaFunc is called for text deltas.
 211	OnTextDeltaFunc func(id, text string) error
 212
 213	// OnTextEndFunc is called when text ends.
 214	OnTextEndFunc func(id string) error
 215
 216	// OnReasoningStartFunc is called when reasoning starts.
 217	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 218
 219	// OnReasoningDeltaFunc is called for reasoning deltas.
 220	OnReasoningDeltaFunc func(id, text string) error
 221
 222	// OnReasoningEndFunc is called when reasoning ends.
 223	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 224
 225	// OnToolInputStartFunc is called when tool input starts.
 226	OnToolInputStartFunc func(id, toolName string) error
 227
 228	// OnToolInputDeltaFunc is called for tool input deltas.
 229	OnToolInputDeltaFunc func(id, delta string) error
 230
 231	// OnToolInputEndFunc is called when tool input ends.
 232	OnToolInputEndFunc func(id string) error
 233
 234	// OnToolCallFunc is called when tool call is complete.
 235	OnToolCallFunc func(toolCall ToolCallContent) error
 236
 237	// OnToolResultFunc is called when tool execution completes.
 238	OnToolResultFunc func(result ToolResultContent) error
 239
 240	// OnSourceFunc is called for source references.
 241	OnSourceFunc func(source SourceContent) error
 242
 243	// OnStreamFinishFunc is called when stream finishes.
 244	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 245)
 246
 247// AgentStreamCall represents a streaming call to an agent.
 248type AgentStreamCall struct {
 249	Prompt           string     `json:"prompt"`
 250	Files            []FilePart `json:"files"`
 251	Messages         []Message  `json:"messages"`
 252	MaxOutputTokens  *int64
 253	Temperature      *float64 `json:"temperature"`
 254	TopP             *float64 `json:"top_p"`
 255	TopK             *int64   `json:"top_k"`
 256	PresencePenalty  *float64 `json:"presence_penalty"`
 257	FrequencyPenalty *float64 `json:"frequency_penalty"`
 258	ActiveTools      []string `json:"active_tools"`
 259	Headers          map[string]string
 260	ProviderOptions  ProviderOptions
 261	OnRetry          OnRetryCallback
 262	MaxRetries       *int
 263
 264	StopWhen       []StopCondition
 265	PrepareStep    PrepareStepFunction
 266	RepairToolCall RepairToolCallFunction
 267
 268	// Agent-level callbacks
 269	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 270	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 271	OnStepStart   OnStepStartFunc   // Called when a step starts
 272	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 273	OnFinish      OnFinishFunc      // Called when entire agent completes
 274	OnError       OnErrorFunc       // Called when an error occurs
 275
 276	// Stream part callbacks - called for each corresponding stream part type
 277	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 278	OnWarnings       OnWarningsFunc       // Called for warnings
 279	OnTextStart      OnTextStartFunc      // Called when text starts
 280	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 281	OnTextEnd        OnTextEndFunc        // Called when text ends
 282	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 283	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 284	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 285	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 286	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 287	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 288	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 289	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 290	OnSource         OnSourceFunc         // Called for source references
 291	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 292}
 293
 294// AgentResult represents the result of an agent execution.
 295type AgentResult struct {
 296	Steps []StepResult
 297	// Final response
 298	Response   Response
 299	TotalUsage Usage
 300}
 301
 302// Agent represents an AI agent that can generate responses and stream responses.
 303type Agent interface {
 304	Generate(context.Context, AgentCall) (*AgentResult, error)
 305	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 306}
 307
 308// AgentOption defines a function that configures agent settings.
 309type AgentOption = func(*agentSettings)
 310
 311type agent struct {
 312	settings agentSettings
 313}
 314
 315// NewAgent creates a new agent with the given language model and options.
 316func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 317	settings := agentSettings{
 318		model: model,
 319	}
 320	for _, o := range opts {
 321		o(&settings)
 322	}
 323	return &agent{
 324		settings: settings,
 325	}
 326}
 327
 328func (a *agent) prepareCall(call AgentCall) AgentCall {
 329	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 330	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 331	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 332	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 333	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 334	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 335	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 336
 337	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 338		call.StopWhen = a.settings.stopWhen
 339	}
 340	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 341		call.PrepareStep = a.settings.prepareStep
 342	}
 343	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 344		call.RepairToolCall = a.settings.repairToolCall
 345	}
 346	if call.OnRetry == nil && a.settings.onRetry != nil {
 347		call.OnRetry = a.settings.onRetry
 348	}
 349
 350	providerOptions := ProviderOptions{}
 351	if a.settings.providerOptions != nil {
 352		maps.Copy(providerOptions, a.settings.providerOptions)
 353	}
 354	if call.ProviderOptions != nil {
 355		maps.Copy(providerOptions, call.ProviderOptions)
 356	}
 357	call.ProviderOptions = providerOptions
 358
 359	headers := map[string]string{}
 360
 361	if a.settings.headers != nil {
 362		maps.Copy(headers, a.settings.headers)
 363	}
 364
 365	return call
 366}
 367
 368// Generate implements Agent.
 369func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 370	opts = a.prepareCall(opts)
 371	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 372	if err != nil {
 373		return nil, err
 374	}
 375	var responseMessages []Message
 376	var steps []StepResult
 377
 378	for {
 379		stepInputMessages := append(initialPrompt, responseMessages...)
 380		stepModel := a.settings.model
 381		stepSystemPrompt := a.settings.systemPrompt
 382		stepActiveTools := opts.ActiveTools
 383		stepToolChoice := ToolChoiceAuto
 384		disableAllTools := false
 385		stepTools := a.settings.tools
 386		if opts.PrepareStep != nil {
 387			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 388				Model:      stepModel,
 389				Steps:      steps,
 390				StepNumber: len(steps),
 391				Messages:   stepInputMessages,
 392			})
 393			if err != nil {
 394				return nil, err
 395			}
 396
 397			ctx = updatedCtx
 398
 399			// Apply prepared step modifications
 400			if prepared.Messages != nil {
 401				stepInputMessages = prepared.Messages
 402			}
 403			if prepared.Model != nil {
 404				stepModel = prepared.Model
 405			}
 406			if prepared.System != nil {
 407				stepSystemPrompt = *prepared.System
 408			}
 409			if prepared.ToolChoice != nil {
 410				stepToolChoice = *prepared.ToolChoice
 411			}
 412			if len(prepared.ActiveTools) > 0 {
 413				stepActiveTools = prepared.ActiveTools
 414			}
 415			disableAllTools = prepared.DisableAllTools
 416			if prepared.Tools != nil {
 417				stepTools = prepared.Tools
 418			}
 419		}
 420
 421		// Recreate prompt with potentially modified system prompt
 422		if stepSystemPrompt != a.settings.systemPrompt {
 423			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 424			if err != nil {
 425				return nil, err
 426			}
 427			// Replace system message part, keep the rest
 428			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 429				stepInputMessages[0] = stepPrompt[0] // Replace system message
 430			}
 431		}
 432
 433		preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
 434
 435		retryOptions := DefaultRetryOptions()
 436		if opts.MaxRetries != nil {
 437			retryOptions.MaxRetries = *opts.MaxRetries
 438		}
 439		retryOptions.OnRetry = opts.OnRetry
 440		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 441
 442		result, err := retry(ctx, func() (*Response, error) {
 443			return stepModel.Generate(ctx, Call{
 444				Prompt:           stepInputMessages,
 445				MaxOutputTokens:  opts.MaxOutputTokens,
 446				Temperature:      opts.Temperature,
 447				TopP:             opts.TopP,
 448				TopK:             opts.TopK,
 449				PresencePenalty:  opts.PresencePenalty,
 450				FrequencyPenalty: opts.FrequencyPenalty,
 451				Tools:            preparedTools,
 452				ToolChoice:       &stepToolChoice,
 453				UserAgent:        a.settings.userAgent,
 454				ProviderOptions:  opts.ProviderOptions,
 455			})
 456		})
 457		if err != nil {
 458			return nil, err
 459		}
 460
 461		var stepToolCalls []ToolCallContent
 462		for _, content := range result.Content {
 463			if content.GetType() == ContentTypeToolCall {
 464				toolCall, ok := AsContentType[ToolCallContent](content)
 465				if !ok {
 466					continue
 467				}
 468				// Provider-executed tool calls (e.g. web search) are
 469				// handled by the provider and should not be validated
 470				// or executed by the agent.
 471				if toolCall.ProviderExecuted {
 472					continue
 473				}
 474				// Validate and potentially repair the tool call
 475				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 476				stepToolCalls = append(stepToolCalls, validatedToolCall)
 477			}
 478		}
 479
 480		toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
 481
 482		// Build step content with validated tool calls and tool results.
 483		// Provider-executed tool calls are kept as-is.
 484		stepContent := []Content{}
 485		toolCallIndex := 0
 486		for _, content := range result.Content {
 487			if content.GetType() == ContentTypeToolCall {
 488				tc, ok := AsContentType[ToolCallContent](content)
 489				if ok && tc.ProviderExecuted {
 490					stepContent = append(stepContent, content)
 491					continue
 492				}
 493				// Replace with validated tool call.
 494				if toolCallIndex < len(stepToolCalls) {
 495					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 496					toolCallIndex++
 497				}
 498			} else {
 499				stepContent = append(stepContent, content)
 500			}
 501		} // Add tool results
 502		for _, result := range toolResults {
 503			stepContent = append(stepContent, result)
 504		}
 505		currentStepMessages := toResponseMessages(stepContent)
 506		responseMessages = append(responseMessages, currentStepMessages...)
 507
 508		stepResult := StepResult{
 509			Response: Response{
 510				Content:          stepContent,
 511				FinishReason:     result.FinishReason,
 512				Usage:            result.Usage,
 513				Warnings:         result.Warnings,
 514				ProviderMetadata: result.ProviderMetadata,
 515			},
 516			Messages: currentStepMessages,
 517		}
 518		steps = append(steps, stepResult)
 519		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 520
 521		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 522			break
 523		}
 524	}
 525
 526	totalUsage := Usage{}
 527
 528	for _, step := range steps {
 529		usage := step.Usage
 530		totalUsage.InputTokens += usage.InputTokens
 531		totalUsage.OutputTokens += usage.OutputTokens
 532		totalUsage.ReasoningTokens += usage.ReasoningTokens
 533		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 534		totalUsage.CacheReadTokens += usage.CacheReadTokens
 535		totalUsage.TotalTokens += usage.TotalTokens
 536	}
 537
 538	agentResult := &AgentResult{
 539		Steps:      steps,
 540		Response:   steps[len(steps)-1].Response,
 541		TotalUsage: totalUsage,
 542	}
 543	return agentResult, nil
 544}
 545
 546func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 547	if len(conditions) == 0 {
 548		return false
 549	}
 550
 551	for _, condition := range conditions {
 552		if condition(steps) {
 553			return true
 554		}
 555	}
 556	return false
 557}
 558
 559func toResponseMessages(content []Content) []Message {
 560	var assistantParts []MessagePart
 561	var toolParts []MessagePart
 562
 563	for _, c := range content {
 564		switch c.GetType() {
 565		case ContentTypeText:
 566			text, ok := AsContentType[TextContent](c)
 567			if !ok {
 568				continue
 569			}
 570			assistantParts = append(assistantParts, TextPart{
 571				Text:            text.Text,
 572				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 573			})
 574		case ContentTypeReasoning:
 575			reasoning, ok := AsContentType[ReasoningContent](c)
 576			if !ok {
 577				continue
 578			}
 579			assistantParts = append(assistantParts, ReasoningPart{
 580				Text:            reasoning.Text,
 581				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 582			})
 583		case ContentTypeToolCall:
 584			toolCall, ok := AsContentType[ToolCallContent](c)
 585			if !ok {
 586				continue
 587			}
 588			assistantParts = append(assistantParts, ToolCallPart{
 589				ToolCallID:       toolCall.ToolCallID,
 590				ToolName:         toolCall.ToolName,
 591				Input:            toolCall.Input,
 592				ProviderExecuted: toolCall.ProviderExecuted,
 593				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 594			})
 595		case ContentTypeFile:
 596			file, ok := AsContentType[FileContent](c)
 597			if !ok {
 598				continue
 599			}
 600			assistantParts = append(assistantParts, FilePart{
 601				Data:            file.Data,
 602				MediaType:       file.MediaType,
 603				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 604			})
 605		case ContentTypeSource:
 606			// Sources are metadata about references used to generate the response.
 607			// They don't need to be included in the conversation messages.
 608			continue
 609		case ContentTypeToolResult:
 610			result, ok := AsContentType[ToolResultContent](c)
 611			if !ok {
 612				continue
 613			}
 614			resultPart := ToolResultPart{
 615				ToolCallID:       result.ToolCallID,
 616				Output:           result.Result,
 617				ProviderExecuted: result.ProviderExecuted,
 618				ProviderOptions:  ProviderOptions(result.ProviderMetadata),
 619			}
 620			if result.ProviderExecuted {
 621				// Provider-executed tool results (e.g. web search)
 622				// belong in the assistant message alongside the
 623				// server_tool_use block that produced them.
 624				assistantParts = append(assistantParts, resultPart)
 625			} else {
 626				toolParts = append(toolParts, resultPart)
 627			}
 628		}
 629	}
 630
 631	var messages []Message
 632	if len(assistantParts) > 0 {
 633		messages = append(messages, Message{
 634			Role:    MessageRoleAssistant,
 635			Content: assistantParts,
 636		})
 637	}
 638	if len(toolParts) > 0 {
 639		messages = append(messages, Message{
 640			Role:    MessageRoleTool,
 641			Content: toolParts,
 642		})
 643	}
 644	return messages
 645}
 646
 647func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 648	if len(toolCalls) == 0 {
 649		return nil, nil
 650	}
 651
 652	// Create a map for quick tool lookup
 653	toolMap := make(map[string]AgentTool)
 654	for _, tool := range allTools {
 655		toolMap[tool.Info().Name] = tool
 656	}
 657
 658	// Execute all tool calls sequentially in order
 659	results := make([]ToolResultContent, 0, len(toolCalls))
 660
 661	for _, toolCall := range toolCalls {
 662		result, isCriticalError := a.executeSingleTool(ctx, toolMap, toolCall, toolResultCallback)
 663		results = append(results, result)
 664		if isCriticalError {
 665			if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
 666				return nil, errorResult.Error
 667			}
 668		}
 669	}
 670
 671	return results, nil
 672}
 673
 674// executeSingleTool executes a single tool and returns its result and a critical error flag.
 675func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
 676	result := ToolResultContent{
 677		ToolCallID:       toolCall.ToolCallID,
 678		ToolName:         toolCall.ToolName,
 679		ProviderExecuted: false,
 680	}
 681
 682	// Skip invalid tool calls - create error result (not critical)
 683	if toolCall.Invalid {
 684		result.Result = ToolResultOutputContentError{
 685			Error: toolCall.ValidationError,
 686		}
 687		if toolResultCallback != nil {
 688			_ = toolResultCallback(result)
 689		}
 690		return result, false
 691	}
 692
 693	tool, exists := toolMap[toolCall.ToolName]
 694	if !exists {
 695		result.Result = ToolResultOutputContentError{
 696			Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
 697		}
 698		if toolResultCallback != nil {
 699			_ = toolResultCallback(result)
 700		}
 701		return result, false
 702	}
 703
 704	// Execute the tool
 705	toolResult, err := tool.Run(ctx, ToolCall{
 706		ID:    toolCall.ToolCallID,
 707		Name:  toolCall.ToolName,
 708		Input: toolCall.Input,
 709	})
 710	if err != nil {
 711		result.Result = ToolResultOutputContentError{
 712			Error: err,
 713		}
 714		result.ClientMetadata = toolResult.Metadata
 715		if toolResultCallback != nil {
 716			_ = toolResultCallback(result)
 717		}
 718		return result, true
 719	}
 720
 721	result.ClientMetadata = toolResult.Metadata
 722	if toolResult.IsError {
 723		result.Result = ToolResultOutputContentError{
 724			Error: errors.New(toolResult.Content),
 725		}
 726	} else if toolResult.Type == "image" || toolResult.Type == "media" {
 727		result.Result = ToolResultOutputContentMedia{
 728			Data:      string(toolResult.Data),
 729			MediaType: toolResult.MediaType,
 730			Text:      toolResult.Content,
 731		}
 732	} else {
 733		result.Result = ToolResultOutputContentText{
 734			Text: toolResult.Content,
 735		}
 736	}
 737	if toolResultCallback != nil {
 738		_ = toolResultCallback(result)
 739	}
 740	return result, false
 741}
 742
 743// Stream implements Agent.
 744func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 745	// Convert AgentStreamCall to AgentCall for preparation
 746	call := AgentCall{
 747		Prompt:           opts.Prompt,
 748		Files:            opts.Files,
 749		Messages:         opts.Messages,
 750		MaxOutputTokens:  opts.MaxOutputTokens,
 751		Temperature:      opts.Temperature,
 752		TopP:             opts.TopP,
 753		TopK:             opts.TopK,
 754		PresencePenalty:  opts.PresencePenalty,
 755		FrequencyPenalty: opts.FrequencyPenalty,
 756		ActiveTools:      opts.ActiveTools,
 757		ProviderOptions:  opts.ProviderOptions,
 758		MaxRetries:       opts.MaxRetries,
 759		OnRetry:          opts.OnRetry,
 760		StopWhen:         opts.StopWhen,
 761		PrepareStep:      opts.PrepareStep,
 762		RepairToolCall:   opts.RepairToolCall,
 763	}
 764
 765	call = a.prepareCall(call)
 766
 767	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 768	if err != nil {
 769		return nil, err
 770	}
 771
 772	var responseMessages []Message
 773	var steps []StepResult
 774	var totalUsage Usage
 775
 776	// Start agent stream
 777	if opts.OnAgentStart != nil {
 778		opts.OnAgentStart()
 779	}
 780
 781	for stepNumber := 0; ; stepNumber++ {
 782		stepInputMessages := append(initialPrompt, responseMessages...)
 783		stepModel := a.settings.model
 784		stepSystemPrompt := a.settings.systemPrompt
 785		stepActiveTools := call.ActiveTools
 786		stepToolChoice := ToolChoiceAuto
 787		disableAllTools := false
 788		stepTools := a.settings.tools
 789		// Apply step preparation if provided
 790		if call.PrepareStep != nil {
 791			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 792				Model:      stepModel,
 793				Steps:      steps,
 794				StepNumber: stepNumber,
 795				Messages:   stepInputMessages,
 796			})
 797			if err != nil {
 798				return nil, err
 799			}
 800
 801			ctx = updatedCtx
 802
 803			if prepared.Messages != nil {
 804				stepInputMessages = prepared.Messages
 805			}
 806			if prepared.Model != nil {
 807				stepModel = prepared.Model
 808			}
 809			if prepared.System != nil {
 810				stepSystemPrompt = *prepared.System
 811			}
 812			if prepared.ToolChoice != nil {
 813				stepToolChoice = *prepared.ToolChoice
 814			}
 815			if len(prepared.ActiveTools) > 0 {
 816				stepActiveTools = prepared.ActiveTools
 817			}
 818			disableAllTools = prepared.DisableAllTools
 819			if prepared.Tools != nil {
 820				stepTools = prepared.Tools
 821			}
 822		}
 823
 824		// Recreate prompt with potentially modified system prompt
 825		if stepSystemPrompt != a.settings.systemPrompt {
 826			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 827			if err != nil {
 828				return nil, err
 829			}
 830			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 831				stepInputMessages[0] = stepPrompt[0]
 832			}
 833		}
 834
 835		preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
 836
 837		// Start step stream
 838		if opts.OnStepStart != nil {
 839			_ = opts.OnStepStart(stepNumber)
 840		}
 841
 842		// Create streaming call
 843		streamCall := Call{
 844			Prompt:           stepInputMessages,
 845			MaxOutputTokens:  call.MaxOutputTokens,
 846			Temperature:      call.Temperature,
 847			TopP:             call.TopP,
 848			TopK:             call.TopK,
 849			PresencePenalty:  call.PresencePenalty,
 850			FrequencyPenalty: call.FrequencyPenalty,
 851			Tools:            preparedTools,
 852			ToolChoice:       &stepToolChoice,
 853			UserAgent:        a.settings.userAgent,
 854			ProviderOptions:  call.ProviderOptions,
 855		}
 856
 857		// Execute step with retry logic wrapping both stream creation and processing
 858		retryOptions := DefaultRetryOptions()
 859		if call.MaxRetries != nil {
 860			retryOptions.MaxRetries = *call.MaxRetries
 861		}
 862		retryOptions.OnRetry = call.OnRetry
 863		retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
 864
 865		result, err := retry(ctx, func() (stepExecutionResult, error) {
 866			// Create the stream
 867			stream, err := stepModel.Stream(ctx, streamCall)
 868			if err != nil {
 869				return stepExecutionResult{}, err
 870			}
 871
 872			// Process the stream
 873			result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
 874			if err != nil {
 875				return stepExecutionResult{}, err
 876			}
 877
 878			return result, nil
 879		})
 880		if err != nil {
 881			if opts.OnError != nil {
 882				opts.OnError(err)
 883			}
 884			return nil, err
 885		}
 886
 887		steps = append(steps, result.StepResult)
 888		totalUsage = addUsage(totalUsage, result.StepResult.Usage)
 889
 890		// Call step finished callback
 891		if opts.OnStepFinish != nil {
 892			_ = opts.OnStepFinish(result.StepResult)
 893		}
 894
 895		// Add step messages to response messages
 896		stepMessages := toResponseMessages(result.StepResult.Content)
 897		responseMessages = append(responseMessages, stepMessages...)
 898
 899		// Check stop conditions
 900		shouldStop := isStopConditionMet(call.StopWhen, steps)
 901		if shouldStop || !result.ShouldContinue {
 902			break
 903		}
 904	}
 905
 906	// Finish agent stream
 907	agentResult := &AgentResult{
 908		Steps:      steps,
 909		Response:   steps[len(steps)-1].Response,
 910		TotalUsage: totalUsage,
 911	}
 912
 913	if opts.OnFinish != nil {
 914		opts.OnFinish(agentResult)
 915	}
 916
 917	if opts.OnAgentFinish != nil {
 918		_ = opts.OnAgentFinish(agentResult)
 919	}
 920
 921	return agentResult, nil
 922}
 923
 924func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
 925	preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
 926
 927	// If explicitly disabling all tools, return no tools
 928	if disableAllTools {
 929		return preparedTools
 930	}
 931
 932	for _, tool := range tools {
 933		// If activeTools has items, only include tools in the list
 934		// If activeTools is empty, include all tools
 935		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 936			continue
 937		}
 938		info := tool.Info()
 939		inputSchema := map[string]any{
 940			"type":       "object",
 941			"properties": info.Parameters,
 942			"required":   info.Required,
 943		}
 944		schema.Normalize(inputSchema)
 945		preparedTools = append(preparedTools, FunctionTool{
 946			Name:            info.Name,
 947			Description:     info.Description,
 948			InputSchema:     inputSchema,
 949			ProviderOptions: tool.ProviderOptions(),
 950		})
 951	}
 952	for _, tool := range providerDefinedTools {
 953		// If activeTools has items, only include tools in the list. If
 954		// activeTools is empty, include all tools
 955		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
 956			continue
 957		}
 958		preparedTools = append(preparedTools, tool)
 959	}
 960	return preparedTools
 961}
 962
 963// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 964func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 965	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 966		return toolCall
 967	} else { //nolint: revive
 968		if repairFunc != nil {
 969			repairOptions := ToolCallRepairOptions{
 970				OriginalToolCall: toolCall,
 971				ValidationError:  err,
 972				AvailableTools:   availableTools,
 973				SystemPrompt:     systemPrompt,
 974				Messages:         messages,
 975			}
 976
 977			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 978				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 979					return *repairedToolCall
 980				}
 981			}
 982		}
 983
 984		invalidToolCall := toolCall
 985		invalidToolCall.Invalid = true
 986		invalidToolCall.ValidationError = err
 987		return invalidToolCall
 988	}
 989}
 990
 991// validateToolCall validates a tool call against available tools and their schemas.
 992func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 993	var tool AgentTool
 994	for _, t := range availableTools {
 995		if t.Info().Name == toolCall.ToolName {
 996			tool = t
 997			break
 998		}
 999	}
1000
1001	if tool == nil {
1002		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1003	}
1004
1005	// Validate JSON parsing
1006	var input map[string]any
1007	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1008		return fmt.Errorf("invalid JSON input: %w", err)
1009	}
1010
1011	// Basic schema validation (check required fields)
1012	// TODO: more robust schema validation using JSON Schema or similar
1013	toolInfo := tool.Info()
1014	for _, required := range toolInfo.Required {
1015		if _, exists := input[required]; !exists {
1016			return fmt.Errorf("missing required parameter: %s", required)
1017		}
1018	}
1019	return nil
1020}
1021
1022func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1023	// Validation: empty prompt is only allowed when there are messages,
1024	// no files to attach, and the last message is a user or tool message.
1025	if prompt == "" {
1026		lastMessage, hasMessages := slice.Last(messages)
1027
1028		if !hasMessages {
1029			return nil, &Error{
1030				Title:   "invalid argument",
1031				Message: "prompt can't be empty when there are no messages",
1032			}
1033		}
1034
1035		if len(files) > 0 {
1036			return nil, &Error{
1037				Title:   "invalid argument",
1038				Message: "prompt can't be empty when there are files",
1039			}
1040		}
1041
1042		switch lastMessage.Role {
1043		case MessageRoleUser, MessageRoleTool:
1044		default:
1045			return nil, &Error{
1046				Title:   "invalid argument",
1047				Message: "prompt can't be empty when the last message is not a user or tool message",
1048			}
1049		}
1050	}
1051
1052	var preparedPrompt Prompt
1053
1054	if system != "" {
1055		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1056	}
1057	preparedPrompt = append(preparedPrompt, messages...)
1058	if prompt != "" {
1059		preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1060	}
1061	return preparedPrompt, nil
1062}
1063
1064// WithSystemPrompt sets the system prompt for the agent.
1065func WithSystemPrompt(prompt string) AgentOption {
1066	return func(s *agentSettings) {
1067		s.systemPrompt = prompt
1068	}
1069}
1070
1071// WithMaxOutputTokens sets the maximum output tokens for the agent.
1072func WithMaxOutputTokens(tokens int64) AgentOption {
1073	return func(s *agentSettings) {
1074		s.maxOutputTokens = &tokens
1075	}
1076}
1077
1078// WithTemperature sets the temperature for the agent.
1079func WithTemperature(temp float64) AgentOption {
1080	return func(s *agentSettings) {
1081		s.temperature = &temp
1082	}
1083}
1084
1085// WithTopP sets the top-p value for the agent.
1086func WithTopP(topP float64) AgentOption {
1087	return func(s *agentSettings) {
1088		s.topP = &topP
1089	}
1090}
1091
1092// WithTopK sets the top-k value for the agent.
1093func WithTopK(topK int64) AgentOption {
1094	return func(s *agentSettings) {
1095		s.topK = &topK
1096	}
1097}
1098
1099// WithPresencePenalty sets the presence penalty for the agent.
1100func WithPresencePenalty(penalty float64) AgentOption {
1101	return func(s *agentSettings) {
1102		s.presencePenalty = &penalty
1103	}
1104}
1105
1106// WithFrequencyPenalty sets the frequency penalty for the agent.
1107func WithFrequencyPenalty(penalty float64) AgentOption {
1108	return func(s *agentSettings) {
1109		s.frequencyPenalty = &penalty
1110	}
1111}
1112
1113// WithTools sets the tools for the agent.
1114func WithTools(tools ...AgentTool) AgentOption {
1115	return func(s *agentSettings) {
1116		s.tools = append(s.tools, tools...)
1117	}
1118}
1119
1120// WithProviderDefinedTools sets the provider-defined tools for the agent.
1121// These tools are executed by the provider (e.g. web search) rather
1122// than by the client.
1123func WithProviderDefinedTools(tools ...ProviderDefinedTool) AgentOption {
1124	return func(s *agentSettings) {
1125		s.providerDefinedTools = append(s.providerDefinedTools, tools...)
1126	}
1127}
1128
1129// WithStopConditions sets the stop conditions for the agent.
1130func WithStopConditions(conditions ...StopCondition) AgentOption {
1131	return func(s *agentSettings) {
1132		s.stopWhen = append(s.stopWhen, conditions...)
1133	}
1134}
1135
1136// WithPrepareStep sets the prepare step function for the agent.
1137func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1138	return func(s *agentSettings) {
1139		s.prepareStep = fn
1140	}
1141}
1142
1143// WithRepairToolCall sets the repair tool call function for the agent.
1144func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1145	return func(s *agentSettings) {
1146		s.repairToolCall = fn
1147	}
1148}
1149
1150// WithMaxRetries sets the maximum number of retries for the agent.
1151func WithMaxRetries(maxRetries int) AgentOption {
1152	return func(s *agentSettings) {
1153		s.maxRetries = &maxRetries
1154	}
1155}
1156
1157// WithOnRetry sets the retry callback for the agent.
1158func WithOnRetry(callback OnRetryCallback) AgentOption {
1159	return func(s *agentSettings) {
1160		s.onRetry = callback
1161	}
1162}
1163
1164// processStepStream processes a single step's stream and returns the step result.
1165func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1166	var stepContent []Content
1167	var stepToolCalls []ToolCallContent
1168	var stepUsage Usage
1169	stepFinishReason := FinishReasonUnknown
1170	var stepWarnings []CallWarning
1171	var stepProviderMetadata ProviderMetadata
1172
1173	activeToolCalls := make(map[string]*ToolCallContent)
1174	activeTextContent := make(map[string]string)
1175	type reasoningContent struct {
1176		content string
1177		options ProviderMetadata
1178	}
1179	activeReasoningContent := make(map[string]reasoningContent)
1180
1181	// Set up concurrent tool execution
1182	type toolExecutionRequest struct {
1183		toolCall ToolCallContent
1184		parallel bool
1185	}
1186	toolChan := make(chan toolExecutionRequest, 10)
1187	var toolExecutionWg sync.WaitGroup
1188	var toolStateMu sync.Mutex
1189	toolResults := make([]ToolResultContent, 0)
1190	var toolExecutionErr error
1191
1192	// Create a map for quick tool lookup
1193	toolMap := make(map[string]AgentTool)
1194	for _, tool := range stepTools {
1195		toolMap[tool.Info().Name] = tool
1196	}
1197
1198	// Semaphores for controlling parallelism
1199	parallelSem := make(chan struct{}, 5)
1200	var sequentialMu sync.Mutex
1201
1202	// Single coordinator goroutine that dispatches tools
1203	toolExecutionWg.Go(func() {
1204		for req := range toolChan {
1205			if req.parallel {
1206				parallelSem <- struct{}{}
1207				toolExecutionWg.Go(func() {
1208					defer func() { <-parallelSem }()
1209					result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1210					toolStateMu.Lock()
1211					toolResults = append(toolResults, result)
1212					if isCriticalError && toolExecutionErr == nil {
1213						if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1214							toolExecutionErr = errorResult.Error
1215						}
1216					}
1217					toolStateMu.Unlock()
1218				})
1219			} else {
1220				sequentialMu.Lock()
1221				result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1222				toolStateMu.Lock()
1223				toolResults = append(toolResults, result)
1224				if isCriticalError && toolExecutionErr == nil {
1225					if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1226						toolExecutionErr = errorResult.Error
1227					}
1228				}
1229				toolStateMu.Unlock()
1230				sequentialMu.Unlock()
1231			}
1232		}
1233	})
1234
1235	// Process stream parts
1236	for part := range stream {
1237		// Forward all parts to chunk callback
1238		if opts.OnChunk != nil {
1239			err := opts.OnChunk(part)
1240			if err != nil {
1241				return stepExecutionResult{}, err
1242			}
1243		}
1244
1245		switch part.Type {
1246		case StreamPartTypeWarnings:
1247			stepWarnings = part.Warnings
1248			if opts.OnWarnings != nil {
1249				err := opts.OnWarnings(part.Warnings)
1250				if err != nil {
1251					return stepExecutionResult{}, err
1252				}
1253			}
1254
1255		case StreamPartTypeTextStart:
1256			activeTextContent[part.ID] = ""
1257			if opts.OnTextStart != nil {
1258				err := opts.OnTextStart(part.ID)
1259				if err != nil {
1260					return stepExecutionResult{}, err
1261				}
1262			}
1263
1264		case StreamPartTypeTextDelta:
1265			if _, exists := activeTextContent[part.ID]; exists {
1266				activeTextContent[part.ID] += part.Delta
1267			}
1268			if opts.OnTextDelta != nil {
1269				err := opts.OnTextDelta(part.ID, part.Delta)
1270				if err != nil {
1271					return stepExecutionResult{}, err
1272				}
1273			}
1274
1275		case StreamPartTypeTextEnd:
1276			if text, exists := activeTextContent[part.ID]; exists {
1277				stepContent = append(stepContent, TextContent{
1278					Text:             text,
1279					ProviderMetadata: part.ProviderMetadata,
1280				})
1281				delete(activeTextContent, part.ID)
1282			}
1283			if opts.OnTextEnd != nil {
1284				err := opts.OnTextEnd(part.ID)
1285				if err != nil {
1286					return stepExecutionResult{}, err
1287				}
1288			}
1289
1290		case StreamPartTypeReasoningStart:
1291			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1292			if opts.OnReasoningStart != nil {
1293				content := ReasoningContent{
1294					Text:             part.Delta,
1295					ProviderMetadata: part.ProviderMetadata,
1296				}
1297				err := opts.OnReasoningStart(part.ID, content)
1298				if err != nil {
1299					return stepExecutionResult{}, err
1300				}
1301			}
1302
1303		case StreamPartTypeReasoningDelta:
1304			if active, exists := activeReasoningContent[part.ID]; exists {
1305				active.content += part.Delta
1306				if part.ProviderMetadata != nil {
1307					active.options = part.ProviderMetadata
1308				}
1309				activeReasoningContent[part.ID] = active
1310			}
1311			if opts.OnReasoningDelta != nil {
1312				err := opts.OnReasoningDelta(part.ID, part.Delta)
1313				if err != nil {
1314					return stepExecutionResult{}, err
1315				}
1316			}
1317
1318		case StreamPartTypeReasoningEnd:
1319			if active, exists := activeReasoningContent[part.ID]; exists {
1320				if part.ProviderMetadata != nil {
1321					active.options = part.ProviderMetadata
1322				}
1323				content := ReasoningContent{
1324					Text:             active.content,
1325					ProviderMetadata: active.options,
1326				}
1327				stepContent = append(stepContent, content)
1328				if opts.OnReasoningEnd != nil {
1329					err := opts.OnReasoningEnd(part.ID, content)
1330					if err != nil {
1331						return stepExecutionResult{}, err
1332					}
1333				}
1334				delete(activeReasoningContent, part.ID)
1335			}
1336
1337		case StreamPartTypeToolInputStart:
1338			activeToolCalls[part.ID] = &ToolCallContent{
1339				ToolCallID:       part.ID,
1340				ToolName:         part.ToolCallName,
1341				Input:            "",
1342				ProviderExecuted: part.ProviderExecuted,
1343			}
1344			if opts.OnToolInputStart != nil {
1345				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1346				if err != nil {
1347					return stepExecutionResult{}, err
1348				}
1349			}
1350
1351		case StreamPartTypeToolInputDelta:
1352			if toolCall, exists := activeToolCalls[part.ID]; exists {
1353				toolCall.Input += part.Delta
1354			}
1355			if opts.OnToolInputDelta != nil {
1356				err := opts.OnToolInputDelta(part.ID, part.Delta)
1357				if err != nil {
1358					return stepExecutionResult{}, err
1359				}
1360			}
1361
1362		case StreamPartTypeToolInputEnd:
1363			if opts.OnToolInputEnd != nil {
1364				err := opts.OnToolInputEnd(part.ID)
1365				if err != nil {
1366					return stepExecutionResult{}, err
1367				}
1368			}
1369
1370		case StreamPartTypeToolCall:
1371			toolCall := ToolCallContent{
1372				ToolCallID:       part.ID,
1373				ToolName:         part.ToolCallName,
1374				Input:            part.ToolCallInput,
1375				ProviderExecuted: part.ProviderExecuted,
1376				ProviderMetadata: part.ProviderMetadata,
1377			}
1378
1379			// Provider-executed tool calls are handled by the provider
1380			// and should not be validated or executed by the agent.
1381			if toolCall.ProviderExecuted {
1382				stepContent = append(stepContent, toolCall)
1383				if opts.OnToolCall != nil {
1384					err := opts.OnToolCall(toolCall)
1385					if err != nil {
1386						return stepExecutionResult{}, err
1387					}
1388				}
1389				delete(activeToolCalls, part.ID)
1390			} else {
1391				// Validate and potentially repair the tool call
1392				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1393				stepToolCalls = append(stepToolCalls, validatedToolCall)
1394				stepContent = append(stepContent, validatedToolCall)
1395
1396				if opts.OnToolCall != nil {
1397					err := opts.OnToolCall(validatedToolCall)
1398					if err != nil {
1399						return stepExecutionResult{}, err
1400					}
1401				}
1402
1403				// Determine if tool can run in parallel
1404				isParallel := false
1405				if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1406					isParallel = tool.Info().Parallel
1407				}
1408
1409				// Send tool call to execution channel
1410				toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1411
1412				// Clean up active tool call
1413				delete(activeToolCalls, part.ID)
1414			}
1415
1416		case StreamPartTypeToolResult:
1417			// Provider-executed tool results (e.g. web search)
1418			// are emitted by the provider and added directly
1419			// to the step content for multi-turn round-tripping.
1420			if part.ProviderExecuted {
1421				resultContent := ToolResultContent{
1422					ToolCallID:       part.ID,
1423					ToolName:         part.ToolCallName,
1424					ProviderExecuted: true,
1425					ProviderMetadata: part.ProviderMetadata,
1426				}
1427				stepContent = append(stepContent, resultContent)
1428				if opts.OnToolResult != nil {
1429					err := opts.OnToolResult(resultContent)
1430					if err != nil {
1431						return stepExecutionResult{}, err
1432					}
1433				}
1434			}
1435
1436		case StreamPartTypeSource:
1437			sourceContent := SourceContent{
1438				SourceType:       part.SourceType,
1439				ID:               part.ID,
1440				URL:              part.URL,
1441				Title:            part.Title,
1442				ProviderMetadata: part.ProviderMetadata,
1443			}
1444			stepContent = append(stepContent, sourceContent)
1445			if opts.OnSource != nil {
1446				err := opts.OnSource(sourceContent)
1447				if err != nil {
1448					return stepExecutionResult{}, err
1449				}
1450			}
1451
1452		case StreamPartTypeFinish:
1453			stepUsage = part.Usage
1454			stepFinishReason = part.FinishReason
1455			stepProviderMetadata = part.ProviderMetadata
1456			if opts.OnStreamFinish != nil {
1457				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1458				if err != nil {
1459					return stepExecutionResult{}, err
1460				}
1461			}
1462
1463		case StreamPartTypeError:
1464			return stepExecutionResult{}, part.Error
1465		}
1466	}
1467
1468	// Close the tool execution channel and wait for all executions to complete
1469	close(toolChan)
1470	toolExecutionWg.Wait()
1471
1472	// Check for tool execution errors
1473	if toolExecutionErr != nil {
1474		return stepExecutionResult{}, toolExecutionErr
1475	}
1476
1477	// Add tool results to content if any
1478	if len(toolResults) > 0 {
1479		for _, result := range toolResults {
1480			stepContent = append(stepContent, result)
1481		}
1482	}
1483
1484	stepResult := StepResult{
1485		Response: Response{
1486			Content:          stepContent,
1487			FinishReason:     stepFinishReason,
1488			Usage:            stepUsage,
1489			Warnings:         stepWarnings,
1490			ProviderMetadata: stepProviderMetadata,
1491		},
1492		Messages: toResponseMessages(stepContent),
1493	}
1494
1495	// Determine if we should continue (has tool calls and not stopped)
1496	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1497
1498	return stepExecutionResult{
1499		StepResult:     stepResult,
1500		ShouldContinue: shouldContinue,
1501	}, nil
1502}
1503
1504func addUsage(a, b Usage) Usage {
1505	return Usage{
1506		InputTokens:         a.InputTokens + b.InputTokens,
1507		OutputTokens:        a.OutputTokens + b.OutputTokens,
1508		TotalTokens:         a.TotalTokens + b.TotalTokens,
1509		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1510		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1511		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1512	}
1513}
1514
1515// WithHeaders sets the headers for the agent.
1516func WithHeaders(headers map[string]string) AgentOption {
1517	return func(s *agentSettings) {
1518		s.headers = headers
1519	}
1520}
1521
1522// WithUserAgent sets the User-Agent header for the agent. This overrides any
1523// provider-level User-Agent setting.
1524func WithUserAgent(ua string) AgentOption {
1525	return func(s *agentSettings) {
1526		s.userAgent = ua
1527	}
1528}
1529
1530// WithProviderOptions sets the provider options for the agent.
1531func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1532	return func(s *agentSettings) {
1533		s.providerOptions = providerOptions
1534	}
1535}