agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11	"sync"
  12
  13	"charm.land/fantasy/schema"
  14)
  15
  16// StepResult represents the result of a single step in an agent execution.
  17type StepResult struct {
  18	Response
  19	Messages []Message
  20}
  21
  22// stepExecutionResult encapsulates the result of executing a step with stream processing.
  23type stepExecutionResult struct {
  24	StepResult     StepResult
  25	ShouldContinue bool
  26}
  27
  28// StopCondition defines a function that determines when an agent should stop executing.
  29type StopCondition = func(steps []StepResult) bool
  30
  31// StepCountIs returns a stop condition that stops after the specified number of steps.
  32func StepCountIs(stepCount int) StopCondition {
  33	return func(steps []StepResult) bool {
  34		return len(steps) >= stepCount
  35	}
  36}
  37
  38// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  39func HasToolCall(toolName string) StopCondition {
  40	return func(steps []StepResult) bool {
  41		if len(steps) == 0 {
  42			return false
  43		}
  44		lastStep := steps[len(steps)-1]
  45		toolCalls := lastStep.Content.ToolCalls()
  46		for _, toolCall := range toolCalls {
  47			if toolCall.ToolName == toolName {
  48				return true
  49			}
  50		}
  51		return false
  52	}
  53}
  54
  55// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  56func HasContent(contentType ContentType) StopCondition {
  57	return func(steps []StepResult) bool {
  58		if len(steps) == 0 {
  59			return false
  60		}
  61		lastStep := steps[len(steps)-1]
  62		for _, content := range lastStep.Content {
  63			if content.GetType() == contentType {
  64				return true
  65			}
  66		}
  67		return false
  68	}
  69}
  70
  71// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  72func FinishReasonIs(reason FinishReason) StopCondition {
  73	return func(steps []StepResult) bool {
  74		if len(steps) == 0 {
  75			return false
  76		}
  77		lastStep := steps[len(steps)-1]
  78		return lastStep.FinishReason == reason
  79	}
  80}
  81
  82// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  83func MaxTokensUsed(maxTokens int64) StopCondition {
  84	return func(steps []StepResult) bool {
  85		var totalTokens int64
  86		for _, step := range steps {
  87			totalTokens += step.Usage.TotalTokens
  88		}
  89		return totalTokens >= maxTokens
  90	}
  91}
  92
  93// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  94type PrepareStepFunctionOptions struct {
  95	Steps      []StepResult
  96	StepNumber int
  97	Model      LanguageModel
  98	Messages   []Message
  99}
 100
 101// PrepareStepResult contains the result of preparing a step in an agent execution.
 102type PrepareStepResult struct {
 103	Model           LanguageModel
 104	Messages        []Message
 105	System          *string
 106	ToolChoice      *ToolChoice
 107	ActiveTools     []string
 108	DisableAllTools bool
 109	Tools           []AgentTool
 110}
 111
 112// ToolCallRepairOptions contains the options for repairing a tool call.
 113type ToolCallRepairOptions struct {
 114	OriginalToolCall ToolCallContent
 115	ValidationError  error
 116	AvailableTools   []AgentTool
 117	SystemPrompt     string
 118	Messages         []Message
 119}
 120
 121type (
 122	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 123	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 124
 125	// OnStepFinishedFunction defines a function that is called when a step finishes.
 126	OnStepFinishedFunction = func(step StepResult)
 127
 128	// RepairToolCallFunction defines a function that repairs a tool call.
 129	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 130)
 131
 132type agentSettings struct {
 133	systemPrompt     string
 134	maxOutputTokens  *int64
 135	temperature      *float64
 136	topP             *float64
 137	topK             *int64
 138	presencePenalty  *float64
 139	frequencyPenalty *float64
 140	headers          map[string]string
 141	providerOptions  ProviderOptions
 142
 143	// TODO: add support for provider tools
 144	tools      []AgentTool
 145	maxRetries *int
 146
 147	model LanguageModel
 148
 149	stopWhen       []StopCondition
 150	prepareStep    PrepareStepFunction
 151	repairToolCall RepairToolCallFunction
 152	onRetry        OnRetryCallback
 153}
 154
 155// AgentCall represents a call to an agent.
 156type AgentCall struct {
 157	Prompt           string     `json:"prompt"`
 158	Files            []FilePart `json:"files"`
 159	Messages         []Message  `json:"messages"`
 160	MaxOutputTokens  *int64
 161	Temperature      *float64 `json:"temperature"`
 162	TopP             *float64 `json:"top_p"`
 163	TopK             *int64   `json:"top_k"`
 164	PresencePenalty  *float64 `json:"presence_penalty"`
 165	FrequencyPenalty *float64 `json:"frequency_penalty"`
 166	ActiveTools      []string `json:"active_tools"`
 167	ProviderOptions  ProviderOptions
 168	OnRetry          OnRetryCallback
 169	MaxRetries       *int
 170
 171	StopWhen       []StopCondition
 172	PrepareStep    PrepareStepFunction
 173	RepairToolCall RepairToolCallFunction
 174}
 175
 176// Agent-level callbacks.
 177type (
 178	// OnAgentStartFunc is called when agent starts.
 179	OnAgentStartFunc func()
 180
 181	// OnAgentFinishFunc is called when agent finishes.
 182	OnAgentFinishFunc func(result *AgentResult) error
 183
 184	// OnStepStartFunc is called when a step starts.
 185	OnStepStartFunc func(stepNumber int) error
 186
 187	// OnStepFinishFunc is called when a step finishes.
 188	OnStepFinishFunc func(stepResult StepResult) error
 189
 190	// OnFinishFunc is called when entire agent completes.
 191	OnFinishFunc func(result *AgentResult)
 192
 193	// OnErrorFunc is called when an error occurs.
 194	OnErrorFunc func(error)
 195)
 196
 197// Stream part callbacks - called for each corresponding stream part type.
 198type (
 199	// OnChunkFunc is called for each stream part (catch-all).
 200	OnChunkFunc func(StreamPart) error
 201
 202	// OnWarningsFunc is called for warnings.
 203	OnWarningsFunc func(warnings []CallWarning) error
 204
 205	// OnTextStartFunc is called when text starts.
 206	OnTextStartFunc func(id string) error
 207
 208	// OnTextDeltaFunc is called for text deltas.
 209	OnTextDeltaFunc func(id, text string) error
 210
 211	// OnTextEndFunc is called when text ends.
 212	OnTextEndFunc func(id string) error
 213
 214	// OnReasoningStartFunc is called when reasoning starts.
 215	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 216
 217	// OnReasoningDeltaFunc is called for reasoning deltas.
 218	OnReasoningDeltaFunc func(id, text string) error
 219
 220	// OnReasoningEndFunc is called when reasoning ends.
 221	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 222
 223	// OnToolInputStartFunc is called when tool input starts.
 224	OnToolInputStartFunc func(id, toolName string) error
 225
 226	// OnToolInputDeltaFunc is called for tool input deltas.
 227	OnToolInputDeltaFunc func(id, delta string) error
 228
 229	// OnToolInputEndFunc is called when tool input ends.
 230	OnToolInputEndFunc func(id string) error
 231
 232	// OnToolCallFunc is called when tool call is complete.
 233	OnToolCallFunc func(toolCall ToolCallContent) error
 234
 235	// OnToolResultFunc is called when tool execution completes.
 236	OnToolResultFunc func(result ToolResultContent) error
 237
 238	// OnSourceFunc is called for source references.
 239	OnSourceFunc func(source SourceContent) error
 240
 241	// OnStreamFinishFunc is called when stream finishes.
 242	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 243)
 244
 245// AgentStreamCall represents a streaming call to an agent.
 246type AgentStreamCall struct {
 247	Prompt           string     `json:"prompt"`
 248	Files            []FilePart `json:"files"`
 249	Messages         []Message  `json:"messages"`
 250	MaxOutputTokens  *int64
 251	Temperature      *float64 `json:"temperature"`
 252	TopP             *float64 `json:"top_p"`
 253	TopK             *int64   `json:"top_k"`
 254	PresencePenalty  *float64 `json:"presence_penalty"`
 255	FrequencyPenalty *float64 `json:"frequency_penalty"`
 256	ActiveTools      []string `json:"active_tools"`
 257	Headers          map[string]string
 258	ProviderOptions  ProviderOptions
 259	OnRetry          OnRetryCallback
 260	MaxRetries       *int
 261
 262	StopWhen       []StopCondition
 263	PrepareStep    PrepareStepFunction
 264	RepairToolCall RepairToolCallFunction
 265
 266	// Agent-level callbacks
 267	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 268	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 269	OnStepStart   OnStepStartFunc   // Called when a step starts
 270	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 271	OnFinish      OnFinishFunc      // Called when entire agent completes
 272	OnError       OnErrorFunc       // Called when an error occurs
 273
 274	// Stream part callbacks - called for each corresponding stream part type
 275	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 276	OnWarnings       OnWarningsFunc       // Called for warnings
 277	OnTextStart      OnTextStartFunc      // Called when text starts
 278	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 279	OnTextEnd        OnTextEndFunc        // Called when text ends
 280	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 281	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 282	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 283	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 284	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 285	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 286	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 287	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 288	OnSource         OnSourceFunc         // Called for source references
 289	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 290}
 291
 292// AgentResult represents the result of an agent execution.
 293type AgentResult struct {
 294	Steps []StepResult
 295	// Final response
 296	Response   Response
 297	TotalUsage Usage
 298}
 299
 300// Agent represents an AI agent that can generate responses and stream responses.
 301type Agent interface {
 302	Generate(context.Context, AgentCall) (*AgentResult, error)
 303	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 304}
 305
 306// AgentOption defines a function that configures agent settings.
 307type AgentOption = func(*agentSettings)
 308
 309type agent struct {
 310	settings agentSettings
 311}
 312
 313// NewAgent creates a new agent with the given language model and options.
 314func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 315	settings := agentSettings{
 316		model: model,
 317	}
 318	for _, o := range opts {
 319		o(&settings)
 320	}
 321	return &agent{
 322		settings: settings,
 323	}
 324}
 325
 326func (a *agent) prepareCall(call AgentCall) AgentCall {
 327	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 328	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 329	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 330	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 331	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 332	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 333	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 334
 335	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 336		call.StopWhen = a.settings.stopWhen
 337	}
 338	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 339		call.PrepareStep = a.settings.prepareStep
 340	}
 341	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 342		call.RepairToolCall = a.settings.repairToolCall
 343	}
 344	if call.OnRetry == nil && a.settings.onRetry != nil {
 345		call.OnRetry = a.settings.onRetry
 346	}
 347
 348	providerOptions := ProviderOptions{}
 349	if a.settings.providerOptions != nil {
 350		maps.Copy(providerOptions, a.settings.providerOptions)
 351	}
 352	if call.ProviderOptions != nil {
 353		maps.Copy(providerOptions, call.ProviderOptions)
 354	}
 355	call.ProviderOptions = providerOptions
 356
 357	headers := map[string]string{}
 358
 359	if a.settings.headers != nil {
 360		maps.Copy(headers, a.settings.headers)
 361	}
 362
 363	return call
 364}
 365
 366// Generate implements Agent.
 367func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 368	opts = a.prepareCall(opts)
 369	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 370	if err != nil {
 371		return nil, err
 372	}
 373	var responseMessages []Message
 374	var steps []StepResult
 375
 376	for {
 377		stepInputMessages := append(initialPrompt, responseMessages...)
 378		stepModel := a.settings.model
 379		stepSystemPrompt := a.settings.systemPrompt
 380		stepActiveTools := opts.ActiveTools
 381		stepToolChoice := ToolChoiceAuto
 382		disableAllTools := false
 383		stepTools := a.settings.tools
 384		if opts.PrepareStep != nil {
 385			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 386				Model:      stepModel,
 387				Steps:      steps,
 388				StepNumber: len(steps),
 389				Messages:   stepInputMessages,
 390			})
 391			if err != nil {
 392				return nil, err
 393			}
 394
 395			ctx = updatedCtx
 396
 397			// Apply prepared step modifications
 398			if prepared.Messages != nil {
 399				stepInputMessages = prepared.Messages
 400			}
 401			if prepared.Model != nil {
 402				stepModel = prepared.Model
 403			}
 404			if prepared.System != nil {
 405				stepSystemPrompt = *prepared.System
 406			}
 407			if prepared.ToolChoice != nil {
 408				stepToolChoice = *prepared.ToolChoice
 409			}
 410			if len(prepared.ActiveTools) > 0 {
 411				stepActiveTools = prepared.ActiveTools
 412			}
 413			disableAllTools = prepared.DisableAllTools
 414			if prepared.Tools != nil {
 415				stepTools = prepared.Tools
 416			}
 417		}
 418
 419		// Recreate prompt with potentially modified system prompt
 420		if stepSystemPrompt != a.settings.systemPrompt {
 421			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 422			if err != nil {
 423				return nil, err
 424			}
 425			// Replace system message part, keep the rest
 426			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 427				stepInputMessages[0] = stepPrompt[0] // Replace system message
 428			}
 429		}
 430
 431		preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
 432
 433		retryOptions := DefaultRetryOptions()
 434		if opts.MaxRetries != nil {
 435			retryOptions.MaxRetries = *opts.MaxRetries
 436		}
 437		retryOptions.OnRetry = opts.OnRetry
 438		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 439
 440		result, err := retry(ctx, func() (*Response, error) {
 441			return stepModel.Generate(ctx, Call{
 442				Prompt:           stepInputMessages,
 443				MaxOutputTokens:  opts.MaxOutputTokens,
 444				Temperature:      opts.Temperature,
 445				TopP:             opts.TopP,
 446				TopK:             opts.TopK,
 447				PresencePenalty:  opts.PresencePenalty,
 448				FrequencyPenalty: opts.FrequencyPenalty,
 449				Tools:            preparedTools,
 450				ToolChoice:       &stepToolChoice,
 451				ProviderOptions:  opts.ProviderOptions,
 452			})
 453		})
 454		if err != nil {
 455			return nil, err
 456		}
 457
 458		var stepToolCalls []ToolCallContent
 459		for _, content := range result.Content {
 460			if content.GetType() == ContentTypeToolCall {
 461				toolCall, ok := AsContentType[ToolCallContent](content)
 462				if !ok {
 463					continue
 464				}
 465
 466				// Validate and potentially repair the tool call
 467				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 468				stepToolCalls = append(stepToolCalls, validatedToolCall)
 469			}
 470		}
 471
 472		toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
 473
 474		// Build step content with validated tool calls and tool results
 475		stepContent := []Content{}
 476		toolCallIndex := 0
 477		for _, content := range result.Content {
 478			if content.GetType() == ContentTypeToolCall {
 479				// Replace with validated tool call
 480				if toolCallIndex < len(stepToolCalls) {
 481					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 482					toolCallIndex++
 483				}
 484			} else {
 485				// Keep other content as-is
 486				stepContent = append(stepContent, content)
 487			}
 488		}
 489		// Add tool results
 490		for _, result := range toolResults {
 491			stepContent = append(stepContent, result)
 492		}
 493		currentStepMessages := toResponseMessages(stepContent)
 494		responseMessages = append(responseMessages, currentStepMessages...)
 495
 496		stepResult := StepResult{
 497			Response: Response{
 498				Content:          stepContent,
 499				FinishReason:     result.FinishReason,
 500				Usage:            result.Usage,
 501				Warnings:         result.Warnings,
 502				ProviderMetadata: result.ProviderMetadata,
 503			},
 504			Messages: currentStepMessages,
 505		}
 506		steps = append(steps, stepResult)
 507		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 508
 509		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 510			break
 511		}
 512	}
 513
 514	totalUsage := Usage{}
 515
 516	for _, step := range steps {
 517		usage := step.Usage
 518		totalUsage.InputTokens += usage.InputTokens
 519		totalUsage.OutputTokens += usage.OutputTokens
 520		totalUsage.ReasoningTokens += usage.ReasoningTokens
 521		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 522		totalUsage.CacheReadTokens += usage.CacheReadTokens
 523		totalUsage.TotalTokens += usage.TotalTokens
 524	}
 525
 526	agentResult := &AgentResult{
 527		Steps:      steps,
 528		Response:   steps[len(steps)-1].Response,
 529		TotalUsage: totalUsage,
 530	}
 531	return agentResult, nil
 532}
 533
 534func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 535	if len(conditions) == 0 {
 536		return false
 537	}
 538
 539	for _, condition := range conditions {
 540		if condition(steps) {
 541			return true
 542		}
 543	}
 544	return false
 545}
 546
 547func toResponseMessages(content []Content) []Message {
 548	var assistantParts []MessagePart
 549	var toolParts []MessagePart
 550
 551	for _, c := range content {
 552		switch c.GetType() {
 553		case ContentTypeText:
 554			text, ok := AsContentType[TextContent](c)
 555			if !ok {
 556				continue
 557			}
 558			assistantParts = append(assistantParts, TextPart{
 559				Text:            text.Text,
 560				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 561			})
 562		case ContentTypeReasoning:
 563			reasoning, ok := AsContentType[ReasoningContent](c)
 564			if !ok {
 565				continue
 566			}
 567			assistantParts = append(assistantParts, ReasoningPart{
 568				Text:            reasoning.Text,
 569				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 570			})
 571		case ContentTypeToolCall:
 572			toolCall, ok := AsContentType[ToolCallContent](c)
 573			if !ok {
 574				continue
 575			}
 576			assistantParts = append(assistantParts, ToolCallPart{
 577				ToolCallID:       toolCall.ToolCallID,
 578				ToolName:         toolCall.ToolName,
 579				Input:            toolCall.Input,
 580				ProviderExecuted: toolCall.ProviderExecuted,
 581				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 582			})
 583		case ContentTypeFile:
 584			file, ok := AsContentType[FileContent](c)
 585			if !ok {
 586				continue
 587			}
 588			assistantParts = append(assistantParts, FilePart{
 589				Data:            file.Data,
 590				MediaType:       file.MediaType,
 591				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 592			})
 593		case ContentTypeSource:
 594			// Sources are metadata about references used to generate the response.
 595			// They don't need to be included in the conversation messages.
 596			continue
 597		case ContentTypeToolResult:
 598			result, ok := AsContentType[ToolResultContent](c)
 599			if !ok {
 600				continue
 601			}
 602			toolParts = append(toolParts, ToolResultPart{
 603				ToolCallID:      result.ToolCallID,
 604				Output:          result.Result,
 605				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 606			})
 607		}
 608	}
 609
 610	var messages []Message
 611	if len(assistantParts) > 0 {
 612		messages = append(messages, Message{
 613			Role:    MessageRoleAssistant,
 614			Content: assistantParts,
 615		})
 616	}
 617	if len(toolParts) > 0 {
 618		messages = append(messages, Message{
 619			Role:    MessageRoleTool,
 620			Content: toolParts,
 621		})
 622	}
 623	return messages
 624}
 625
 626func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 627	if len(toolCalls) == 0 {
 628		return nil, nil
 629	}
 630
 631	// Create a map for quick tool lookup
 632	toolMap := make(map[string]AgentTool)
 633	for _, tool := range allTools {
 634		toolMap[tool.Info().Name] = tool
 635	}
 636
 637	// Execute all tool calls sequentially in order
 638	results := make([]ToolResultContent, 0, len(toolCalls))
 639
 640	for _, toolCall := range toolCalls {
 641		result, isCriticalError := a.executeSingleTool(ctx, toolMap, toolCall, toolResultCallback)
 642		results = append(results, result)
 643		if isCriticalError {
 644			if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
 645				return nil, errorResult.Error
 646			}
 647		}
 648	}
 649
 650	return results, nil
 651}
 652
 653// executeSingleTool executes a single tool and returns its result and a critical error flag.
 654func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
 655	result := ToolResultContent{
 656		ToolCallID:       toolCall.ToolCallID,
 657		ToolName:         toolCall.ToolName,
 658		ProviderExecuted: false,
 659	}
 660
 661	// Skip invalid tool calls - create error result (not critical)
 662	if toolCall.Invalid {
 663		result.Result = ToolResultOutputContentError{
 664			Error: toolCall.ValidationError,
 665		}
 666		if toolResultCallback != nil {
 667			_ = toolResultCallback(result)
 668		}
 669		return result, false
 670	}
 671
 672	tool, exists := toolMap[toolCall.ToolName]
 673	if !exists {
 674		result.Result = ToolResultOutputContentError{
 675			Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
 676		}
 677		if toolResultCallback != nil {
 678			_ = toolResultCallback(result)
 679		}
 680		return result, false
 681	}
 682
 683	// Execute the tool
 684	toolResult, err := tool.Run(ctx, ToolCall{
 685		ID:    toolCall.ToolCallID,
 686		Name:  toolCall.ToolName,
 687		Input: toolCall.Input,
 688	})
 689	if err != nil {
 690		result.Result = ToolResultOutputContentError{
 691			Error: err,
 692		}
 693		result.ClientMetadata = toolResult.Metadata
 694		if toolResultCallback != nil {
 695			_ = toolResultCallback(result)
 696		}
 697		return result, true
 698	}
 699
 700	result.ClientMetadata = toolResult.Metadata
 701	if toolResult.IsError {
 702		result.Result = ToolResultOutputContentError{
 703			Error: errors.New(toolResult.Content),
 704		}
 705	} else if toolResult.Type == "image" || toolResult.Type == "media" {
 706		result.Result = ToolResultOutputContentMedia{
 707			Data:      string(toolResult.Data),
 708			MediaType: toolResult.MediaType,
 709			Text:      toolResult.Content,
 710		}
 711	} else {
 712		result.Result = ToolResultOutputContentText{
 713			Text: toolResult.Content,
 714		}
 715	}
 716	if toolResultCallback != nil {
 717		_ = toolResultCallback(result)
 718	}
 719	return result, false
 720}
 721
 722// Stream implements Agent.
 723func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 724	// Convert AgentStreamCall to AgentCall for preparation
 725	call := AgentCall{
 726		Prompt:           opts.Prompt,
 727		Files:            opts.Files,
 728		Messages:         opts.Messages,
 729		MaxOutputTokens:  opts.MaxOutputTokens,
 730		Temperature:      opts.Temperature,
 731		TopP:             opts.TopP,
 732		TopK:             opts.TopK,
 733		PresencePenalty:  opts.PresencePenalty,
 734		FrequencyPenalty: opts.FrequencyPenalty,
 735		ActiveTools:      opts.ActiveTools,
 736		ProviderOptions:  opts.ProviderOptions,
 737		MaxRetries:       opts.MaxRetries,
 738		OnRetry:          opts.OnRetry,
 739		StopWhen:         opts.StopWhen,
 740		PrepareStep:      opts.PrepareStep,
 741		RepairToolCall:   opts.RepairToolCall,
 742	}
 743
 744	call = a.prepareCall(call)
 745
 746	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 747	if err != nil {
 748		return nil, err
 749	}
 750
 751	var responseMessages []Message
 752	var steps []StepResult
 753	var totalUsage Usage
 754
 755	// Start agent stream
 756	if opts.OnAgentStart != nil {
 757		opts.OnAgentStart()
 758	}
 759
 760	for stepNumber := 0; ; stepNumber++ {
 761		stepInputMessages := append(initialPrompt, responseMessages...)
 762		stepModel := a.settings.model
 763		stepSystemPrompt := a.settings.systemPrompt
 764		stepActiveTools := call.ActiveTools
 765		stepToolChoice := ToolChoiceAuto
 766		disableAllTools := false
 767		stepTools := a.settings.tools
 768		// Apply step preparation if provided
 769		if call.PrepareStep != nil {
 770			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 771				Model:      stepModel,
 772				Steps:      steps,
 773				StepNumber: stepNumber,
 774				Messages:   stepInputMessages,
 775			})
 776			if err != nil {
 777				return nil, err
 778			}
 779
 780			ctx = updatedCtx
 781
 782			if prepared.Messages != nil {
 783				stepInputMessages = prepared.Messages
 784			}
 785			if prepared.Model != nil {
 786				stepModel = prepared.Model
 787			}
 788			if prepared.System != nil {
 789				stepSystemPrompt = *prepared.System
 790			}
 791			if prepared.ToolChoice != nil {
 792				stepToolChoice = *prepared.ToolChoice
 793			}
 794			if len(prepared.ActiveTools) > 0 {
 795				stepActiveTools = prepared.ActiveTools
 796			}
 797			disableAllTools = prepared.DisableAllTools
 798			if prepared.Tools != nil {
 799				stepTools = prepared.Tools
 800			}
 801		}
 802
 803		// Recreate prompt with potentially modified system prompt
 804		if stepSystemPrompt != a.settings.systemPrompt {
 805			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 806			if err != nil {
 807				return nil, err
 808			}
 809			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 810				stepInputMessages[0] = stepPrompt[0]
 811			}
 812		}
 813
 814		preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
 815
 816		// Start step stream
 817		if opts.OnStepStart != nil {
 818			_ = opts.OnStepStart(stepNumber)
 819		}
 820
 821		// Create streaming call
 822		streamCall := Call{
 823			Prompt:           stepInputMessages,
 824			MaxOutputTokens:  call.MaxOutputTokens,
 825			Temperature:      call.Temperature,
 826			TopP:             call.TopP,
 827			TopK:             call.TopK,
 828			PresencePenalty:  call.PresencePenalty,
 829			FrequencyPenalty: call.FrequencyPenalty,
 830			Tools:            preparedTools,
 831			ToolChoice:       &stepToolChoice,
 832			ProviderOptions:  call.ProviderOptions,
 833		}
 834
 835		// Execute step with retry logic wrapping both stream creation and processing
 836		retryOptions := DefaultRetryOptions()
 837		if call.MaxRetries != nil {
 838			retryOptions.MaxRetries = *call.MaxRetries
 839		}
 840		retryOptions.OnRetry = call.OnRetry
 841		retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
 842
 843		result, err := retry(ctx, func() (stepExecutionResult, error) {
 844			// Create the stream
 845			stream, err := stepModel.Stream(ctx, streamCall)
 846			if err != nil {
 847				return stepExecutionResult{}, err
 848			}
 849
 850			// Process the stream
 851			result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
 852			if err != nil {
 853				return stepExecutionResult{}, err
 854			}
 855
 856			return result, nil
 857		})
 858		if err != nil {
 859			if opts.OnError != nil {
 860				opts.OnError(err)
 861			}
 862			return nil, err
 863		}
 864
 865		steps = append(steps, result.StepResult)
 866		totalUsage = addUsage(totalUsage, result.StepResult.Usage)
 867
 868		// Call step finished callback
 869		if opts.OnStepFinish != nil {
 870			_ = opts.OnStepFinish(result.StepResult)
 871		}
 872
 873		// Add step messages to response messages
 874		stepMessages := toResponseMessages(result.StepResult.Content)
 875		responseMessages = append(responseMessages, stepMessages...)
 876
 877		// Check stop conditions
 878		shouldStop := isStopConditionMet(call.StopWhen, steps)
 879		if shouldStop || !result.ShouldContinue {
 880			break
 881		}
 882	}
 883
 884	// Finish agent stream
 885	agentResult := &AgentResult{
 886		Steps:      steps,
 887		Response:   steps[len(steps)-1].Response,
 888		TotalUsage: totalUsage,
 889	}
 890
 891	if opts.OnFinish != nil {
 892		opts.OnFinish(agentResult)
 893	}
 894
 895	if opts.OnAgentFinish != nil {
 896		_ = opts.OnAgentFinish(agentResult)
 897	}
 898
 899	return agentResult, nil
 900}
 901
 902func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 903	preparedTools := make([]Tool, 0, len(tools))
 904
 905	// If explicitly disabling all tools, return no tools
 906	if disableAllTools {
 907		return preparedTools
 908	}
 909
 910	for _, tool := range tools {
 911		// If activeTools has items, only include tools in the list
 912		// If activeTools is empty, include all tools
 913		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 914			continue
 915		}
 916		info := tool.Info()
 917		inputSchema := map[string]any{
 918			"type":       "object",
 919			"properties": info.Parameters,
 920			"required":   info.Required,
 921		}
 922		schema.Normalize(inputSchema)
 923		preparedTools = append(preparedTools, FunctionTool{
 924			Name:            info.Name,
 925			Description:     info.Description,
 926			InputSchema:     inputSchema,
 927			ProviderOptions: tool.ProviderOptions(),
 928		})
 929	}
 930	return preparedTools
 931}
 932
 933// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 934func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 935	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 936		return toolCall
 937	} else { //nolint: revive
 938		if repairFunc != nil {
 939			repairOptions := ToolCallRepairOptions{
 940				OriginalToolCall: toolCall,
 941				ValidationError:  err,
 942				AvailableTools:   availableTools,
 943				SystemPrompt:     systemPrompt,
 944				Messages:         messages,
 945			}
 946
 947			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 948				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 949					return *repairedToolCall
 950				}
 951			}
 952		}
 953
 954		invalidToolCall := toolCall
 955		invalidToolCall.Invalid = true
 956		invalidToolCall.ValidationError = err
 957		return invalidToolCall
 958	}
 959}
 960
 961// validateToolCall validates a tool call against available tools and their schemas.
 962func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 963	var tool AgentTool
 964	for _, t := range availableTools {
 965		if t.Info().Name == toolCall.ToolName {
 966			tool = t
 967			break
 968		}
 969	}
 970
 971	if tool == nil {
 972		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 973	}
 974
 975	// Validate JSON parsing
 976	var input map[string]any
 977	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 978		return fmt.Errorf("invalid JSON input: %w", err)
 979	}
 980
 981	// Basic schema validation (check required fields)
 982	// TODO: more robust schema validation using JSON Schema or similar
 983	toolInfo := tool.Info()
 984	for _, required := range toolInfo.Required {
 985		if _, exists := input[required]; !exists {
 986			return fmt.Errorf("missing required parameter: %s", required)
 987		}
 988	}
 989	return nil
 990}
 991
 992func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 993	if prompt == "" {
 994		return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
 995	}
 996
 997	var preparedPrompt Prompt
 998
 999	if system != "" {
1000		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1001	}
1002	preparedPrompt = append(preparedPrompt, messages...)
1003	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1004	return preparedPrompt, nil
1005}
1006
1007// WithSystemPrompt sets the system prompt for the agent.
1008func WithSystemPrompt(prompt string) AgentOption {
1009	return func(s *agentSettings) {
1010		s.systemPrompt = prompt
1011	}
1012}
1013
1014// WithMaxOutputTokens sets the maximum output tokens for the agent.
1015func WithMaxOutputTokens(tokens int64) AgentOption {
1016	return func(s *agentSettings) {
1017		s.maxOutputTokens = &tokens
1018	}
1019}
1020
1021// WithTemperature sets the temperature for the agent.
1022func WithTemperature(temp float64) AgentOption {
1023	return func(s *agentSettings) {
1024		s.temperature = &temp
1025	}
1026}
1027
1028// WithTopP sets the top-p value for the agent.
1029func WithTopP(topP float64) AgentOption {
1030	return func(s *agentSettings) {
1031		s.topP = &topP
1032	}
1033}
1034
1035// WithTopK sets the top-k value for the agent.
1036func WithTopK(topK int64) AgentOption {
1037	return func(s *agentSettings) {
1038		s.topK = &topK
1039	}
1040}
1041
1042// WithPresencePenalty sets the presence penalty for the agent.
1043func WithPresencePenalty(penalty float64) AgentOption {
1044	return func(s *agentSettings) {
1045		s.presencePenalty = &penalty
1046	}
1047}
1048
1049// WithFrequencyPenalty sets the frequency penalty for the agent.
1050func WithFrequencyPenalty(penalty float64) AgentOption {
1051	return func(s *agentSettings) {
1052		s.frequencyPenalty = &penalty
1053	}
1054}
1055
1056// WithTools sets the tools for the agent.
1057func WithTools(tools ...AgentTool) AgentOption {
1058	return func(s *agentSettings) {
1059		s.tools = append(s.tools, tools...)
1060	}
1061}
1062
1063// WithStopConditions sets the stop conditions for the agent.
1064func WithStopConditions(conditions ...StopCondition) AgentOption {
1065	return func(s *agentSettings) {
1066		s.stopWhen = append(s.stopWhen, conditions...)
1067	}
1068}
1069
1070// WithPrepareStep sets the prepare step function for the agent.
1071func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1072	return func(s *agentSettings) {
1073		s.prepareStep = fn
1074	}
1075}
1076
1077// WithRepairToolCall sets the repair tool call function for the agent.
1078func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1079	return func(s *agentSettings) {
1080		s.repairToolCall = fn
1081	}
1082}
1083
1084// WithMaxRetries sets the maximum number of retries for the agent.
1085func WithMaxRetries(maxRetries int) AgentOption {
1086	return func(s *agentSettings) {
1087		s.maxRetries = &maxRetries
1088	}
1089}
1090
1091// WithOnRetry sets the retry callback for the agent.
1092func WithOnRetry(callback OnRetryCallback) AgentOption {
1093	return func(s *agentSettings) {
1094		s.onRetry = callback
1095	}
1096}
1097
1098// processStepStream processes a single step's stream and returns the step result.
1099func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1100	var stepContent []Content
1101	var stepToolCalls []ToolCallContent
1102	var stepUsage Usage
1103	stepFinishReason := FinishReasonUnknown
1104	var stepWarnings []CallWarning
1105	var stepProviderMetadata ProviderMetadata
1106
1107	activeToolCalls := make(map[string]*ToolCallContent)
1108	activeTextContent := make(map[string]string)
1109	type reasoningContent struct {
1110		content string
1111		options ProviderMetadata
1112	}
1113	activeReasoningContent := make(map[string]reasoningContent)
1114
1115	// Set up concurrent tool execution
1116	type toolExecutionRequest struct {
1117		toolCall ToolCallContent
1118		parallel bool
1119	}
1120	toolChan := make(chan toolExecutionRequest, 10)
1121	var toolExecutionWg sync.WaitGroup
1122	var toolStateMu sync.Mutex
1123	toolResults := make([]ToolResultContent, 0)
1124	var toolExecutionErr error
1125
1126	// Create a map for quick tool lookup
1127	toolMap := make(map[string]AgentTool)
1128	for _, tool := range stepTools {
1129		toolMap[tool.Info().Name] = tool
1130	}
1131
1132	// Semaphores for controlling parallelism
1133	parallelSem := make(chan struct{}, 5)
1134	var sequentialMu sync.Mutex
1135
1136	// Single coordinator goroutine that dispatches tools
1137	toolExecutionWg.Go(func() {
1138		for req := range toolChan {
1139			if req.parallel {
1140				parallelSem <- struct{}{}
1141				toolExecutionWg.Go(func() {
1142					defer func() { <-parallelSem }()
1143					result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1144					toolStateMu.Lock()
1145					toolResults = append(toolResults, result)
1146					if isCriticalError && toolExecutionErr == nil {
1147						if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1148							toolExecutionErr = errorResult.Error
1149						}
1150					}
1151					toolStateMu.Unlock()
1152				})
1153			} else {
1154				sequentialMu.Lock()
1155				result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1156				toolStateMu.Lock()
1157				toolResults = append(toolResults, result)
1158				if isCriticalError && toolExecutionErr == nil {
1159					if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1160						toolExecutionErr = errorResult.Error
1161					}
1162				}
1163				toolStateMu.Unlock()
1164				sequentialMu.Unlock()
1165			}
1166		}
1167	})
1168
1169	// Process stream parts
1170	for part := range stream {
1171		// Forward all parts to chunk callback
1172		if opts.OnChunk != nil {
1173			err := opts.OnChunk(part)
1174			if err != nil {
1175				return stepExecutionResult{}, err
1176			}
1177		}
1178
1179		switch part.Type {
1180		case StreamPartTypeWarnings:
1181			stepWarnings = part.Warnings
1182			if opts.OnWarnings != nil {
1183				err := opts.OnWarnings(part.Warnings)
1184				if err != nil {
1185					return stepExecutionResult{}, err
1186				}
1187			}
1188
1189		case StreamPartTypeTextStart:
1190			activeTextContent[part.ID] = ""
1191			if opts.OnTextStart != nil {
1192				err := opts.OnTextStart(part.ID)
1193				if err != nil {
1194					return stepExecutionResult{}, err
1195				}
1196			}
1197
1198		case StreamPartTypeTextDelta:
1199			if _, exists := activeTextContent[part.ID]; exists {
1200				activeTextContent[part.ID] += part.Delta
1201			}
1202			if opts.OnTextDelta != nil {
1203				err := opts.OnTextDelta(part.ID, part.Delta)
1204				if err != nil {
1205					return stepExecutionResult{}, err
1206				}
1207			}
1208
1209		case StreamPartTypeTextEnd:
1210			if text, exists := activeTextContent[part.ID]; exists {
1211				stepContent = append(stepContent, TextContent{
1212					Text:             text,
1213					ProviderMetadata: part.ProviderMetadata,
1214				})
1215				delete(activeTextContent, part.ID)
1216			}
1217			if opts.OnTextEnd != nil {
1218				err := opts.OnTextEnd(part.ID)
1219				if err != nil {
1220					return stepExecutionResult{}, err
1221				}
1222			}
1223
1224		case StreamPartTypeReasoningStart:
1225			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1226			if opts.OnReasoningStart != nil {
1227				content := ReasoningContent{
1228					Text:             part.Delta,
1229					ProviderMetadata: part.ProviderMetadata,
1230				}
1231				err := opts.OnReasoningStart(part.ID, content)
1232				if err != nil {
1233					return stepExecutionResult{}, err
1234				}
1235			}
1236
1237		case StreamPartTypeReasoningDelta:
1238			if active, exists := activeReasoningContent[part.ID]; exists {
1239				active.content += part.Delta
1240				active.options = part.ProviderMetadata
1241				activeReasoningContent[part.ID] = active
1242			}
1243			if opts.OnReasoningDelta != nil {
1244				err := opts.OnReasoningDelta(part.ID, part.Delta)
1245				if err != nil {
1246					return stepExecutionResult{}, err
1247				}
1248			}
1249
1250		case StreamPartTypeReasoningEnd:
1251			if active, exists := activeReasoningContent[part.ID]; exists {
1252				if part.ProviderMetadata != nil {
1253					active.options = part.ProviderMetadata
1254				}
1255				content := ReasoningContent{
1256					Text:             active.content,
1257					ProviderMetadata: active.options,
1258				}
1259				stepContent = append(stepContent, content)
1260				if opts.OnReasoningEnd != nil {
1261					err := opts.OnReasoningEnd(part.ID, content)
1262					if err != nil {
1263						return stepExecutionResult{}, err
1264					}
1265				}
1266				delete(activeReasoningContent, part.ID)
1267			}
1268
1269		case StreamPartTypeToolInputStart:
1270			activeToolCalls[part.ID] = &ToolCallContent{
1271				ToolCallID:       part.ID,
1272				ToolName:         part.ToolCallName,
1273				Input:            "",
1274				ProviderExecuted: part.ProviderExecuted,
1275			}
1276			if opts.OnToolInputStart != nil {
1277				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1278				if err != nil {
1279					return stepExecutionResult{}, err
1280				}
1281			}
1282
1283		case StreamPartTypeToolInputDelta:
1284			if toolCall, exists := activeToolCalls[part.ID]; exists {
1285				toolCall.Input += part.Delta
1286			}
1287			if opts.OnToolInputDelta != nil {
1288				err := opts.OnToolInputDelta(part.ID, part.Delta)
1289				if err != nil {
1290					return stepExecutionResult{}, err
1291				}
1292			}
1293
1294		case StreamPartTypeToolInputEnd:
1295			if opts.OnToolInputEnd != nil {
1296				err := opts.OnToolInputEnd(part.ID)
1297				if err != nil {
1298					return stepExecutionResult{}, err
1299				}
1300			}
1301
1302		case StreamPartTypeToolCall:
1303			toolCall := ToolCallContent{
1304				ToolCallID:       part.ID,
1305				ToolName:         part.ToolCallName,
1306				Input:            part.ToolCallInput,
1307				ProviderExecuted: part.ProviderExecuted,
1308				ProviderMetadata: part.ProviderMetadata,
1309			}
1310
1311			// Validate and potentially repair the tool call
1312			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1313			stepToolCalls = append(stepToolCalls, validatedToolCall)
1314			stepContent = append(stepContent, validatedToolCall)
1315
1316			if opts.OnToolCall != nil {
1317				err := opts.OnToolCall(validatedToolCall)
1318				if err != nil {
1319					return stepExecutionResult{}, err
1320				}
1321			}
1322
1323			// Determine if tool can run in parallel
1324			isParallel := false
1325			if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1326				isParallel = tool.Info().Parallel
1327			}
1328
1329			// Send tool call to execution channel
1330			toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1331
1332			// Clean up active tool call
1333			delete(activeToolCalls, part.ID)
1334
1335		case StreamPartTypeSource:
1336			sourceContent := SourceContent{
1337				SourceType:       part.SourceType,
1338				ID:               part.ID,
1339				URL:              part.URL,
1340				Title:            part.Title,
1341				ProviderMetadata: part.ProviderMetadata,
1342			}
1343			stepContent = append(stepContent, sourceContent)
1344			if opts.OnSource != nil {
1345				err := opts.OnSource(sourceContent)
1346				if err != nil {
1347					return stepExecutionResult{}, err
1348				}
1349			}
1350
1351		case StreamPartTypeFinish:
1352			stepUsage = part.Usage
1353			stepFinishReason = part.FinishReason
1354			stepProviderMetadata = part.ProviderMetadata
1355			if opts.OnStreamFinish != nil {
1356				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1357				if err != nil {
1358					return stepExecutionResult{}, err
1359				}
1360			}
1361
1362		case StreamPartTypeError:
1363			return stepExecutionResult{}, part.Error
1364		}
1365	}
1366
1367	// Close the tool execution channel and wait for all executions to complete
1368	close(toolChan)
1369	toolExecutionWg.Wait()
1370
1371	// Check for tool execution errors
1372	if toolExecutionErr != nil {
1373		return stepExecutionResult{}, toolExecutionErr
1374	}
1375
1376	// Add tool results to content if any
1377	if len(toolResults) > 0 {
1378		for _, result := range toolResults {
1379			stepContent = append(stepContent, result)
1380		}
1381	}
1382
1383	stepResult := StepResult{
1384		Response: Response{
1385			Content:          stepContent,
1386			FinishReason:     stepFinishReason,
1387			Usage:            stepUsage,
1388			Warnings:         stepWarnings,
1389			ProviderMetadata: stepProviderMetadata,
1390		},
1391		Messages: toResponseMessages(stepContent),
1392	}
1393
1394	// Determine if we should continue (has tool calls and not stopped)
1395	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1396
1397	return stepExecutionResult{
1398		StepResult:     stepResult,
1399		ShouldContinue: shouldContinue,
1400	}, nil
1401}
1402
1403func addUsage(a, b Usage) Usage {
1404	return Usage{
1405		InputTokens:         a.InputTokens + b.InputTokens,
1406		OutputTokens:        a.OutputTokens + b.OutputTokens,
1407		TotalTokens:         a.TotalTokens + b.TotalTokens,
1408		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1409		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1410		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1411	}
1412}
1413
1414// WithHeaders sets the headers for the agent.
1415func WithHeaders(headers map[string]string) AgentOption {
1416	return func(s *agentSettings) {
1417		s.headers = headers
1418	}
1419}
1420
1421// WithProviderOptions sets the provider options for the agent.
1422func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1423	return func(s *agentSettings) {
1424		s.providerOptions = providerOptions
1425	}
1426}