agent.go

   1package ai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11	"sync"
  12)
  13
  14type StepResult struct {
  15	Response
  16	Messages []Message
  17}
  18
  19type StopCondition = func(steps []StepResult) bool
  20
  21// StepCountIs returns a stop condition that stops after the specified number of steps.
  22func StepCountIs(stepCount int) StopCondition {
  23	return func(steps []StepResult) bool {
  24		return len(steps) >= stepCount
  25	}
  26}
  27
  28// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  29func HasToolCall(toolName string) StopCondition {
  30	return func(steps []StepResult) bool {
  31		if len(steps) == 0 {
  32			return false
  33		}
  34		lastStep := steps[len(steps)-1]
  35		toolCalls := lastStep.Content.ToolCalls()
  36		for _, toolCall := range toolCalls {
  37			if toolCall.ToolName == toolName {
  38				return true
  39			}
  40		}
  41		return false
  42	}
  43}
  44
  45// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  46func HasContent(contentType ContentType) StopCondition {
  47	return func(steps []StepResult) bool {
  48		if len(steps) == 0 {
  49			return false
  50		}
  51		lastStep := steps[len(steps)-1]
  52		for _, content := range lastStep.Content {
  53			if content.GetType() == contentType {
  54				return true
  55			}
  56		}
  57		return false
  58	}
  59}
  60
  61// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  62func FinishReasonIs(reason FinishReason) StopCondition {
  63	return func(steps []StepResult) bool {
  64		if len(steps) == 0 {
  65			return false
  66		}
  67		lastStep := steps[len(steps)-1]
  68		return lastStep.FinishReason == reason
  69	}
  70}
  71
  72// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  73func MaxTokensUsed(maxTokens int64) StopCondition {
  74	return func(steps []StepResult) bool {
  75		var totalTokens int64
  76		for _, step := range steps {
  77			totalTokens += step.Usage.TotalTokens
  78		}
  79		return totalTokens >= maxTokens
  80	}
  81}
  82
  83type PrepareStepFunctionOptions struct {
  84	Steps      []StepResult
  85	StepNumber int
  86	Model      LanguageModel
  87	Messages   []Message
  88}
  89
  90type PrepareStepResult struct {
  91	Model           LanguageModel
  92	Messages        []Message
  93	System          *string
  94	ToolChoice      *ToolChoice
  95	ActiveTools     []string
  96	DisableAllTools bool
  97}
  98
  99type ToolCallRepairOptions struct {
 100	OriginalToolCall ToolCallContent
 101	ValidationError  error
 102	AvailableTools   []AgentTool
 103	SystemPrompt     string
 104	Messages         []Message
 105}
 106
 107type (
 108	PrepareStepFunction    = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
 109	OnStepFinishedFunction = func(step StepResult)
 110	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 111)
 112
 113type agentSettings struct {
 114	systemPrompt     string
 115	maxOutputTokens  *int64
 116	temperature      *float64
 117	topP             *float64
 118	topK             *int64
 119	presencePenalty  *float64
 120	frequencyPenalty *float64
 121	headers          map[string]string
 122	providerOptions  ProviderOptions
 123
 124	// TODO: add support for provider tools
 125	tools      []AgentTool
 126	maxRetries *int
 127
 128	model LanguageModel
 129
 130	stopWhen       []StopCondition
 131	prepareStep    PrepareStepFunction
 132	repairToolCall RepairToolCallFunction
 133	onRetry        OnRetryCallback
 134}
 135
 136type AgentCall struct {
 137	Prompt           string     `json:"prompt"`
 138	Files            []FilePart `json:"files"`
 139	Messages         []Message  `json:"messages"`
 140	MaxOutputTokens  *int64
 141	Temperature      *float64 `json:"temperature"`
 142	TopP             *float64 `json:"top_p"`
 143	TopK             *int64   `json:"top_k"`
 144	PresencePenalty  *float64 `json:"presence_penalty"`
 145	FrequencyPenalty *float64 `json:"frequency_penalty"`
 146	ActiveTools      []string `json:"active_tools"`
 147	Headers          map[string]string
 148	ProviderOptions  ProviderOptions
 149	OnRetry          OnRetryCallback
 150	MaxRetries       *int
 151
 152	StopWhen       []StopCondition
 153	PrepareStep    PrepareStepFunction
 154	RepairToolCall RepairToolCallFunction
 155}
 156
 157type AgentStreamCall struct {
 158	Prompt           string     `json:"prompt"`
 159	Files            []FilePart `json:"files"`
 160	Messages         []Message  `json:"messages"`
 161	MaxOutputTokens  *int64
 162	Temperature      *float64 `json:"temperature"`
 163	TopP             *float64 `json:"top_p"`
 164	TopK             *int64   `json:"top_k"`
 165	PresencePenalty  *float64 `json:"presence_penalty"`
 166	FrequencyPenalty *float64 `json:"frequency_penalty"`
 167	ActiveTools      []string `json:"active_tools"`
 168	Headers          map[string]string
 169	ProviderOptions  ProviderOptions
 170	OnRetry          OnRetryCallback
 171	MaxRetries       *int
 172
 173	StopWhen       []StopCondition
 174	PrepareStep    PrepareStepFunction
 175	RepairToolCall RepairToolCallFunction
 176
 177	// Agent-level callbacks
 178	OnAgentStart  func()                            // Called when agent starts
 179	OnAgentFinish func(result *AgentResult) error   // Called when agent finishes
 180	OnStepStart   func(stepNumber int) error        // Called when a step starts
 181	OnStepFinish  func(stepResult StepResult) error // Called when a step finishes
 182	OnFinish      func(result *AgentResult)         // Called when entire agent completes
 183	OnError       func(error)                       // Called when an error occurs
 184
 185	// Stream part callbacks - called for each corresponding stream part type
 186	OnChunk          func(StreamPart) error                                                                // Called for each stream part (catch-all)
 187	OnWarnings       func(warnings []CallWarning) error                                                    // Called for warnings
 188	OnTextStart      func(id string) error                                                                 // Called when text starts
 189	OnTextDelta      func(id, text string) error                                                           // Called for text deltas
 190	OnTextEnd        func(id string) error                                                                 // Called when text ends
 191	OnReasoningStart func(id string) error                                                                 // Called when reasoning starts
 192	OnReasoningDelta func(id, text string) error                                                           // Called for reasoning deltas
 193	OnReasoningEnd   func(id string, reasoning ReasoningContent) error                                     // Called when reasoning ends
 194	OnToolInputStart func(id, toolName string) error                                                       // Called when tool input starts
 195	OnToolInputDelta func(id, delta string) error                                                          // Called for tool input deltas
 196	OnToolInputEnd   func(id string) error                                                                 // Called when tool input ends
 197	OnToolCall       func(toolCall ToolCallContent) error                                                  // Called when tool call is complete
 198	OnToolResult     func(result ToolResultContent) error                                                  // Called when tool execution completes
 199	OnSource         func(source SourceContent) error                                                      // Called for source references
 200	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error // Called when stream finishes
 201}
 202
 203type AgentResult struct {
 204	Steps []StepResult
 205	// Final response
 206	Response   Response
 207	TotalUsage Usage
 208}
 209
 210type Agent interface {
 211	Generate(context.Context, AgentCall) (*AgentResult, error)
 212	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 213}
 214
 215type AgentOption = func(*agentSettings)
 216
 217type agent struct {
 218	settings agentSettings
 219}
 220
 221func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 222	settings := agentSettings{
 223		model: model,
 224	}
 225	for _, o := range opts {
 226		o(&settings)
 227	}
 228	return &agent{
 229		settings: settings,
 230	}
 231}
 232
 233func (a *agent) prepareCall(call AgentCall) AgentCall {
 234	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 235	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 236	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 237	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 238	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 239	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 240	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 241
 242	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 243		call.StopWhen = a.settings.stopWhen
 244	}
 245	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 246		call.PrepareStep = a.settings.prepareStep
 247	}
 248	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 249		call.RepairToolCall = a.settings.repairToolCall
 250	}
 251	if call.OnRetry == nil && a.settings.onRetry != nil {
 252		call.OnRetry = a.settings.onRetry
 253	}
 254
 255	providerOptions := ProviderOptions{}
 256	if a.settings.providerOptions != nil {
 257		maps.Copy(providerOptions, a.settings.providerOptions)
 258	}
 259	if call.ProviderOptions != nil {
 260		maps.Copy(providerOptions, call.ProviderOptions)
 261	}
 262	call.ProviderOptions = providerOptions
 263
 264	headers := map[string]string{}
 265
 266	if a.settings.headers != nil {
 267		maps.Copy(headers, a.settings.headers)
 268	}
 269
 270	if call.Headers != nil {
 271		maps.Copy(headers, call.Headers)
 272	}
 273	call.Headers = headers
 274	return call
 275}
 276
 277// Generate implements Agent.
 278func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 279	opts = a.prepareCall(opts)
 280	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 281	if err != nil {
 282		return nil, err
 283	}
 284	var responseMessages []Message
 285	var steps []StepResult
 286
 287	for {
 288		stepInputMessages := append(initialPrompt, responseMessages...)
 289		stepModel := a.settings.model
 290		stepSystemPrompt := a.settings.systemPrompt
 291		stepActiveTools := opts.ActiveTools
 292		stepToolChoice := ToolChoiceAuto
 293		disableAllTools := false
 294
 295		if opts.PrepareStep != nil {
 296			prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
 297				Model:      stepModel,
 298				Steps:      steps,
 299				StepNumber: len(steps),
 300				Messages:   stepInputMessages,
 301			})
 302			if err != nil {
 303				return nil, err
 304			}
 305
 306			// Apply prepared step modifications
 307			if prepared.Messages != nil {
 308				stepInputMessages = prepared.Messages
 309			}
 310			if prepared.Model != nil {
 311				stepModel = prepared.Model
 312			}
 313			if prepared.System != nil {
 314				stepSystemPrompt = *prepared.System
 315			}
 316			if prepared.ToolChoice != nil {
 317				stepToolChoice = *prepared.ToolChoice
 318			}
 319			if len(prepared.ActiveTools) > 0 {
 320				stepActiveTools = prepared.ActiveTools
 321			}
 322			disableAllTools = prepared.DisableAllTools
 323		}
 324
 325		// Recreate prompt with potentially modified system prompt
 326		if stepSystemPrompt != a.settings.systemPrompt {
 327			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 328			if err != nil {
 329				return nil, err
 330			}
 331			// Replace system message part, keep the rest
 332			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 333				stepInputMessages[0] = stepPrompt[0] // Replace system message
 334			}
 335		}
 336
 337		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 338
 339		retryOptions := DefaultRetryOptions()
 340		retryOptions.OnRetry = opts.OnRetry
 341		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 342
 343		result, err := retry(ctx, func() (*Response, error) {
 344			return stepModel.Generate(ctx, Call{
 345				Prompt:           stepInputMessages,
 346				MaxOutputTokens:  opts.MaxOutputTokens,
 347				Temperature:      opts.Temperature,
 348				TopP:             opts.TopP,
 349				TopK:             opts.TopK,
 350				PresencePenalty:  opts.PresencePenalty,
 351				FrequencyPenalty: opts.FrequencyPenalty,
 352				Tools:            preparedTools,
 353				ToolChoice:       &stepToolChoice,
 354				Headers:          opts.Headers,
 355				ProviderOptions:  opts.ProviderOptions,
 356			})
 357		})
 358		if err != nil {
 359			return nil, err
 360		}
 361
 362		var stepToolCalls []ToolCallContent
 363		for _, content := range result.Content {
 364			if content.GetType() == ContentTypeToolCall {
 365				toolCall, ok := AsContentType[ToolCallContent](content)
 366				if !ok {
 367					continue
 368				}
 369
 370				// Validate and potentially repair the tool call
 371				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 372				stepToolCalls = append(stepToolCalls, validatedToolCall)
 373			}
 374		}
 375
 376		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 377
 378		// Build step content with validated tool calls and tool results
 379		stepContent := []Content{}
 380		toolCallIndex := 0
 381		for _, content := range result.Content {
 382			if content.GetType() == ContentTypeToolCall {
 383				// Replace with validated tool call
 384				if toolCallIndex < len(stepToolCalls) {
 385					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 386					toolCallIndex++
 387				}
 388			} else {
 389				// Keep other content as-is
 390				stepContent = append(stepContent, content)
 391			}
 392		}
 393		// Add tool results
 394		for _, result := range toolResults {
 395			stepContent = append(stepContent, result)
 396		}
 397		currentStepMessages := toResponseMessages(stepContent)
 398		responseMessages = append(responseMessages, currentStepMessages...)
 399
 400		stepResult := StepResult{
 401			Response: Response{
 402				Content:          stepContent,
 403				FinishReason:     result.FinishReason,
 404				Usage:            result.Usage,
 405				Warnings:         result.Warnings,
 406				ProviderMetadata: result.ProviderMetadata,
 407			},
 408			Messages: currentStepMessages,
 409		}
 410		steps = append(steps, stepResult)
 411		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 412
 413		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 414			break
 415		}
 416	}
 417
 418	totalUsage := Usage{}
 419
 420	for _, step := range steps {
 421		usage := step.Usage
 422		totalUsage.InputTokens += usage.InputTokens
 423		totalUsage.OutputTokens += usage.OutputTokens
 424		totalUsage.ReasoningTokens += usage.ReasoningTokens
 425		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 426		totalUsage.CacheReadTokens += usage.CacheReadTokens
 427		totalUsage.TotalTokens += usage.TotalTokens
 428	}
 429
 430	agentResult := &AgentResult{
 431		Steps:      steps,
 432		Response:   steps[len(steps)-1].Response,
 433		TotalUsage: totalUsage,
 434	}
 435	return agentResult, nil
 436}
 437
 438func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 439	if len(conditions) == 0 {
 440		return false
 441	}
 442
 443	for _, condition := range conditions {
 444		if condition(steps) {
 445			return true
 446		}
 447	}
 448	return false
 449}
 450
 451func toResponseMessages(content []Content) []Message {
 452	var assistantParts []MessagePart
 453	var toolParts []MessagePart
 454
 455	for _, c := range content {
 456		switch c.GetType() {
 457		case ContentTypeText:
 458			text, ok := AsContentType[TextContent](c)
 459			if !ok {
 460				continue
 461			}
 462			assistantParts = append(assistantParts, TextPart{
 463				Text:            text.Text,
 464				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 465			})
 466		case ContentTypeReasoning:
 467			reasoning, ok := AsContentType[ReasoningContent](c)
 468			if !ok {
 469				continue
 470			}
 471			assistantParts = append(assistantParts, ReasoningPart{
 472				Text:            reasoning.Text,
 473				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 474			})
 475		case ContentTypeToolCall:
 476			toolCall, ok := AsContentType[ToolCallContent](c)
 477			if !ok {
 478				continue
 479			}
 480			assistantParts = append(assistantParts, ToolCallPart{
 481				ToolCallID:       toolCall.ToolCallID,
 482				ToolName:         toolCall.ToolName,
 483				Input:            toolCall.Input,
 484				ProviderExecuted: toolCall.ProviderExecuted,
 485				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 486			})
 487		case ContentTypeFile:
 488			file, ok := AsContentType[FileContent](c)
 489			if !ok {
 490				continue
 491			}
 492			assistantParts = append(assistantParts, FilePart{
 493				Data:            file.Data,
 494				MediaType:       file.MediaType,
 495				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 496			})
 497		case ContentTypeSource:
 498			// Sources are metadata about references used to generate the response.
 499			// They don't need to be included in the conversation messages.
 500			continue
 501		case ContentTypeToolResult:
 502			result, ok := AsContentType[ToolResultContent](c)
 503			if !ok {
 504				continue
 505			}
 506			toolParts = append(toolParts, ToolResultPart{
 507				ToolCallID:      result.ToolCallID,
 508				Output:          result.Result,
 509				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 510			})
 511		}
 512	}
 513
 514	var messages []Message
 515	if len(assistantParts) > 0 {
 516		messages = append(messages, Message{
 517			Role:    MessageRoleAssistant,
 518			Content: assistantParts,
 519		})
 520	}
 521	if len(toolParts) > 0 {
 522		messages = append(messages, Message{
 523			Role:    MessageRoleTool,
 524			Content: toolParts,
 525		})
 526	}
 527	return messages
 528}
 529
 530func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 531	if len(toolCalls) == 0 {
 532		return nil, nil
 533	}
 534
 535	// Create a map for quick tool lookup
 536	toolMap := make(map[string]AgentTool)
 537	for _, tool := range allTools {
 538		toolMap[tool.Info().Name] = tool
 539	}
 540
 541	// Execute all tool calls in parallel
 542	results := make([]ToolResultContent, len(toolCalls))
 543	executeErrors := make([]error, len(toolCalls))
 544
 545	var wg sync.WaitGroup
 546
 547	for i, toolCall := range toolCalls {
 548		wg.Add(1)
 549		go func(index int, call ToolCallContent) {
 550			defer wg.Done()
 551
 552			// Skip invalid tool calls - create error result
 553			if call.Invalid {
 554				results[index] = ToolResultContent{
 555					ToolCallID: call.ToolCallID,
 556					ToolName:   call.ToolName,
 557					Result: ToolResultOutputContentError{
 558						Error: call.ValidationError,
 559					},
 560					ProviderExecuted: false,
 561				}
 562				if toolResultCallback != nil {
 563					err := toolResultCallback(results[index])
 564					if err != nil {
 565						executeErrors[index] = err
 566					}
 567				}
 568
 569				return
 570			}
 571
 572			tool, exists := toolMap[call.ToolName]
 573			if !exists {
 574				results[index] = ToolResultContent{
 575					ToolCallID: call.ToolCallID,
 576					ToolName:   call.ToolName,
 577					Result: ToolResultOutputContentError{
 578						Error: errors.New("Error: Tool not found: " + call.ToolName),
 579					},
 580					ProviderExecuted: false,
 581				}
 582
 583				if toolResultCallback != nil {
 584					err := toolResultCallback(results[index])
 585					if err != nil {
 586						executeErrors[index] = err
 587					}
 588				}
 589				return
 590			}
 591
 592			// Execute the tool
 593			result, err := tool.Run(ctx, ToolCall{
 594				ID:    call.ToolCallID,
 595				Name:  call.ToolName,
 596				Input: call.Input,
 597			})
 598			if err != nil {
 599				results[index] = ToolResultContent{
 600					ToolCallID: call.ToolCallID,
 601					ToolName:   call.ToolName,
 602					Result: ToolResultOutputContentError{
 603						Error: err,
 604					},
 605					ClientMetadata:   result.Metadata,
 606					ProviderExecuted: false,
 607				}
 608				if toolResultCallback != nil {
 609					cbErr := toolResultCallback(results[index])
 610					if cbErr != nil {
 611						executeErrors[index] = cbErr
 612					}
 613				}
 614				executeErrors[index] = err
 615				return
 616			}
 617
 618			if result.IsError {
 619				results[index] = ToolResultContent{
 620					ToolCallID: call.ToolCallID,
 621					ToolName:   call.ToolName,
 622					Result: ToolResultOutputContentError{
 623						Error: errors.New(result.Content),
 624					},
 625					ClientMetadata:   result.Metadata,
 626					ProviderExecuted: false,
 627				}
 628
 629				if toolResultCallback != nil {
 630					err := toolResultCallback(results[index])
 631					if err != nil {
 632						executeErrors[index] = err
 633					}
 634				}
 635			} else {
 636				results[index] = ToolResultContent{
 637					ToolCallID: call.ToolCallID,
 638					ToolName:   toolCall.ToolName,
 639					Result: ToolResultOutputContentText{
 640						Text: result.Content,
 641					},
 642					ClientMetadata:   result.Metadata,
 643					ProviderExecuted: false,
 644				}
 645				if toolResultCallback != nil {
 646					err := toolResultCallback(results[index])
 647					if err != nil {
 648						executeErrors[index] = err
 649					}
 650				}
 651			}
 652		}(i, toolCall)
 653	}
 654
 655	// Wait for all tool executions to complete
 656	wg.Wait()
 657
 658	for _, err := range executeErrors {
 659		if err != nil {
 660			return nil, err
 661		}
 662	}
 663
 664	return results, nil
 665}
 666
 667// Stream implements Agent.
 668func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 669	// Convert AgentStreamCall to AgentCall for preparation
 670	call := AgentCall{
 671		Prompt:           opts.Prompt,
 672		Files:            opts.Files,
 673		Messages:         opts.Messages,
 674		MaxOutputTokens:  opts.MaxOutputTokens,
 675		Temperature:      opts.Temperature,
 676		TopP:             opts.TopP,
 677		TopK:             opts.TopK,
 678		PresencePenalty:  opts.PresencePenalty,
 679		FrequencyPenalty: opts.FrequencyPenalty,
 680		ActiveTools:      opts.ActiveTools,
 681		Headers:          opts.Headers,
 682		ProviderOptions:  opts.ProviderOptions,
 683		MaxRetries:       opts.MaxRetries,
 684		StopWhen:         opts.StopWhen,
 685		PrepareStep:      opts.PrepareStep,
 686		RepairToolCall:   opts.RepairToolCall,
 687	}
 688
 689	call = a.prepareCall(call)
 690
 691	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 692	if err != nil {
 693		return nil, err
 694	}
 695
 696	var responseMessages []Message
 697	var steps []StepResult
 698	var totalUsage Usage
 699
 700	// Start agent stream
 701	if opts.OnAgentStart != nil {
 702		opts.OnAgentStart()
 703	}
 704
 705	for stepNumber := 0; ; stepNumber++ {
 706		stepInputMessages := append(initialPrompt, responseMessages...)
 707		stepModel := a.settings.model
 708		stepSystemPrompt := a.settings.systemPrompt
 709		stepActiveTools := call.ActiveTools
 710		stepToolChoice := ToolChoiceAuto
 711		disableAllTools := false
 712
 713		// Apply step preparation if provided
 714		if call.PrepareStep != nil {
 715			prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
 716				Model:      stepModel,
 717				Steps:      steps,
 718				StepNumber: stepNumber,
 719				Messages:   stepInputMessages,
 720			})
 721			if err != nil {
 722				return nil, err
 723			}
 724
 725			if prepared.Messages != nil {
 726				stepInputMessages = prepared.Messages
 727			}
 728			if prepared.Model != nil {
 729				stepModel = prepared.Model
 730			}
 731			if prepared.System != nil {
 732				stepSystemPrompt = *prepared.System
 733			}
 734			if prepared.ToolChoice != nil {
 735				stepToolChoice = *prepared.ToolChoice
 736			}
 737			if len(prepared.ActiveTools) > 0 {
 738				stepActiveTools = prepared.ActiveTools
 739			}
 740			disableAllTools = prepared.DisableAllTools
 741		}
 742
 743		// Recreate prompt with potentially modified system prompt
 744		if stepSystemPrompt != a.settings.systemPrompt {
 745			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 746			if err != nil {
 747				return nil, err
 748			}
 749			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 750				stepInputMessages[0] = stepPrompt[0]
 751			}
 752		}
 753
 754		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 755
 756		// Start step stream
 757		if opts.OnStepStart != nil {
 758			opts.OnStepStart(stepNumber)
 759		}
 760
 761		// Create streaming call
 762		streamCall := Call{
 763			Prompt:           stepInputMessages,
 764			MaxOutputTokens:  call.MaxOutputTokens,
 765			Temperature:      call.Temperature,
 766			TopP:             call.TopP,
 767			TopK:             call.TopK,
 768			PresencePenalty:  call.PresencePenalty,
 769			FrequencyPenalty: call.FrequencyPenalty,
 770			Tools:            preparedTools,
 771			ToolChoice:       &stepToolChoice,
 772			Headers:          call.Headers,
 773			ProviderOptions:  call.ProviderOptions,
 774		}
 775
 776		// Get streaming response
 777		stream, err := stepModel.Stream(ctx, streamCall)
 778		if err != nil {
 779			if opts.OnError != nil {
 780				opts.OnError(err)
 781			}
 782			return nil, err
 783		}
 784
 785		// Process stream with tool execution
 786		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 787		if err != nil {
 788			if opts.OnError != nil {
 789				opts.OnError(err)
 790			}
 791			return nil, err
 792		}
 793
 794		steps = append(steps, stepResult)
 795		totalUsage = addUsage(totalUsage, stepResult.Usage)
 796
 797		// Call step finished callback
 798		if opts.OnStepFinish != nil {
 799			opts.OnStepFinish(stepResult)
 800		}
 801
 802		// Add step messages to response messages
 803		stepMessages := toResponseMessages(stepResult.Content)
 804		responseMessages = append(responseMessages, stepMessages...)
 805
 806		// Check stop conditions
 807		shouldStop := isStopConditionMet(call.StopWhen, steps)
 808		if shouldStop || !shouldContinue {
 809			break
 810		}
 811	}
 812
 813	// Finish agent stream
 814	agentResult := &AgentResult{
 815		Steps:      steps,
 816		Response:   steps[len(steps)-1].Response,
 817		TotalUsage: totalUsage,
 818	}
 819
 820	if opts.OnFinish != nil {
 821		opts.OnFinish(agentResult)
 822	}
 823
 824	if opts.OnAgentFinish != nil {
 825		opts.OnAgentFinish(agentResult)
 826	}
 827
 828	return agentResult, nil
 829}
 830
 831func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 832	var preparedTools []Tool
 833
 834	// If explicitly disabling all tools, return no tools
 835	if disableAllTools {
 836		return preparedTools
 837	}
 838
 839	for _, tool := range tools {
 840		// If activeTools has items, only include tools in the list
 841		// If activeTools is empty, include all tools
 842		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 843			continue
 844		}
 845		info := tool.Info()
 846		preparedTools = append(preparedTools, FunctionTool{
 847			Name:        info.Name,
 848			Description: info.Description,
 849			InputSchema: map[string]any{
 850				"type":       "object",
 851				"properties": info.Parameters,
 852				"required":   info.Required,
 853			},
 854		})
 855	}
 856	return preparedTools
 857}
 858
 859// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
 860func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 861	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 862		return toolCall
 863	} else {
 864		if repairFunc != nil {
 865			repairOptions := ToolCallRepairOptions{
 866				OriginalToolCall: toolCall,
 867				ValidationError:  err,
 868				AvailableTools:   availableTools,
 869				SystemPrompt:     systemPrompt,
 870				Messages:         messages,
 871			}
 872
 873			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 874				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 875					return *repairedToolCall
 876				}
 877			}
 878		}
 879
 880		invalidToolCall := toolCall
 881		invalidToolCall.Invalid = true
 882		invalidToolCall.ValidationError = err
 883		return invalidToolCall
 884	}
 885}
 886
 887// validateToolCall validates a tool call against available tools and their schemas
 888func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 889	var tool AgentTool
 890	for _, t := range availableTools {
 891		if t.Info().Name == toolCall.ToolName {
 892			tool = t
 893			break
 894		}
 895	}
 896
 897	if tool == nil {
 898		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 899	}
 900
 901	// Validate JSON parsing
 902	var input map[string]any
 903	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 904		return fmt.Errorf("invalid JSON input: %w", err)
 905	}
 906
 907	// Basic schema validation (check required fields)
 908	// TODO: more robust schema validation using JSON Schema or similar
 909	toolInfo := tool.Info()
 910	for _, required := range toolInfo.Required {
 911		if _, exists := input[required]; !exists {
 912			return fmt.Errorf("missing required parameter: %s", required)
 913		}
 914	}
 915	return nil
 916}
 917
 918func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 919	if prompt == "" {
 920		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 921	}
 922
 923	var preparedPrompt Prompt
 924
 925	if system != "" {
 926		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 927	}
 928
 929	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 930	preparedPrompt = append(preparedPrompt, messages...)
 931	return preparedPrompt, nil
 932}
 933
 934func WithSystemPrompt(prompt string) AgentOption {
 935	return func(s *agentSettings) {
 936		s.systemPrompt = prompt
 937	}
 938}
 939
 940func WithMaxOutputTokens(tokens int64) AgentOption {
 941	return func(s *agentSettings) {
 942		s.maxOutputTokens = &tokens
 943	}
 944}
 945
 946func WithTemperature(temp float64) AgentOption {
 947	return func(s *agentSettings) {
 948		s.temperature = &temp
 949	}
 950}
 951
 952func WithTopP(topP float64) AgentOption {
 953	return func(s *agentSettings) {
 954		s.topP = &topP
 955	}
 956}
 957
 958func WithTopK(topK int64) AgentOption {
 959	return func(s *agentSettings) {
 960		s.topK = &topK
 961	}
 962}
 963
 964func WithPresencePenalty(penalty float64) AgentOption {
 965	return func(s *agentSettings) {
 966		s.presencePenalty = &penalty
 967	}
 968}
 969
 970func WithFrequencyPenalty(penalty float64) AgentOption {
 971	return func(s *agentSettings) {
 972		s.frequencyPenalty = &penalty
 973	}
 974}
 975
 976func WithTools(tools ...AgentTool) AgentOption {
 977	return func(s *agentSettings) {
 978		s.tools = append(s.tools, tools...)
 979	}
 980}
 981
 982func WithStopConditions(conditions ...StopCondition) AgentOption {
 983	return func(s *agentSettings) {
 984		s.stopWhen = append(s.stopWhen, conditions...)
 985	}
 986}
 987
 988func WithPrepareStep(fn PrepareStepFunction) AgentOption {
 989	return func(s *agentSettings) {
 990		s.prepareStep = fn
 991	}
 992}
 993
 994func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
 995	return func(s *agentSettings) {
 996		s.repairToolCall = fn
 997	}
 998}
 999
