agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11)
  12
  13// StepResult represents the result of a single step in an agent execution.
  14type StepResult struct {
  15	Response
  16	Messages []Message
  17}
  18
  19// StopCondition defines a function that determines when an agent should stop executing.
  20type StopCondition = func(steps []StepResult) bool
  21
  22// StepCountIs returns a stop condition that stops after the specified number of steps.
  23func StepCountIs(stepCount int) StopCondition {
  24	return func(steps []StepResult) bool {
  25		return len(steps) >= stepCount
  26	}
  27}
  28
  29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  30func HasToolCall(toolName string) StopCondition {
  31	return func(steps []StepResult) bool {
  32		if len(steps) == 0 {
  33			return false
  34		}
  35		lastStep := steps[len(steps)-1]
  36		toolCalls := lastStep.Content.ToolCalls()
  37		for _, toolCall := range toolCalls {
  38			if toolCall.ToolName == toolName {
  39				return true
  40			}
  41		}
  42		return false
  43	}
  44}
  45
  46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  47func HasContent(contentType ContentType) StopCondition {
  48	return func(steps []StepResult) bool {
  49		if len(steps) == 0 {
  50			return false
  51		}
  52		lastStep := steps[len(steps)-1]
  53		for _, content := range lastStep.Content {
  54			if content.GetType() == contentType {
  55				return true
  56			}
  57		}
  58		return false
  59	}
  60}
  61
  62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  63func FinishReasonIs(reason FinishReason) StopCondition {
  64	return func(steps []StepResult) bool {
  65		if len(steps) == 0 {
  66			return false
  67		}
  68		lastStep := steps[len(steps)-1]
  69		return lastStep.FinishReason == reason
  70	}
  71}
  72
  73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  74func MaxTokensUsed(maxTokens int64) StopCondition {
  75	return func(steps []StepResult) bool {
  76		var totalTokens int64
  77		for _, step := range steps {
  78			totalTokens += step.Usage.TotalTokens
  79		}
  80		return totalTokens >= maxTokens
  81	}
  82}
  83
  84// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  85type PrepareStepFunctionOptions struct {
  86	Steps      []StepResult
  87	StepNumber int
  88	Model      LanguageModel
  89	Messages   []Message
  90}
  91
  92// PrepareStepResult contains the result of preparing a step in an agent execution.
  93type PrepareStepResult struct {
  94	Model           LanguageModel
  95	Messages        []Message
  96	System          *string
  97	ToolChoice      *ToolChoice
  98	ActiveTools     []string
  99	DisableAllTools bool
 100}
 101
 102// ToolCallRepairOptions contains the options for repairing a tool call.
 103type ToolCallRepairOptions struct {
 104	OriginalToolCall ToolCallContent
 105	ValidationError  error
 106	AvailableTools   []AgentTool
 107	SystemPrompt     string
 108	Messages         []Message
 109}
 110
 111type (
 112	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 113	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 114
 115	// OnStepFinishedFunction defines a function that is called when a step finishes.
 116	OnStepFinishedFunction = func(step StepResult)
 117
 118	// RepairToolCallFunction defines a function that repairs a tool call.
 119	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 120)
 121
 122type agentSettings struct {
 123	systemPrompt     string
 124	maxOutputTokens  *int64
 125	temperature      *float64
 126	topP             *float64
 127	topK             *int64
 128	presencePenalty  *float64
 129	frequencyPenalty *float64
 130	headers          map[string]string
 131	providerOptions  ProviderOptions
 132
 133	// TODO: add support for provider tools
 134	tools      []AgentTool
 135	maxRetries *int
 136
 137	model LanguageModel
 138
 139	stopWhen       []StopCondition
 140	prepareStep    PrepareStepFunction
 141	repairToolCall RepairToolCallFunction
 142	onRetry        OnRetryCallback
 143}
 144
 145// AgentCall represents a call to an agent.
 146type AgentCall struct {
 147	Prompt           string     `json:"prompt"`
 148	Files            []FilePart `json:"files"`
 149	Messages         []Message  `json:"messages"`
 150	MaxOutputTokens  *int64
 151	Temperature      *float64 `json:"temperature"`
 152	TopP             *float64 `json:"top_p"`
 153	TopK             *int64   `json:"top_k"`
 154	PresencePenalty  *float64 `json:"presence_penalty"`
 155	FrequencyPenalty *float64 `json:"frequency_penalty"`
 156	ActiveTools      []string `json:"active_tools"`
 157	ProviderOptions  ProviderOptions
 158	OnRetry          OnRetryCallback
 159	MaxRetries       *int
 160
 161	StopWhen       []StopCondition
 162	PrepareStep    PrepareStepFunction
 163	RepairToolCall RepairToolCallFunction
 164}
 165
 166// Agent-level callbacks.
 167type (
 168	// OnAgentStartFunc is called when agent starts.
 169	OnAgentStartFunc func()
 170
 171	// OnAgentFinishFunc is called when agent finishes.
 172	OnAgentFinishFunc func(result *AgentResult) error
 173
 174	// OnStepStartFunc is called when a step starts.
 175	OnStepStartFunc func(stepNumber int) error
 176
 177	// OnStepFinishFunc is called when a step finishes.
 178	OnStepFinishFunc func(stepResult StepResult) error
 179
 180	// OnFinishFunc is called when entire agent completes.
 181	OnFinishFunc func(result *AgentResult)
 182
 183	// OnErrorFunc is called when an error occurs.
 184	OnErrorFunc func(error)
 185)
 186
 187// Stream part callbacks - called for each corresponding stream part type.
 188type (
 189	// OnChunkFunc is called for each stream part (catch-all).
 190	OnChunkFunc func(StreamPart) error
 191
 192	// OnWarningsFunc is called for warnings.
 193	OnWarningsFunc func(warnings []CallWarning) error
 194
 195	// OnTextStartFunc is called when text starts.
 196	OnTextStartFunc func(id string) error
 197
 198	// OnTextDeltaFunc is called for text deltas.
 199	OnTextDeltaFunc func(id, text string) error
 200
 201	// OnTextEndFunc is called when text ends.
 202	OnTextEndFunc func(id string) error
 203
 204	// OnReasoningStartFunc is called when reasoning starts.
 205	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 206
 207	// OnReasoningDeltaFunc is called for reasoning deltas.
 208	OnReasoningDeltaFunc func(id, text string) error
 209
 210	// OnReasoningEndFunc is called when reasoning ends.
 211	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 212
 213	// OnToolInputStartFunc is called when tool input starts.
 214	OnToolInputStartFunc func(id, toolName string) error
 215
 216	// OnToolInputDeltaFunc is called for tool input deltas.
 217	OnToolInputDeltaFunc func(id, delta string) error
 218
 219	// OnToolInputEndFunc is called when tool input ends.
 220	OnToolInputEndFunc func(id string) error
 221
 222	// OnToolCallFunc is called when tool call is complete.
 223	OnToolCallFunc func(toolCall ToolCallContent) error
 224
 225	// OnBeforeToolExecutionFunc is called right before tool execution.
 226	// It can modify tool input and/or skip tool execution by returning (modifiedInput, skipExecution, error).
 227	// If skipExecution is true, a synthetic error result is created with the provided error.
 228	// If modifiedInput is not empty, it replaces the original tool input.
 229	OnBeforeToolExecutionFunc func(toolCall ToolCallContent) (modifiedInput string, skipExecution bool, err error)
 230
 231	// OnToolResultFunc is called when tool execution completes.
 232	OnToolResultFunc func(result ToolResultContent) error
 233
 234	// OnSourceFunc is called for source references.
 235	OnSourceFunc func(source SourceContent) error
 236
 237	// OnStreamFinishFunc is called when stream finishes.
 238	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 239)
 240
 241// AgentStreamCall represents a streaming call to an agent.
 242type AgentStreamCall struct {
 243	Prompt           string     `json:"prompt"`
 244	Files            []FilePart `json:"files"`
 245	Messages         []Message  `json:"messages"`
 246	MaxOutputTokens  *int64
 247	Temperature      *float64 `json:"temperature"`
 248	TopP             *float64 `json:"top_p"`
 249	TopK             *int64   `json:"top_k"`
 250	PresencePenalty  *float64 `json:"presence_penalty"`
 251	FrequencyPenalty *float64 `json:"frequency_penalty"`
 252	ActiveTools      []string `json:"active_tools"`
 253	Headers          map[string]string
 254	ProviderOptions  ProviderOptions
 255	OnRetry          OnRetryCallback
 256	MaxRetries       *int
 257
 258	StopWhen       []StopCondition
 259	PrepareStep    PrepareStepFunction
 260	RepairToolCall RepairToolCallFunction
 261
 262	// Agent-level callbacks
 263	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 264	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 265	OnStepStart   OnStepStartFunc   // Called when a step starts
 266	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 267	OnFinish      OnFinishFunc      // Called when entire agent completes
 268	OnError       OnErrorFunc       // Called when an error occurs
 269
 270	// Stream part callbacks - called for each corresponding stream part type
 271	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 272	OnWarnings       OnWarningsFunc       // Called for warnings
 273	OnTextStart      OnTextStartFunc      // Called when text starts
 274	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 275	OnTextEnd        OnTextEndFunc        // Called when text ends
 276	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 277	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 278	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 279	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 280	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 281	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 282	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 283	OnBeforeToolExecution OnBeforeToolExecutionFunc // Called right before tool execution
 284	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 285	OnSource         OnSourceFunc         // Called for source references
 286	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 287}
 288
 289// AgentResult represents the result of an agent execution.
 290type AgentResult struct {
 291	Steps []StepResult
 292	// Final response
 293	Response   Response
 294	TotalUsage Usage
 295}
 296
 297// Agent represents an AI agent that can generate responses and stream responses.
 298type Agent interface {
 299	Generate(context.Context, AgentCall) (*AgentResult, error)
 300	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 301}
 302
 303// AgentOption defines a function that configures agent settings.
 304type AgentOption = func(*agentSettings)
 305
 306type agent struct {
 307	settings agentSettings
 308}
 309
 310// NewAgent creates a new agent with the given language model and options.
 311func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 312	settings := agentSettings{
 313		model: model,
 314	}
 315	for _, o := range opts {
 316		o(&settings)
 317	}
 318	return &agent{
 319		settings: settings,
 320	}
 321}
 322
 323func (a *agent) prepareCall(call AgentCall) AgentCall {
 324	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 325	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 326	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 327	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 328	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 329	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 330	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 331
 332	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 333		call.StopWhen = a.settings.stopWhen
 334	}
 335	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 336		call.PrepareStep = a.settings.prepareStep
 337	}
 338	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 339		call.RepairToolCall = a.settings.repairToolCall
 340	}
 341	if call.OnRetry == nil && a.settings.onRetry != nil {
 342		call.OnRetry = a.settings.onRetry
 343	}
 344
 345	providerOptions := ProviderOptions{}
 346	if a.settings.providerOptions != nil {
 347		maps.Copy(providerOptions, a.settings.providerOptions)
 348	}
 349	if call.ProviderOptions != nil {
 350		maps.Copy(providerOptions, call.ProviderOptions)
 351	}
 352	call.ProviderOptions = providerOptions
 353
 354	headers := map[string]string{}
 355
 356	if a.settings.headers != nil {
 357		maps.Copy(headers, a.settings.headers)
 358	}
 359
 360	return call
 361}
 362
 363// Generate implements Agent.
 364func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 365	opts = a.prepareCall(opts)
 366	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 367	if err != nil {
 368		return nil, err
 369	}
 370	var responseMessages []Message
 371	var steps []StepResult
 372
 373	for {
 374		stepInputMessages := append(initialPrompt, responseMessages...)
 375		stepModel := a.settings.model
 376		stepSystemPrompt := a.settings.systemPrompt
 377		stepActiveTools := opts.ActiveTools
 378		stepToolChoice := ToolChoiceAuto
 379		disableAllTools := false
 380
 381		if opts.PrepareStep != nil {
 382			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 383				Model:      stepModel,
 384				Steps:      steps,
 385				StepNumber: len(steps),
 386				Messages:   stepInputMessages,
 387			})
 388			if err != nil {
 389				return nil, err
 390			}
 391
 392			ctx = updatedCtx
 393
 394			// Apply prepared step modifications
 395			if prepared.Messages != nil {
 396				stepInputMessages = prepared.Messages
 397			}
 398			if prepared.Model != nil {
 399				stepModel = prepared.Model
 400			}
 401			if prepared.System != nil {
 402				stepSystemPrompt = *prepared.System
 403			}
 404			if prepared.ToolChoice != nil {
 405				stepToolChoice = *prepared.ToolChoice
 406			}
 407			if len(prepared.ActiveTools) > 0 {
 408				stepActiveTools = prepared.ActiveTools
 409			}
 410			disableAllTools = prepared.DisableAllTools
 411		}
 412
 413		// Recreate prompt with potentially modified system prompt
 414		if stepSystemPrompt != a.settings.systemPrompt {
 415			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 416			if err != nil {
 417				return nil, err
 418			}
 419			// Replace system message part, keep the rest
 420			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 421				stepInputMessages[0] = stepPrompt[0] // Replace system message
 422			}
 423		}
 424
 425		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 426
 427		retryOptions := DefaultRetryOptions()
 428		if opts.MaxRetries != nil {
 429			retryOptions.MaxRetries = *opts.MaxRetries
 430		}
 431		retryOptions.OnRetry = opts.OnRetry
 432		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 433
 434		result, err := retry(ctx, func() (*Response, error) {
 435			return stepModel.Generate(ctx, Call{
 436				Prompt:           stepInputMessages,
 437				MaxOutputTokens:  opts.MaxOutputTokens,
 438				Temperature:      opts.Temperature,
 439				TopP:             opts.TopP,
 440				TopK:             opts.TopK,
 441				PresencePenalty:  opts.PresencePenalty,
 442				FrequencyPenalty: opts.FrequencyPenalty,
 443				Tools:            preparedTools,
 444				ToolChoice:       &stepToolChoice,
 445				ProviderOptions:  opts.ProviderOptions,
 446			})
 447		})
 448		if err != nil {
 449			return nil, err
 450		}
 451
 452		var stepToolCalls []ToolCallContent
 453		for _, content := range result.Content {
 454			if content.GetType() == ContentTypeToolCall {
 455				toolCall, ok := AsContentType[ToolCallContent](content)
 456				if !ok {
 457					continue
 458				}
 459
 460				// Validate and potentially repair the tool call
 461				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 462				stepToolCalls = append(stepToolCalls, validatedToolCall)
 463			}
 464		}
 465
 466		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil)
 467
 468		// Build step content with validated tool calls and tool results
 469		stepContent := []Content{}
 470		toolCallIndex := 0
 471		for _, content := range result.Content {
 472			if content.GetType() == ContentTypeToolCall {
 473				// Replace with validated tool call
 474				if toolCallIndex < len(stepToolCalls) {
 475					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 476					toolCallIndex++
 477				}
 478			} else {
 479				// Keep other content as-is
 480				stepContent = append(stepContent, content)
 481			}
 482		}
 483		// Add tool results
 484		for _, result := range toolResults {
 485			stepContent = append(stepContent, result)
 486		}
 487		currentStepMessages := toResponseMessages(stepContent)
 488		responseMessages = append(responseMessages, currentStepMessages...)
 489
 490		stepResult := StepResult{
 491			Response: Response{
 492				Content:          stepContent,
 493				FinishReason:     result.FinishReason,
 494				Usage:            result.Usage,
 495				Warnings:         result.Warnings,
 496				ProviderMetadata: result.ProviderMetadata,
 497			},
 498			Messages: currentStepMessages,
 499		}
 500		steps = append(steps, stepResult)
 501		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 502
 503		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 504			break
 505		}
 506	}
 507
 508	totalUsage := Usage{}
 509
 510	for _, step := range steps {
 511		usage := step.Usage
 512		totalUsage.InputTokens += usage.InputTokens
 513		totalUsage.OutputTokens += usage.OutputTokens
 514		totalUsage.ReasoningTokens += usage.ReasoningTokens
 515		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 516		totalUsage.CacheReadTokens += usage.CacheReadTokens
 517		totalUsage.TotalTokens += usage.TotalTokens
 518	}
 519
 520	agentResult := &AgentResult{
 521		Steps:      steps,
 522		Response:   steps[len(steps)-1].Response,
 523		TotalUsage: totalUsage,
 524	}
 525	return agentResult, nil
 526}
 527
 528func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 529	if len(conditions) == 0 {
 530		return false
 531	}
 532
 533	for _, condition := range conditions {
 534		if condition(steps) {
 535			return true
 536		}
 537	}
 538	return false
 539}
 540
 541func toResponseMessages(content []Content) []Message {
 542	var assistantParts []MessagePart
 543	var toolParts []MessagePart
 544
 545	for _, c := range content {
 546		switch c.GetType() {
 547		case ContentTypeText:
 548			text, ok := AsContentType[TextContent](c)
 549			if !ok {
 550				continue
 551			}
 552			assistantParts = append(assistantParts, TextPart{
 553				Text:            text.Text,
 554				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 555			})
 556		case ContentTypeReasoning:
 557			reasoning, ok := AsContentType[ReasoningContent](c)
 558			if !ok {
 559				continue
 560			}
 561			assistantParts = append(assistantParts, ReasoningPart{
 562				Text:            reasoning.Text,
 563				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 564			})
 565		case ContentTypeToolCall:
 566			toolCall, ok := AsContentType[ToolCallContent](c)
 567			if !ok {
 568				continue
 569			}
 570			assistantParts = append(assistantParts, ToolCallPart{
 571				ToolCallID:       toolCall.ToolCallID,
 572				ToolName:         toolCall.ToolName,
 573				Input:            toolCall.Input,
 574				ProviderExecuted: toolCall.ProviderExecuted,
 575				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 576			})
 577		case ContentTypeFile:
 578			file, ok := AsContentType[FileContent](c)
 579			if !ok {
 580				continue
 581			}
 582			assistantParts = append(assistantParts, FilePart{
 583				Data:            file.Data,
 584				MediaType:       file.MediaType,
 585				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 586			})
 587		case ContentTypeSource:
 588			// Sources are metadata about references used to generate the response.
 589			// They don't need to be included in the conversation messages.
 590			continue
 591		case ContentTypeToolResult:
 592			result, ok := AsContentType[ToolResultContent](c)
 593			if !ok {
 594				continue
 595			}
 596			toolParts = append(toolParts, ToolResultPart{
 597				ToolCallID:      result.ToolCallID,
 598				Output:          result.Result,
 599				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 600			})
 601		}
 602	}
 603
 604	var messages []Message
 605	if len(assistantParts) > 0 {
 606		messages = append(messages, Message{
 607			Role:    MessageRoleAssistant,
 608			Content: assistantParts,
 609		})
 610	}
 611	if len(toolParts) > 0 {
 612		messages = append(messages, Message{
 613			Role:    MessageRoleTool,
 614			Content: toolParts,
 615		})
 616	}
 617	return messages
 618}
 619
 620func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, beforeExecCallback OnBeforeToolExecutionFunc, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 621	if len(toolCalls) == 0 {
 622		return nil, nil
 623	}
 624
 625	// Create a map for quick tool lookup
 626	toolMap := make(map[string]AgentTool)
 627	for _, tool := range allTools {
 628		toolMap[tool.Info().Name] = tool
 629	}
 630
 631	// Execute all tool calls sequentially in order
 632	results := make([]ToolResultContent, 0, len(toolCalls))
 633
 634	for _, toolCall := range toolCalls {
 635		// Skip invalid tool calls - create error result
 636		if toolCall.Invalid {
 637			result := ToolResultContent{
 638				ToolCallID: toolCall.ToolCallID,
 639				ToolName:   toolCall.ToolName,
 640				Result: ToolResultOutputContentError{
 641					Error: toolCall.ValidationError,
 642				},
 643				ProviderExecuted: false,
 644			}
 645			results = append(results, result)
 646			if toolResultCallback != nil {
 647				if err := toolResultCallback(result); err != nil {
 648					return nil, err
 649				}
 650			}
 651			continue
 652		}
 653
 654		tool, exists := toolMap[toolCall.ToolName]
 655		if !exists {
 656			result := ToolResultContent{
 657				ToolCallID: toolCall.ToolCallID,
 658				ToolName:   toolCall.ToolName,
 659				Result: ToolResultOutputContentError{
 660					Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
 661				},
 662				ProviderExecuted: false,
 663			}
 664			results = append(results, result)
 665			if toolResultCallback != nil {
 666				if err := toolResultCallback(result); err != nil {
 667					return nil, err
 668				}
 669			}
 670			continue
 671		}
 672
 673		// Call OnBeforeToolExecution hook to allow modification or skipping
 674		inputToUse := toolCall.Input
 675		if beforeExecCallback != nil {
 676			modifiedInput, skipExecution, err := beforeExecCallback(toolCall)
 677			if err != nil {
 678				// Callback returned error - create synthetic error result
 679				result := ToolResultContent{
 680					ToolCallID: toolCall.ToolCallID,
 681					ToolName:   toolCall.ToolName,
 682					Result: ToolResultOutputContentError{
 683						Error: err,
 684					},
 685					ProviderExecuted: false,
 686				}
 687				results = append(results, result)
 688				if toolResultCallback != nil {
 689					if cbErr := toolResultCallback(result); cbErr != nil {
 690						return nil, cbErr
 691					}
 692				}
 693				if skipExecution {
 694					// Hook wants to skip execution
 695					continue
 696				} else {
 697					// Hook returned error but doesn't want to skip - stop execution
 698					return nil, err
 699				}
 700			}
 701			if skipExecution {
 702				// Hook wants to skip execution without error
 703				result := ToolResultContent{
 704					ToolCallID: toolCall.ToolCallID,
 705					ToolName:   toolCall.ToolName,
 706					Result: ToolResultOutputContentError{
 707						Error: errors.New("Tool execution skipped by hook"),
 708					},
 709					ProviderExecuted: false,
 710				}
 711				results = append(results, result)
 712				if toolResultCallback != nil {
 713					if cbErr := toolResultCallback(result); cbErr != nil {
 714						return nil, cbErr
 715					}
 716				}
 717				continue
 718			}
 719			if modifiedInput != "" {
 720				inputToUse = modifiedInput
 721			}
 722		}
 723
 724		// Execute the tool
 725		toolResult, err := tool.Run(ctx, ToolCall{
 726			ID:    toolCall.ToolCallID,
 727			Name:  toolCall.ToolName,
 728			Input: inputToUse,
 729		})
 730		if err != nil {
 731			result := ToolResultContent{
 732				ToolCallID: toolCall.ToolCallID,
 733				ToolName:   toolCall.ToolName,
 734				Result: ToolResultOutputContentError{
 735					Error: err,
 736				},
 737				ClientMetadata:   toolResult.Metadata,
 738				ProviderExecuted: false,
 739			}
 740			if toolResultCallback != nil {
 741				if cbErr := toolResultCallback(result); cbErr != nil {
 742					return nil, cbErr
 743				}
 744			}
 745			return nil, err
 746		}
 747
 748		var result ToolResultContent
 749		if toolResult.IsError {
 750			result = ToolResultContent{
 751				ToolCallID: toolCall.ToolCallID,
 752				ToolName:   toolCall.ToolName,
 753				Result: ToolResultOutputContentError{
 754					Error: errors.New(toolResult.Content),
 755				},
 756				ClientMetadata:   toolResult.Metadata,
 757				ProviderExecuted: false,
 758			}
 759		} else {
 760			result = ToolResultContent{
 761				ToolCallID: toolCall.ToolCallID,
 762				ToolName:   toolCall.ToolName,
 763				Result: ToolResultOutputContentText{
 764					Text: toolResult.Content,
 765				},
 766				ClientMetadata:   toolResult.Metadata,
 767				ProviderExecuted: false,
 768			}
 769		}
 770		results = append(results, result)
 771		if toolResultCallback != nil {
 772			if err := toolResultCallback(result); err != nil {
 773				return nil, err
 774			}
 775		}
 776	}
 777
 778	return results, nil
 779}
 780
 781// Stream implements Agent.
 782func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 783	// Convert AgentStreamCall to AgentCall for preparation
 784	call := AgentCall{
 785		Prompt:           opts.Prompt,
 786		Files:            opts.Files,
 787		Messages:         opts.Messages,
 788		MaxOutputTokens:  opts.MaxOutputTokens,
 789		Temperature:      opts.Temperature,
 790		TopP:             opts.TopP,
 791		TopK:             opts.TopK,
 792		PresencePenalty:  opts.PresencePenalty,
 793		FrequencyPenalty: opts.FrequencyPenalty,
 794		ActiveTools:      opts.ActiveTools,
 795		ProviderOptions:  opts.ProviderOptions,
 796		MaxRetries:       opts.MaxRetries,
 797		StopWhen:         opts.StopWhen,
 798		PrepareStep:      opts.PrepareStep,
 799		RepairToolCall:   opts.RepairToolCall,
 800	}
 801
 802	call = a.prepareCall(call)
 803
 804	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 805	if err != nil {
 806		return nil, err
 807	}
 808
 809	var responseMessages []Message
 810	var steps []StepResult
 811	var totalUsage Usage
 812
 813	// Start agent stream
 814	if opts.OnAgentStart != nil {
 815		opts.OnAgentStart()
 816	}
 817
 818	for stepNumber := 0; ; stepNumber++ {
 819		stepInputMessages := append(initialPrompt, responseMessages...)
 820		stepModel := a.settings.model
 821		stepSystemPrompt := a.settings.systemPrompt
 822		stepActiveTools := call.ActiveTools
 823		stepToolChoice := ToolChoiceAuto
 824		disableAllTools := false
 825
 826		// Apply step preparation if provided
 827		if call.PrepareStep != nil {
 828			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 829				Model:      stepModel,
 830				Steps:      steps,
 831				StepNumber: stepNumber,
 832				Messages:   stepInputMessages,
 833			})
 834			if err != nil {
 835				return nil, err
 836			}
 837
 838			ctx = updatedCtx
 839
 840			if prepared.Messages != nil {
 841				stepInputMessages = prepared.Messages
 842			}
 843			if prepared.Model != nil {
 844				stepModel = prepared.Model
 845			}
 846			if prepared.System != nil {
 847				stepSystemPrompt = *prepared.System
 848			}
 849			if prepared.ToolChoice != nil {
 850				stepToolChoice = *prepared.ToolChoice
 851			}
 852			if len(prepared.ActiveTools) > 0 {
 853				stepActiveTools = prepared.ActiveTools
 854			}
 855			disableAllTools = prepared.DisableAllTools
 856		}
 857
 858		// Recreate prompt with potentially modified system prompt
 859		if stepSystemPrompt != a.settings.systemPrompt {
 860			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 861			if err != nil {
 862				return nil, err
 863			}
 864			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 865				stepInputMessages[0] = stepPrompt[0]
 866			}
 867		}
 868
 869		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 870
 871		// Start step stream
 872		if opts.OnStepStart != nil {
 873			_ = opts.OnStepStart(stepNumber)
 874		}
 875
 876		// Create streaming call
 877		streamCall := Call{
 878			Prompt:           stepInputMessages,
 879			MaxOutputTokens:  call.MaxOutputTokens,
 880			Temperature:      call.Temperature,
 881			TopP:             call.TopP,
 882			TopK:             call.TopK,
 883			PresencePenalty:  call.PresencePenalty,
 884			FrequencyPenalty: call.FrequencyPenalty,
 885			Tools:            preparedTools,
 886			ToolChoice:       &stepToolChoice,
 887			ProviderOptions:  call.ProviderOptions,
 888		}
 889
 890		// Get streaming response with retry logic
 891		retryOptions := DefaultRetryOptions()
 892		if call.MaxRetries != nil {
 893			retryOptions.MaxRetries = *call.MaxRetries
 894		}
 895		retryOptions.OnRetry = call.OnRetry
 896		retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions)
 897
 898		stream, err := retry(ctx, func() (StreamResponse, error) {
 899			return stepModel.Stream(ctx, streamCall)
 900		})
 901		if err != nil {
 902			if opts.OnError != nil {
 903				opts.OnError(err)
 904			}
 905			return nil, err
 906		}
 907
 908		// Process stream with tool execution
 909		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 910		if err != nil {
 911			if opts.OnError != nil {
 912				opts.OnError(err)
 913			}
 914			return nil, err
 915		}
 916
 917		steps = append(steps, stepResult)
 918		totalUsage = addUsage(totalUsage, stepResult.Usage)
 919
 920		// Call step finished callback
 921		if opts.OnStepFinish != nil {
 922			_ = opts.OnStepFinish(stepResult)
 923		}
 924
 925		// Add step messages to response messages
 926		stepMessages := toResponseMessages(stepResult.Content)
 927		responseMessages = append(responseMessages, stepMessages...)
 928
 929		// Check stop conditions
 930		shouldStop := isStopConditionMet(call.StopWhen, steps)
 931		if shouldStop || !shouldContinue {
 932			break
 933		}
 934	}
 935
 936	// Finish agent stream
 937	agentResult := &AgentResult{
 938		Steps:      steps,
 939		Response:   steps[len(steps)-1].Response,
 940		TotalUsage: totalUsage,
 941	}
 942
 943	if opts.OnFinish != nil {
 944		opts.OnFinish(agentResult)
 945	}
 946
 947	if opts.OnAgentFinish != nil {
 948		_ = opts.OnAgentFinish(agentResult)
 949	}
 950
 951	return agentResult, nil
 952}
 953
 954func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 955	preparedTools := make([]Tool, 0, len(tools))
 956
 957	// If explicitly disabling all tools, return no tools
 958	if disableAllTools {
 959		return preparedTools
 960	}
 961
 962	for _, tool := range tools {
 963		// If activeTools has items, only include tools in the list
 964		// If activeTools is empty, include all tools
 965		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 966			continue
 967		}
 968		info := tool.Info()
 969		preparedTools = append(preparedTools, FunctionTool{
 970			Name:        info.Name,
 971			Description: info.Description,
 972			InputSchema: map[string]any{
 973				"type":       "object",
 974				"properties": info.Parameters,
 975				"required":   info.Required,
 976			},
 977			ProviderOptions: tool.ProviderOptions(),
 978		})
 979	}
 980	return preparedTools
 981}
 982
 983// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 984func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 985	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 986		return toolCall
 987	} else { //nolint: revive
 988		if repairFunc != nil {
 989			repairOptions := ToolCallRepairOptions{
 990				OriginalToolCall: toolCall,
 991				ValidationError:  err,
 992				AvailableTools:   availableTools,
 993				SystemPrompt:     systemPrompt,
 994				Messages:         messages,
 995			}
 996
 997			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 998				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 999					return *repairedToolCall
1000				}
1001			}
1002		}
1003
1004		invalidToolCall := toolCall
1005		invalidToolCall.Invalid = true
1006		invalidToolCall.ValidationError = err
1007		return invalidToolCall
1008	}
1009}
1010
1011// validateToolCall validates a tool call against available tools and their schemas.
1012func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
1013	var tool AgentTool
1014	for _, t := range availableTools {
1015		if t.Info().Name == toolCall.ToolName {
1016			tool = t
1017			break
1018		}
1019	}
1020
1021	if tool == nil {
1022		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1023	}
1024
1025	// Validate JSON parsing
1026	var input map[string]any
1027	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1028		return fmt.Errorf("invalid JSON input: %w", err)
1029	}
1030
1031	// Basic schema validation (check required fields)
1032	// TODO: more robust schema validation using JSON Schema or similar
1033	toolInfo := tool.Info()
1034	for _, required := range toolInfo.Required {
1035		if _, exists := input[required]; !exists {
1036			return fmt.Errorf("missing required parameter: %s", required)
1037		}
1038	}
1039	return nil
1040}
1041
1042func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1043	if prompt == "" {
1044		return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
1045	}
1046
1047	var preparedPrompt Prompt
1048
1049	if system != "" {
1050		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1051	}
1052	preparedPrompt = append(preparedPrompt, messages...)
1053	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1054	return preparedPrompt, nil
1055}
1056
1057// WithSystemPrompt sets the system prompt for the agent.
1058func WithSystemPrompt(prompt string) AgentOption {
1059	return func(s *agentSettings) {
1060		s.systemPrompt = prompt
1061	}
1062}
1063
1064// WithMaxOutputTokens sets the maximum output tokens for the agent.
1065func WithMaxOutputTokens(tokens int64) AgentOption {
1066	return func(s *agentSettings) {
1067		s.maxOutputTokens = &tokens
1068	}
1069}
1070
1071// WithTemperature sets the temperature for the agent.
1072func WithTemperature(temp float64) AgentOption {
1073	return func(s *agentSettings) {
1074		s.temperature = &temp
1075	}
1076}
1077
1078// WithTopP sets the top-p value for the agent.
1079func WithTopP(topP float64) AgentOption {
1080	return func(s *agentSettings) {
1081		s.topP = &topP
1082	}
1083}
1084
1085// WithTopK sets the top-k value for the agent.
1086func WithTopK(topK int64) AgentOption {
1087	return func(s *agentSettings) {
1088		s.topK = &topK
1089	}
1090}
1091
1092// WithPresencePenalty sets the presence penalty for the agent.
1093func WithPresencePenalty(penalty float64) AgentOption {
1094	return func(s *agentSettings) {
1095		s.presencePenalty = &penalty
1096	}
1097}
1098
1099// WithFrequencyPenalty sets the frequency penalty for the agent.
1100func WithFrequencyPenalty(penalty float64) AgentOption {
1101	return func(s *agentSettings) {
1102		s.frequencyPenalty = &penalty
1103	}
1104}
1105
1106// WithTools sets the tools for the agent.
1107func WithTools(tools ...AgentTool) AgentOption {
1108	return func(s *agentSettings) {
1109		s.tools = append(s.tools, tools...)
1110	}
1111}
1112
1113// WithStopConditions sets the stop conditions for the agent.
1114func WithStopConditions(conditions ...StopCondition) AgentOption {
1115	return func(s *agentSettings) {
1116		s.stopWhen = append(s.stopWhen, conditions...)
1117	}
1118}
1119
1120// WithPrepareStep sets the prepare step function for the agent.
1121func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1122	return func(s *agentSettings) {
1123		s.prepareStep = fn
1124	}
1125}
1126
1127// WithRepairToolCall sets the repair tool call function for the agent.
1128func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1129	return func(s *agentSettings) {
1130		s.repairToolCall = fn
1131	}
1132}
1133
1134// WithMaxRetries sets the maximum number of retries for the agent.
1135func WithMaxRetries(maxRetries int) AgentOption {
1136	return func(s *agentSettings) {
1137		s.maxRetries = &maxRetries
1138	}
1139}
1140
1141// WithOnRetry sets the retry callback for the agent.
1142func WithOnRetry(callback OnRetryCallback) AgentOption {
1143	return func(s *agentSettings) {
1144		s.onRetry = callback
1145	}
1146}
1147
1148// processStepStream processes a single step's stream and returns the step result.
1149func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1150	var stepContent []Content
1151	var stepToolCalls []ToolCallContent
1152	var stepUsage Usage
1153	stepFinishReason := FinishReasonUnknown
1154	var stepWarnings []CallWarning
1155	var stepProviderMetadata ProviderMetadata
1156
1157	activeToolCalls := make(map[string]*ToolCallContent)
1158	activeTextContent := make(map[string]string)
1159	type reasoningContent struct {
1160		content string
1161		options ProviderMetadata
1162	}
1163	activeReasoningContent := make(map[string]reasoningContent)
1164
1165	// Process stream parts
1166	for part := range stream {
1167		// Forward all parts to chunk callback
1168		if opts.OnChunk != nil {
1169			err := opts.OnChunk(part)
1170			if err != nil {
1171				return StepResult{}, false, err
1172			}
1173		}
1174
1175		switch part.Type {
1176		case StreamPartTypeWarnings:
1177			stepWarnings = part.Warnings
1178			if opts.OnWarnings != nil {
1179				err := opts.OnWarnings(part.Warnings)
1180				if err != nil {
1181					return StepResult{}, false, err
1182				}
1183			}
1184
1185		case StreamPartTypeTextStart:
1186			activeTextContent[part.ID] = ""
1187			if opts.OnTextStart != nil {
1188				err := opts.OnTextStart(part.ID)
1189				if err != nil {
1190					return StepResult{}, false, err
1191				}
1192			}
1193
1194		case StreamPartTypeTextDelta:
1195			if _, exists := activeTextContent[part.ID]; exists {
1196				activeTextContent[part.ID] += part.Delta
1197			}
1198			if opts.OnTextDelta != nil {
1199				err := opts.OnTextDelta(part.ID, part.Delta)
1200				if err != nil {
1201					return StepResult{}, false, err
1202				}
1203			}
1204
1205		case StreamPartTypeTextEnd:
1206			if text, exists := activeTextContent[part.ID]; exists {
1207				stepContent = append(stepContent, TextContent{
1208					Text:             text,
1209					ProviderMetadata: part.ProviderMetadata,
1210				})
1211				delete(activeTextContent, part.ID)
1212			}
1213			if opts.OnTextEnd != nil {
1214				err := opts.OnTextEnd(part.ID)
1215				if err != nil {
1216					return StepResult{}, false, err
1217				}
1218			}
1219
1220		case StreamPartTypeReasoningStart:
1221			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1222			if opts.OnReasoningStart != nil {
1223				content := ReasoningContent{
1224					Text:             part.Delta,
1225					ProviderMetadata: part.ProviderMetadata,
1226				}
1227				err := opts.OnReasoningStart(part.ID, content)
1228				if err != nil {
1229					return StepResult{}, false, err
1230				}
1231			}
1232
1233		case StreamPartTypeReasoningDelta:
1234			if active, exists := activeReasoningContent[part.ID]; exists {
1235				active.content += part.Delta
1236				active.options = part.ProviderMetadata
1237				activeReasoningContent[part.ID] = active
1238			}
1239			if opts.OnReasoningDelta != nil {
1240				err := opts.OnReasoningDelta(part.ID, part.Delta)
1241				if err != nil {
1242					return StepResult{}, false, err
1243				}
1244			}
1245
1246		case StreamPartTypeReasoningEnd:
1247			if active, exists := activeReasoningContent[part.ID]; exists {
1248				if part.ProviderMetadata != nil {
1249					active.options = part.ProviderMetadata
1250				}
1251				content := ReasoningContent{
1252					Text:             active.content,
1253					ProviderMetadata: active.options,
1254				}
1255				stepContent = append(stepContent, content)
1256				if opts.OnReasoningEnd != nil {
1257					err := opts.OnReasoningEnd(part.ID, content)
1258					if err != nil {
1259						return StepResult{}, false, err
1260					}
1261				}
1262				delete(activeReasoningContent, part.ID)
1263			}
1264
1265		case StreamPartTypeToolInputStart:
1266			activeToolCalls[part.ID] = &ToolCallContent{
1267				ToolCallID:       part.ID,
1268				ToolName:         part.ToolCallName,
1269				Input:            "",
1270				ProviderExecuted: part.ProviderExecuted,
1271			}
1272			if opts.OnToolInputStart != nil {
1273				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1274				if err != nil {
1275					return StepResult{}, false, err
1276				}
1277			}
1278
1279		case StreamPartTypeToolInputDelta:
1280			if toolCall, exists := activeToolCalls[part.ID]; exists {
1281				toolCall.Input += part.Delta
1282			}
1283			if opts.OnToolInputDelta != nil {
1284				err := opts.OnToolInputDelta(part.ID, part.Delta)
1285				if err != nil {
1286					return StepResult{}, false, err
1287				}
1288			}
1289
1290		case StreamPartTypeToolInputEnd:
1291			if opts.OnToolInputEnd != nil {
1292				err := opts.OnToolInputEnd(part.ID)
1293				if err != nil {
1294					return StepResult{}, false, err
1295				}
1296			}
1297
1298		case StreamPartTypeToolCall:
1299			toolCall := ToolCallContent{
1300				ToolCallID:       part.ID,
1301				ToolName:         part.ToolCallName,
1302				Input:            part.ToolCallInput,
1303				ProviderExecuted: part.ProviderExecuted,
1304				ProviderMetadata: part.ProviderMetadata,
1305			}
1306
1307			// Validate and potentially repair the tool call
1308			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1309			stepToolCalls = append(stepToolCalls, validatedToolCall)
1310			stepContent = append(stepContent, validatedToolCall)
1311
1312			if opts.OnToolCall != nil {
1313				err := opts.OnToolCall(validatedToolCall)
1314				if err != nil {
1315					return StepResult{}, false, err
1316				}
1317			}
1318
1319			// Clean up active tool call
1320			delete(activeToolCalls, part.ID)
1321
1322		case StreamPartTypeSource:
1323			sourceContent := SourceContent{
1324				SourceType:       part.SourceType,
1325				ID:               part.ID,
1326				URL:              part.URL,
1327				Title:            part.Title,
1328				ProviderMetadata: part.ProviderMetadata,
1329			}
1330			stepContent = append(stepContent, sourceContent)
1331			if opts.OnSource != nil {
1332				err := opts.OnSource(sourceContent)
1333				if err != nil {
1334					return StepResult{}, false, err
1335				}
1336			}
1337
1338		case StreamPartTypeFinish:
1339			stepUsage = part.Usage
1340			stepFinishReason = part.FinishReason
1341			stepProviderMetadata = part.ProviderMetadata
1342			if opts.OnStreamFinish != nil {
1343				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1344				if err != nil {
1345					return StepResult{}, false, err
1346				}
1347			}
1348
1349		case StreamPartTypeError:
1350			if opts.OnError != nil {
1351				opts.OnError(part.Error)
1352			}
1353			return StepResult{}, false, part.Error
1354		}
1355	}
1356
1357	// Execute tools if any
1358	var toolResults []ToolResultContent
1359	if len(stepToolCalls) > 0 {
1360		var err error
1361		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnBeforeToolExecution, opts.OnToolResult)
1362		if err != nil {
1363			return StepResult{}, false, err
1364		}
1365		// Add tool results to content
1366		for _, result := range toolResults {
1367			stepContent = append(stepContent, result)
1368		}
1369	}
1370
1371	stepResult := StepResult{
1372		Response: Response{
1373			Content:          stepContent,
1374			FinishReason:     stepFinishReason,
1375			Usage:            stepUsage,
1376			Warnings:         stepWarnings,
1377			ProviderMetadata: stepProviderMetadata,
1378		},
1379		Messages: toResponseMessages(stepContent),
1380	}
1381
1382	// Determine if we should continue (has tool calls and not stopped)
1383	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1384
1385	return stepResult, shouldContinue, nil
1386}
1387
1388func addUsage(a, b Usage) Usage {
1389	return Usage{
1390		InputTokens:         a.InputTokens + b.InputTokens,
1391		OutputTokens:        a.OutputTokens + b.OutputTokens,
1392		TotalTokens:         a.TotalTokens + b.TotalTokens,
1393		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1394		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1395		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1396	}
1397}
1398
1399// WithHeaders sets the headers for the agent.
1400func WithHeaders(headers map[string]string) AgentOption {
1401	return func(s *agentSettings) {
1402		s.headers = headers
1403	}
1404}
1405
1406// WithProviderOptions sets the provider options for the agent.
1407func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1408	return func(s *agentSettings) {
1409		s.providerOptions = providerOptions
1410	}
1411}