1000// processStepStream processes a single step's stream and returns the step result
1001func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1002	var stepContent []Content
1003	var stepToolCalls []ToolCallContent
1004	var stepUsage Usage
1005	stepFinishReason := FinishReasonUnknown
1006	var stepWarnings []CallWarning
1007	var stepProviderMetadata ProviderMetadata
1008
1009	activeToolCalls := make(map[string]*ToolCallContent)
1010	activeTextContent := make(map[string]string)
1011
1012	// Process stream parts
1013	for part := range stream {
1014		// Forward all parts to chunk callback
1015		if opts.OnChunk != nil {
1016			err := opts.OnChunk(part)
1017			if err != nil {
1018				return StepResult{}, false, err
1019			}
1020		}
1021
1022		switch part.Type {
1023		case StreamPartTypeWarnings:
1024			stepWarnings = part.Warnings
1025			if opts.OnWarnings != nil {
1026				err := opts.OnWarnings(part.Warnings)
1027				if err != nil {
1028					return StepResult{}, false, err
1029				}
1030			}
1031
1032		case StreamPartTypeTextStart:
1033			activeTextContent[part.ID] = ""
1034			if opts.OnTextStart != nil {
1035				err := opts.OnTextStart(part.ID)
1036				if err != nil {
1037					return StepResult{}, false, err
1038				}
1039			}
1040
1041		case StreamPartTypeTextDelta:
1042			if _, exists := activeTextContent[part.ID]; exists {
1043				activeTextContent[part.ID] += part.Delta
1044			}
1045			if opts.OnTextDelta != nil {
1046				err := opts.OnTextDelta(part.ID, part.Delta)
1047				if err != nil {
1048					return StepResult{}, false, err
1049				}
1050			}
1051
1052		case StreamPartTypeTextEnd:
1053			if text, exists := activeTextContent[part.ID]; exists {
1054				stepContent = append(stepContent, TextContent{
1055					Text:             text,
1056					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1057				})
1058				delete(activeTextContent, part.ID)
1059			}
1060			if opts.OnTextEnd != nil {
1061				err := opts.OnTextEnd(part.ID)
1062				if err != nil {
1063					return StepResult{}, false, err
1064				}
1065			}
1066
1067		case StreamPartTypeReasoningStart:
1068			activeTextContent[part.ID] = ""
1069			if opts.OnReasoningStart != nil {
1070				err := opts.OnReasoningStart(part.ID)
1071				if err != nil {
1072					return StepResult{}, false, err
1073				}
1074			}
1075
1076		case StreamPartTypeReasoningDelta:
1077			if _, exists := activeTextContent[part.ID]; exists {
1078				activeTextContent[part.ID] += part.Delta
1079			}
1080			if opts.OnReasoningDelta != nil {
1081				err := opts.OnReasoningDelta(part.ID, part.Delta)
1082				if err != nil {
1083					return StepResult{}, false, err
1084				}
1085			}
1086
1087		case StreamPartTypeReasoningEnd:
1088			if text, exists := activeTextContent[part.ID]; exists {
1089				stepContent = append(stepContent, ReasoningContent{
1090					Text:             text,
1091					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1092				})
1093				if opts.OnReasoningEnd != nil {
1094					err := opts.OnReasoningEnd(part.ID, ReasoningContent{
1095						Text:             text,
1096						ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1097					})
1098					if err != nil {
1099						return StepResult{}, false, err
1100					}
1101				}
1102				delete(activeTextContent, part.ID)
1103			}
1104
1105		case StreamPartTypeToolInputStart:
1106			activeToolCalls[part.ID] = &ToolCallContent{
1107				ToolCallID:       part.ID,
1108				ToolName:         part.ToolCallName,
1109				Input:            "",
1110				ProviderExecuted: part.ProviderExecuted,
1111			}
1112			if opts.OnToolInputStart != nil {
1113				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1114				if err != nil {
1115					return StepResult{}, false, err
1116				}
1117			}
1118
1119		case StreamPartTypeToolInputDelta:
1120			if toolCall, exists := activeToolCalls[part.ID]; exists {
1121				toolCall.Input += part.Delta
1122			}
1123			if opts.OnToolInputDelta != nil {
1124				err := opts.OnToolInputDelta(part.ID, part.Delta)
1125				if err != nil {
1126					return StepResult{}, false, err
1127				}
1128			}
1129
1130		case StreamPartTypeToolInputEnd:
1131			if opts.OnToolInputEnd != nil {
1132				err := opts.OnToolInputEnd(part.ID)
1133				if err != nil {
1134					return StepResult{}, false, err
1135				}
1136			}
1137
1138		case StreamPartTypeToolCall:
1139			toolCall := ToolCallContent{
1140				ToolCallID:       part.ID,
1141				ToolName:         part.ToolCallName,
1142				Input:            part.ToolCallInput,
1143				ProviderExecuted: part.ProviderExecuted,
1144				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1145			}
1146
1147			// Validate and potentially repair the tool call
1148			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1149			stepToolCalls = append(stepToolCalls, validatedToolCall)
1150			stepContent = append(stepContent, validatedToolCall)
1151
1152			if opts.OnToolCall != nil {
1153				err := opts.OnToolCall(validatedToolCall)
1154				if err != nil {
1155					return StepResult{}, false, err
1156				}
1157			}
1158
1159			// Clean up active tool call
1160			delete(activeToolCalls, part.ID)
1161
1162		case StreamPartTypeSource:
1163			sourceContent := SourceContent{
1164				SourceType:       part.SourceType,
1165				ID:               part.ID,
1166				URL:              part.URL,
1167				Title:            part.Title,
1168				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1169			}
1170			stepContent = append(stepContent, sourceContent)
1171			if opts.OnSource != nil {
1172				err := opts.OnSource(sourceContent)
1173				if err != nil {
1174					return StepResult{}, false, err
1175				}
1176			}
1177
1178		case StreamPartTypeFinish:
1179			stepUsage = part.Usage
1180			stepFinishReason = part.FinishReason
1181			stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1182			if opts.OnStreamFinish != nil {
1183				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1184				if err != nil {
1185					return StepResult{}, false, err
1186				}
1187			}
1188
1189		case StreamPartTypeError:
1190			if opts.OnError != nil {
1191				opts.OnError(part.Error)
1192			}
1193			return StepResult{}, false, part.Error
1194		}
1195	}
1196
1197	// Execute tools if any
1198	var toolResults []ToolResultContent
1199	if len(stepToolCalls) > 0 {
1200		var err error
1201		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1202		if err != nil {
1203			return StepResult{}, false, err
1204		}
1205		// Add tool results to content
1206		for _, result := range toolResults {
1207			stepContent = append(stepContent, result)
1208		}
1209	}
1210
1211	stepResult := StepResult{
1212		Response: Response{
1213			Content:          stepContent,
1214			FinishReason:     stepFinishReason,
1215			Usage:            stepUsage,
1216			Warnings:         stepWarnings,
1217			ProviderMetadata: stepProviderMetadata,
1218		},
1219		Messages: toResponseMessages(stepContent),
1220	}
1221
1222	// Determine if we should continue (has tool calls and not stopped)
1223	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1224
1225	return stepResult, shouldContinue, nil
1226}
1227
1228func addUsage(a, b Usage) Usage {
1229	return Usage{
1230		InputTokens:         a.InputTokens + b.InputTokens,
1231		OutputTokens:        a.OutputTokens + b.OutputTokens,
1232		TotalTokens:         a.TotalTokens + b.TotalTokens,
1233		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1234		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1235		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1236	}
1237}
1238
1239func WithHeaders(headers map[string]string) AgentOption {
1240	return func(s *agentSettings) {
1241		s.headers = headers
1242	}
1243}
1244
1245func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1246	return func(s *agentSettings) {
1247		s.providerOptions = providerOptions
1248	}
1249}