agent.go

   1package ai
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"maps"
   9	"slices"
  10	"sync"
  11)
  12
  13type StepResult struct {
  14	Response
  15	Messages []Message
  16}
  17
  18type StopCondition = func(steps []StepResult) bool
  19
  20// StepCountIs returns a stop condition that stops after the specified number of steps.
  21func StepCountIs(stepCount int) StopCondition {
  22	return func(steps []StepResult) bool {
  23		return len(steps) >= stepCount
  24	}
  25}
  26
  27// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  28func HasToolCall(toolName string) StopCondition {
  29	return func(steps []StepResult) bool {
  30		if len(steps) == 0 {
  31			return false
  32		}
  33		lastStep := steps[len(steps)-1]
  34		toolCalls := lastStep.Content.ToolCalls()
  35		for _, toolCall := range toolCalls {
  36			if toolCall.ToolName == toolName {
  37				return true
  38			}
  39		}
  40		return false
  41	}
  42}
  43
  44// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  45func HasContent(contentType ContentType) StopCondition {
  46	return func(steps []StepResult) bool {
  47		if len(steps) == 0 {
  48			return false
  49		}
  50		lastStep := steps[len(steps)-1]
  51		for _, content := range lastStep.Content {
  52			if content.GetType() == contentType {
  53				return true
  54			}
  55		}
  56		return false
  57	}
  58}
  59
  60// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  61func FinishReasonIs(reason FinishReason) StopCondition {
  62	return func(steps []StepResult) bool {
  63		if len(steps) == 0 {
  64			return false
  65		}
  66		lastStep := steps[len(steps)-1]
  67		return lastStep.FinishReason == reason
  68	}
  69}
  70
  71// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  72func MaxTokensUsed(maxTokens int64) StopCondition {
  73	return func(steps []StepResult) bool {
  74		var totalTokens int64
  75		for _, step := range steps {
  76			totalTokens += step.Usage.TotalTokens
  77		}
  78		return totalTokens >= maxTokens
  79	}
  80}
  81
  82type PrepareStepFunctionOptions struct {
  83	Steps      []StepResult
  84	StepNumber int
  85	Model      LanguageModel
  86	Messages   []Message
  87}
  88
  89type PrepareStepResult struct {
  90	Model           LanguageModel
  91	Messages        []Message
  92	System          *string
  93	ToolChoice      *ToolChoice
  94	ActiveTools     []string
  95	DisableAllTools bool
  96}
  97
  98type ToolCallRepairOptions struct {
  99	OriginalToolCall ToolCallContent
 100	ValidationError  error
 101	AvailableTools   []AgentTool
 102	SystemPrompt     string
 103	Messages         []Message
 104}
 105
 106type (
 107	PrepareStepFunction    = func(options PrepareStepFunctionOptions) PrepareStepResult
 108	OnStepFinishedFunction = func(step StepResult)
 109	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 110)
 111
 112type AgentSettings struct {
 113	systemPrompt     string
 114	maxOutputTokens  *int64
 115	temperature      *float64
 116	topP             *float64
 117	topK             *int64
 118	presencePenalty  *float64
 119	frequencyPenalty *float64
 120	headers          map[string]string
 121	providerOptions  ProviderOptions
 122
 123	// TODO: add support for provider tools
 124	tools      []AgentTool
 125	maxRetries *int
 126
 127	model LanguageModel
 128
 129	stopWhen       []StopCondition
 130	prepareStep    PrepareStepFunction
 131	repairToolCall RepairToolCallFunction
 132	onStepFinished OnStepFinishedFunction
 133	onRetry        OnRetryCallback
 134}
 135
 136type AgentCall struct {
 137	Prompt           string     `json:"prompt"`
 138	Files            []FilePart `json:"files"`
 139	Messages         []Message  `json:"messages"`
 140	MaxOutputTokens  *int64
 141	Temperature      *float64 `json:"temperature"`
 142	TopP             *float64 `json:"top_p"`
 143	TopK             *int64   `json:"top_k"`
 144	PresencePenalty  *float64 `json:"presence_penalty"`
 145	FrequencyPenalty *float64 `json:"frequency_penalty"`
 146	ActiveTools      []string `json:"active_tools"`
 147	Headers          map[string]string
 148	ProviderOptions  ProviderOptions
 149	OnRetry          OnRetryCallback
 150	MaxRetries       *int
 151
 152	StopWhen       []StopCondition
 153	PrepareStep    PrepareStepFunction
 154	RepairToolCall RepairToolCallFunction
 155	OnStepFinished OnStepFinishedFunction
 156}
 157
 158type AgentStreamCall struct {
 159	Prompt           string     `json:"prompt"`
 160	Files            []FilePart `json:"files"`
 161	Messages         []Message  `json:"messages"`
 162	MaxOutputTokens  *int64
 163	Temperature      *float64 `json:"temperature"`
 164	TopP             *float64 `json:"top_p"`
 165	TopK             *int64   `json:"top_k"`
 166	PresencePenalty  *float64 `json:"presence_penalty"`
 167	FrequencyPenalty *float64 `json:"frequency_penalty"`
 168	ActiveTools      []string `json:"active_tools"`
 169	Headers          map[string]string
 170	ProviderOptions  ProviderOptions
 171	OnRetry          OnRetryCallback
 172	MaxRetries       *int
 173
 174	StopWhen       []StopCondition
 175	PrepareStep    PrepareStepFunction
 176	RepairToolCall RepairToolCallFunction
 177
 178	// Agent-level callbacks
 179	OnAgentStart  func()                      // Called when agent starts
 180	OnAgentFinish func(result *AgentResult)   // Called when agent finishes
 181	OnStepStart   func(stepNumber int)        // Called when a step starts
 182	OnStepFinish  func(stepResult StepResult) // Called when a step finishes
 183	OnFinish      func(result *AgentResult)   // Called when entire agent completes
 184	OnError       func(error)                 // Called when an error occurs
 185
 186	// Stream part callbacks - called for each corresponding stream part type
 187	OnChunk          func(StreamPart)                                                               // Called for each stream part (catch-all)
 188	OnWarnings       func(warnings []CallWarning)                                                   // Called for warnings
 189	OnTextStart      func(id string)                                                                // Called when text starts
 190	OnTextDelta      func(id, text string)                                                          // Called for text deltas
 191	OnTextEnd        func(id string)                                                                // Called when text ends
 192	OnReasoningStart func(id string)                                                                // Called when reasoning starts
 193	OnReasoningDelta func(id, text string)                                                          // Called for reasoning deltas
 194	OnReasoningEnd   func(id string)                                                                // Called when reasoning ends
 195	OnToolInputStart func(id, toolName string)                                                      // Called when tool input starts
 196	OnToolInputDelta func(id, delta string)                                                         // Called for tool input deltas
 197	OnToolInputEnd   func(id string)                                                                // Called when tool input ends
 198	OnToolCall       func(toolCall ToolCallContent)                                                 // Called when tool call is complete
 199	OnToolResult     func(result ToolResultContent)                                                 // Called when tool execution completes
 200	OnSource         func(source SourceContent)                                                     // Called for source references
 201	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes
 202	OnStreamError    func(error)                                                                    // Called when stream error occurs
 203}
 204
 205type AgentResult struct {
 206	Steps []StepResult
 207	// Final response
 208	Response   Response
 209	TotalUsage Usage
 210}
 211
 212type Agent interface {
 213	Generate(context.Context, AgentCall) (*AgentResult, error)
 214	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 215}
 216
 217type agentOption = func(*AgentSettings)
 218
 219type agent struct {
 220	settings AgentSettings
 221}
 222
 223func NewAgent(model LanguageModel, opts ...agentOption) Agent {
 224	settings := AgentSettings{
 225		model: model,
 226	}
 227	for _, o := range opts {
 228		o(&settings)
 229	}
 230	return &agent{
 231		settings: settings,
 232	}
 233}
 234
 235func (a *agent) prepareCall(call AgentCall) AgentCall {
 236	if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
 237		call.MaxOutputTokens = a.settings.maxOutputTokens
 238	}
 239	if call.Temperature == nil && a.settings.temperature != nil {
 240		call.Temperature = a.settings.temperature
 241	}
 242	if call.TopP == nil && a.settings.topP != nil {
 243		call.TopP = a.settings.topP
 244	}
 245	if call.TopK == nil && a.settings.topK != nil {
 246		call.TopK = a.settings.topK
 247	}
 248	if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
 249		call.PresencePenalty = a.settings.presencePenalty
 250	}
 251	if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
 252		call.FrequencyPenalty = a.settings.frequencyPenalty
 253	}
 254	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 255		call.StopWhen = a.settings.stopWhen
 256	}
 257	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 258		call.PrepareStep = a.settings.prepareStep
 259	}
 260	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 261		call.RepairToolCall = a.settings.repairToolCall
 262	}
 263	if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
 264		call.OnStepFinished = a.settings.onStepFinished
 265	}
 266	if call.OnRetry == nil && a.settings.onRetry != nil {
 267		call.OnRetry = a.settings.onRetry
 268	}
 269	if call.MaxRetries == nil && a.settings.maxRetries != nil {
 270		call.MaxRetries = a.settings.maxRetries
 271	}
 272
 273	providerOptions := ProviderOptions{}
 274	if a.settings.providerOptions != nil {
 275		maps.Copy(providerOptions, a.settings.providerOptions)
 276	}
 277	if call.ProviderOptions != nil {
 278		maps.Copy(providerOptions, call.ProviderOptions)
 279	}
 280	call.ProviderOptions = providerOptions
 281
 282	headers := map[string]string{}
 283
 284	if a.settings.headers != nil {
 285		maps.Copy(headers, a.settings.headers)
 286	}
 287
 288	if call.Headers != nil {
 289		maps.Copy(headers, call.Headers)
 290	}
 291	call.Headers = headers
 292	return call
 293}
 294
 295// Generate implements Agent.
 296func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 297	opts = a.prepareCall(opts)
 298	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 299	if err != nil {
 300		return nil, err
 301	}
 302	var responseMessages []Message
 303	var steps []StepResult
 304
 305	for {
 306		stepInputMessages := append(initialPrompt, responseMessages...)
 307		stepModel := a.settings.model
 308		stepSystemPrompt := a.settings.systemPrompt
 309		stepActiveTools := opts.ActiveTools
 310		stepToolChoice := ToolChoiceAuto
 311		disableAllTools := false
 312
 313		if opts.PrepareStep != nil {
 314			prepared := opts.PrepareStep(PrepareStepFunctionOptions{
 315				Model:      stepModel,
 316				Steps:      steps,
 317				StepNumber: len(steps),
 318				Messages:   stepInputMessages,
 319			})
 320
 321			// Apply prepared step modifications
 322			if prepared.Messages != nil {
 323				stepInputMessages = prepared.Messages
 324			}
 325			if prepared.Model != nil {
 326				stepModel = prepared.Model
 327			}
 328			if prepared.System != nil {
 329				stepSystemPrompt = *prepared.System
 330			}
 331			if prepared.ToolChoice != nil {
 332				stepToolChoice = *prepared.ToolChoice
 333			}
 334			if len(prepared.ActiveTools) > 0 {
 335				stepActiveTools = prepared.ActiveTools
 336			}
 337			disableAllTools = prepared.DisableAllTools
 338		}
 339
 340		// Recreate prompt with potentially modified system prompt
 341		if stepSystemPrompt != a.settings.systemPrompt {
 342			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 343			if err != nil {
 344				return nil, err
 345			}
 346			// Replace system message part, keep the rest
 347			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 348				stepInputMessages[0] = stepPrompt[0] // Replace system message
 349			}
 350		}
 351
 352		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 353
 354		retryOptions := DefaultRetryOptions()
 355		retryOptions.OnRetry = opts.OnRetry
 356		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 357
 358		result, err := retry(ctx, func() (*Response, error) {
 359			return stepModel.Generate(ctx, Call{
 360				Prompt:           stepInputMessages,
 361				MaxOutputTokens:  opts.MaxOutputTokens,
 362				Temperature:      opts.Temperature,
 363				TopP:             opts.TopP,
 364				TopK:             opts.TopK,
 365				PresencePenalty:  opts.PresencePenalty,
 366				FrequencyPenalty: opts.FrequencyPenalty,
 367				Tools:            preparedTools,
 368				ToolChoice:       &stepToolChoice,
 369				Headers:          opts.Headers,
 370				ProviderOptions:  opts.ProviderOptions,
 371			})
 372		})
 373		if err != nil {
 374			return nil, err
 375		}
 376
 377		var stepToolCalls []ToolCallContent
 378		for _, content := range result.Content {
 379			if content.GetType() == ContentTypeToolCall {
 380				toolCall, ok := AsContentType[ToolCallContent](content)
 381				if !ok {
 382					continue
 383				}
 384
 385				// Validate and potentially repair the tool call
 386				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 387				stepToolCalls = append(stepToolCalls, validatedToolCall)
 388			}
 389		}
 390
 391		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 392
 393		// Build step content with validated tool calls and tool results
 394		stepContent := []Content{}
 395		toolCallIndex := 0
 396		for _, content := range result.Content {
 397			if content.GetType() == ContentTypeToolCall {
 398				// Replace with validated tool call
 399				if toolCallIndex < len(stepToolCalls) {
 400					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 401					toolCallIndex++
 402				}
 403			} else {
 404				// Keep other content as-is
 405				stepContent = append(stepContent, content)
 406			}
 407		}
 408		// Add tool results
 409		for _, result := range toolResults {
 410			stepContent = append(stepContent, result)
 411		}
 412		currentStepMessages := toResponseMessages(stepContent)
 413		responseMessages = append(responseMessages, currentStepMessages...)
 414
 415		stepResult := StepResult{
 416			Response: Response{
 417				Content:          stepContent,
 418				FinishReason:     result.FinishReason,
 419				Usage:            result.Usage,
 420				Warnings:         result.Warnings,
 421				ProviderMetadata: result.ProviderMetadata,
 422			},
 423			Messages: currentStepMessages,
 424		}
 425		steps = append(steps, stepResult)
 426		if opts.OnStepFinished != nil {
 427			opts.OnStepFinished(stepResult)
 428		}
 429
 430		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 431
 432		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 433			break
 434		}
 435	}
 436
 437	totalUsage := Usage{}
 438
 439	for _, step := range steps {
 440		usage := step.Usage
 441		totalUsage.InputTokens += usage.InputTokens
 442		totalUsage.OutputTokens += usage.OutputTokens
 443		totalUsage.ReasoningTokens += usage.ReasoningTokens
 444		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 445		totalUsage.CacheReadTokens += usage.CacheReadTokens
 446		totalUsage.TotalTokens += usage.TotalTokens
 447	}
 448
 449	agentResult := &AgentResult{
 450		Steps:      steps,
 451		Response:   steps[len(steps)-1].Response,
 452		TotalUsage: totalUsage,
 453	}
 454	return agentResult, nil
 455}
 456
 457func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 458	if len(conditions) == 0 {
 459		return false
 460	}
 461
 462	for _, condition := range conditions {
 463		if condition(steps) {
 464			return true
 465		}
 466	}
 467	return false
 468}
 469
 470func toResponseMessages(content []Content) []Message {
 471	var assistantParts []MessagePart
 472	var toolParts []MessagePart
 473
 474	for _, c := range content {
 475		switch c.GetType() {
 476		case ContentTypeText:
 477			text, ok := AsContentType[TextContent](c)
 478			if !ok {
 479				continue
 480			}
 481			assistantParts = append(assistantParts, TextPart{
 482				Text:            text.Text,
 483				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 484			})
 485		case ContentTypeReasoning:
 486			reasoning, ok := AsContentType[ReasoningContent](c)
 487			if !ok {
 488				continue
 489			}
 490			assistantParts = append(assistantParts, ReasoningPart{
 491				Text:            reasoning.Text,
 492				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 493			})
 494		case ContentTypeToolCall:
 495			toolCall, ok := AsContentType[ToolCallContent](c)
 496			if !ok {
 497				continue
 498			}
 499			assistantParts = append(assistantParts, ToolCallPart{
 500				ToolCallID:       toolCall.ToolCallID,
 501				ToolName:         toolCall.ToolName,
 502				Input:            toolCall.Input,
 503				ProviderExecuted: toolCall.ProviderExecuted,
 504				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 505			})
 506		case ContentTypeFile:
 507			file, ok := AsContentType[FileContent](c)
 508			if !ok {
 509				continue
 510			}
 511			assistantParts = append(assistantParts, FilePart{
 512				Data:            file.Data,
 513				MediaType:       file.MediaType,
 514				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 515			})
 516		case ContentTypeSource:
 517			// Sources are metadata about references used to generate the response.
 518			// They don't need to be included in the conversation messages.
 519			continue
 520		case ContentTypeToolResult:
 521			result, ok := AsContentType[ToolResultContent](c)
 522			if !ok {
 523				continue
 524			}
 525			toolParts = append(toolParts, ToolResultPart{
 526				ToolCallID:      result.ToolCallID,
 527				Output:          result.Result,
 528				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 529			})
 530		}
 531	}
 532
 533	var messages []Message
 534	if len(assistantParts) > 0 {
 535		messages = append(messages, Message{
 536			Role:    MessageRoleAssistant,
 537			Content: assistantParts,
 538		})
 539	}
 540	if len(toolParts) > 0 {
 541		messages = append(messages, Message{
 542			Role:    MessageRoleTool,
 543			Content: toolParts,
 544		})
 545	}
 546	return messages
 547}
 548
 549func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
 550	if len(toolCalls) == 0 {
 551		return nil, nil
 552	}
 553
 554	// Create a map for quick tool lookup
 555	toolMap := make(map[string]AgentTool)
 556	for _, tool := range allTools {
 557		toolMap[tool.Info().Name] = tool
 558	}
 559
 560	// Execute all tool calls in parallel
 561	results := make([]ToolResultContent, len(toolCalls))
 562	var toolExecutionError error
 563	var wg sync.WaitGroup
 564
 565	for i, toolCall := range toolCalls {
 566		wg.Add(1)
 567		go func(index int, call ToolCallContent) {
 568			defer wg.Done()
 569
 570			// Skip invalid tool calls - create error result
 571			if call.Invalid {
 572				results[index] = ToolResultContent{
 573					ToolCallID: call.ToolCallID,
 574					ToolName:   call.ToolName,
 575					Result: ToolResultOutputContentError{
 576						Error: call.ValidationError,
 577					},
 578					ProviderExecuted: false,
 579				}
 580				if toolResultCallback != nil {
 581					toolResultCallback(results[index])
 582				}
 583
 584				return
 585			}
 586
 587			tool, exists := toolMap[call.ToolName]
 588			if !exists {
 589				results[index] = ToolResultContent{
 590					ToolCallID: call.ToolCallID,
 591					ToolName:   call.ToolName,
 592					Result: ToolResultOutputContentError{
 593						Error: errors.New("Error: Tool not found: " + call.ToolName),
 594					},
 595					ProviderExecuted: false,
 596				}
 597
 598				if toolResultCallback != nil {
 599					toolResultCallback(results[index])
 600				}
 601				return
 602			}
 603
 604			// Execute the tool
 605			result, err := tool.Run(ctx, ToolCall{
 606				ID:    call.ToolCallID,
 607				Name:  call.ToolName,
 608				Input: call.Input,
 609			})
 610			if err != nil {
 611				results[index] = ToolResultContent{
 612					ToolCallID: call.ToolCallID,
 613					ToolName:   call.ToolName,
 614					Result: ToolResultOutputContentError{
 615						Error: err,
 616					},
 617					ClientMetadata:   result.Metadata,
 618					ProviderExecuted: false,
 619				}
 620				if toolResultCallback != nil {
 621					toolResultCallback(results[index])
 622				}
 623				toolExecutionError = err
 624				return
 625			}
 626
 627			if result.IsError {
 628				results[index] = ToolResultContent{
 629					ToolCallID: call.ToolCallID,
 630					ToolName:   call.ToolName,
 631					Result: ToolResultOutputContentError{
 632						Error: errors.New(result.Content),
 633					},
 634					ClientMetadata:   result.Metadata,
 635					ProviderExecuted: false,
 636				}
 637
 638				if toolResultCallback != nil {
 639					toolResultCallback(results[index])
 640				}
 641			} else {
 642				results[index] = ToolResultContent{
 643					ToolCallID: call.ToolCallID,
 644					ToolName:   toolCall.ToolName,
 645					Result: ToolResultOutputContentText{
 646						Text: result.Content,
 647					},
 648					ClientMetadata:   result.Metadata,
 649					ProviderExecuted: false,
 650				}
 651				if toolResultCallback != nil {
 652					toolResultCallback(results[index])
 653				}
 654			}
 655		}(i, toolCall)
 656	}
 657
 658	// Wait for all tool executions to complete
 659	wg.Wait()
 660
 661	return results, toolExecutionError
 662}
 663
 664// Stream implements Agent.
 665func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 666	// Convert AgentStreamCall to AgentCall for preparation
 667	call := AgentCall{
 668		Prompt:           opts.Prompt,
 669		Files:            opts.Files,
 670		Messages:         opts.Messages,
 671		MaxOutputTokens:  opts.MaxOutputTokens,
 672		Temperature:      opts.Temperature,
 673		TopP:             opts.TopP,
 674		TopK:             opts.TopK,
 675		PresencePenalty:  opts.PresencePenalty,
 676		FrequencyPenalty: opts.FrequencyPenalty,
 677		ActiveTools:      opts.ActiveTools,
 678		Headers:          opts.Headers,
 679		ProviderOptions:  opts.ProviderOptions,
 680		MaxRetries:       opts.MaxRetries,
 681		StopWhen:         opts.StopWhen,
 682		PrepareStep:      opts.PrepareStep,
 683		RepairToolCall:   opts.RepairToolCall,
 684	}
 685
 686	call = a.prepareCall(call)
 687
 688	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 689	if err != nil {
 690		return nil, err
 691	}
 692
 693	var responseMessages []Message
 694	var steps []StepResult
 695	var totalUsage Usage
 696
 697	// Start agent stream
 698	if opts.OnAgentStart != nil {
 699		opts.OnAgentStart()
 700	}
 701
 702	for stepNumber := 0; ; stepNumber++ {
 703		stepInputMessages := append(initialPrompt, responseMessages...)
 704		stepModel := a.settings.model
 705		stepSystemPrompt := a.settings.systemPrompt
 706		stepActiveTools := call.ActiveTools
 707		stepToolChoice := ToolChoiceAuto
 708		disableAllTools := false
 709
 710		// Apply step preparation if provided
 711		if call.PrepareStep != nil {
 712			prepared := call.PrepareStep(PrepareStepFunctionOptions{
 713				Model:      stepModel,
 714				Steps:      steps,
 715				StepNumber: stepNumber,
 716				Messages:   stepInputMessages,
 717			})
 718
 719			if prepared.Messages != nil {
 720				stepInputMessages = prepared.Messages
 721			}
 722			if prepared.Model != nil {
 723				stepModel = prepared.Model
 724			}
 725			if prepared.System != nil {
 726				stepSystemPrompt = *prepared.System
 727			}
 728			if prepared.ToolChoice != nil {
 729				stepToolChoice = *prepared.ToolChoice
 730			}
 731			if len(prepared.ActiveTools) > 0 {
 732				stepActiveTools = prepared.ActiveTools
 733			}
 734			disableAllTools = prepared.DisableAllTools
 735		}
 736
 737		// Recreate prompt with potentially modified system prompt
 738		if stepSystemPrompt != a.settings.systemPrompt {
 739			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 740			if err != nil {
 741				return nil, err
 742			}
 743			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 744				stepInputMessages[0] = stepPrompt[0]
 745			}
 746		}
 747
 748		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 749
 750		// Start step stream
 751		if opts.OnStepStart != nil {
 752			opts.OnStepStart(stepNumber)
 753		}
 754
 755		// Create streaming call
 756		streamCall := Call{
 757			Prompt:           stepInputMessages,
 758			MaxOutputTokens:  call.MaxOutputTokens,
 759			Temperature:      call.Temperature,
 760			TopP:             call.TopP,
 761			TopK:             call.TopK,
 762			PresencePenalty:  call.PresencePenalty,
 763			FrequencyPenalty: call.FrequencyPenalty,
 764			Tools:            preparedTools,
 765			ToolChoice:       &stepToolChoice,
 766			Headers:          call.Headers,
 767			ProviderOptions:  call.ProviderOptions,
 768		}
 769
 770		// Get streaming response
 771		stream, err := stepModel.Stream(ctx, streamCall)
 772		if err != nil {
 773			if opts.OnError != nil {
 774				opts.OnError(err)
 775			}
 776			return nil, err
 777		}
 778
 779		// Process stream with tool execution
 780		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 781		if err != nil {
 782			if opts.OnError != nil {
 783				opts.OnError(err)
 784			}
 785			return nil, err
 786		}
 787
 788		steps = append(steps, stepResult)
 789		totalUsage = addUsage(totalUsage, stepResult.Usage)
 790
 791		// Call step finished callback
 792		if opts.OnStepFinish != nil {
 793			opts.OnStepFinish(stepResult)
 794		}
 795
 796		// Add step messages to response messages
 797		stepMessages := toResponseMessages(stepResult.Content)
 798		responseMessages = append(responseMessages, stepMessages...)
 799
 800		// Check stop conditions
 801		shouldStop := isStopConditionMet(call.StopWhen, steps)
 802		if shouldStop || !shouldContinue {
 803			break
 804		}
 805	}
 806
 807	// Finish agent stream
 808	agentResult := &AgentResult{
 809		Steps:      steps,
 810		Response:   steps[len(steps)-1].Response,
 811		TotalUsage: totalUsage,
 812	}
 813
 814	if opts.OnFinish != nil {
 815		opts.OnFinish(agentResult)
 816	}
 817
 818	if opts.OnAgentFinish != nil {
 819		opts.OnAgentFinish(agentResult)
 820	}
 821
 822	return agentResult, nil
 823}
 824
 825func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 826	var preparedTools []Tool
 827
 828	// If explicitly disabling all tools, return no tools
 829	if disableAllTools {
 830		return preparedTools
 831	}
 832
 833	for _, tool := range tools {
 834		// If activeTools has items, only include tools in the list
 835		// If activeTools is empty, include all tools
 836		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 837			continue
 838		}
 839		info := tool.Info()
 840		preparedTools = append(preparedTools, FunctionTool{
 841			Name:        info.Name,
 842			Description: info.Description,
 843			InputSchema: map[string]any{
 844				"type":       "object",
 845				"properties": info.Parameters,
 846				"required":   info.Required,
 847			},
 848		})
 849	}
 850	return preparedTools
 851}
 852
 853// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
 854func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 855	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 856		return toolCall
 857	} else {
 858		if repairFunc != nil {
 859			repairOptions := ToolCallRepairOptions{
 860				OriginalToolCall: toolCall,
 861				ValidationError:  err,
 862				AvailableTools:   availableTools,
 863				SystemPrompt:     systemPrompt,
 864				Messages:         messages,
 865			}
 866
 867			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 868				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 869					return *repairedToolCall
 870				}
 871			}
 872		}
 873
 874		invalidToolCall := toolCall
 875		invalidToolCall.Invalid = true
 876		invalidToolCall.ValidationError = err
 877		return invalidToolCall
 878	}
 879}
 880
 881// validateToolCall validates a tool call against available tools and their schemas
 882func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 883	var tool AgentTool
 884	for _, t := range availableTools {
 885		if t.Info().Name == toolCall.ToolName {
 886			tool = t
 887			break
 888		}
 889	}
 890
 891	if tool == nil {
 892		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 893	}
 894
 895	// Validate JSON parsing
 896	var input map[string]any
 897	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 898		return fmt.Errorf("invalid JSON input: %w", err)
 899	}
 900
 901	// Basic schema validation (check required fields)
 902	// TODO: more robust schema validation using JSON Schema or similar
 903	toolInfo := tool.Info()
 904	for _, required := range toolInfo.Required {
 905		if _, exists := input[required]; !exists {
 906			return fmt.Errorf("missing required parameter: %s", required)
 907		}
 908	}
 909	return nil
 910}
 911
 912func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 913	if prompt == "" {
 914		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 915	}
 916
 917	var preparedPrompt Prompt
 918
 919	if system != "" {
 920		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 921	}
 922
 923	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 924	preparedPrompt = append(preparedPrompt, messages...)
 925	return preparedPrompt, nil
 926}
 927
 928func WithSystemPrompt(prompt string) agentOption {
 929	return func(s *AgentSettings) {
 930		s.systemPrompt = prompt
 931	}
 932}
 933
 934func WithMaxOutputTokens(tokens int64) agentOption {
 935	return func(s *AgentSettings) {
 936		s.maxOutputTokens = &tokens
 937	}
 938}
 939
 940func WithTemperature(temp float64) agentOption {
 941	return func(s *AgentSettings) {
 942		s.temperature = &temp
 943	}
 944}
 945
 946func WithTopP(topP float64) agentOption {
 947	return func(s *AgentSettings) {
 948		s.topP = &topP
 949	}
 950}
 951
 952func WithTopK(topK int64) agentOption {
 953	return func(s *AgentSettings) {
 954		s.topK = &topK
 955	}
 956}
 957
 958func WithPresencePenalty(penalty float64) agentOption {
 959	return func(s *AgentSettings) {
 960		s.presencePenalty = &penalty
 961	}
 962}
 963
 964func WithFrequencyPenalty(penalty float64) agentOption {
 965	return func(s *AgentSettings) {
 966		s.frequencyPenalty = &penalty
 967	}
 968}
 969
 970func WithTools(tools ...AgentTool) agentOption {
 971	return func(s *AgentSettings) {
 972		s.tools = append(s.tools, tools...)
 973	}
 974}
 975
 976func WithStopConditions(conditions ...StopCondition) agentOption {
 977	return func(s *AgentSettings) {
 978		s.stopWhen = append(s.stopWhen, conditions...)
 979	}
 980}
 981
 982func WithPrepareStep(fn PrepareStepFunction) agentOption {
 983	return func(s *AgentSettings) {
 984		s.prepareStep = fn
 985	}
 986}
 987
 988func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
 989	return func(s *AgentSettings) {
 990		s.repairToolCall = fn
 991	}
 992}
 993
 994func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
 995	return func(s *AgentSettings) {
 996		s.onStepFinished = fn
 997	}
 998}
 999
1000// processStepStream processes a single step's stream and returns the step result
1001func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1002	var stepContent []Content
1003	var stepToolCalls []ToolCallContent
1004	var stepUsage Usage
1005	var stepFinishReason FinishReason = FinishReasonUnknown
1006	var stepWarnings []CallWarning
1007	var stepProviderMetadata ProviderMetadata
1008
1009	activeToolCalls := make(map[string]*ToolCallContent)
1010	activeTextContent := make(map[string]string)
1011
1012	// Process stream parts
1013	for part := range stream {
1014		// Forward all parts to chunk callback
1015		if opts.OnChunk != nil {
1016			opts.OnChunk(part)
1017		}
1018
1019		switch part.Type {
1020		case StreamPartTypeWarnings:
1021			stepWarnings = part.Warnings
1022			if opts.OnWarnings != nil {
1023				opts.OnWarnings(part.Warnings)
1024			}
1025
1026		case StreamPartTypeTextStart:
1027			activeTextContent[part.ID] = ""
1028			if opts.OnTextStart != nil {
1029				opts.OnTextStart(part.ID)
1030			}
1031
1032		case StreamPartTypeTextDelta:
1033			if _, exists := activeTextContent[part.ID]; exists {
1034				activeTextContent[part.ID] += part.Delta
1035			}
1036			if opts.OnTextDelta != nil {
1037				opts.OnTextDelta(part.ID, part.Delta)
1038			}
1039
1040		case StreamPartTypeTextEnd:
1041			if text, exists := activeTextContent[part.ID]; exists {
1042				stepContent = append(stepContent, TextContent{
1043					Text:             text,
1044					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1045				})
1046				delete(activeTextContent, part.ID)
1047			}
1048			if opts.OnTextEnd != nil {
1049				opts.OnTextEnd(part.ID)
1050			}
1051
1052		case StreamPartTypeReasoningStart:
1053			activeTextContent[part.ID] = ""
1054			if opts.OnReasoningStart != nil {
1055				opts.OnReasoningStart(part.ID)
1056			}
1057
1058		case StreamPartTypeReasoningDelta:
1059			if _, exists := activeTextContent[part.ID]; exists {
1060				activeTextContent[part.ID] += part.Delta
1061			}
1062			if opts.OnReasoningDelta != nil {
1063				opts.OnReasoningDelta(part.ID, part.Delta)
1064			}
1065
1066		case StreamPartTypeReasoningEnd:
1067			if text, exists := activeTextContent[part.ID]; exists {
1068				stepContent = append(stepContent, ReasoningContent{
1069					Text:             text,
1070					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1071				})
1072				delete(activeTextContent, part.ID)
1073			}
1074			if opts.OnReasoningEnd != nil {
1075				opts.OnReasoningEnd(part.ID)
1076			}
1077
1078		case StreamPartTypeToolInputStart:
1079			activeToolCalls[part.ID] = &ToolCallContent{
1080				ToolCallID:       part.ID,
1081				ToolName:         part.ToolCallName,
1082				Input:            "",
1083				ProviderExecuted: part.ProviderExecuted,
1084			}
1085			if opts.OnToolInputStart != nil {
1086				opts.OnToolInputStart(part.ID, part.ToolCallName)
1087			}
1088
1089		case StreamPartTypeToolInputDelta:
1090			if toolCall, exists := activeToolCalls[part.ID]; exists {
1091				toolCall.Input += part.Delta
1092			}
1093			if opts.OnToolInputDelta != nil {
1094				opts.OnToolInputDelta(part.ID, part.Delta)
1095			}
1096
1097		case StreamPartTypeToolInputEnd:
1098			if opts.OnToolInputEnd != nil {
1099				opts.OnToolInputEnd(part.ID)
1100			}
1101
1102		case StreamPartTypeToolCall:
1103			toolCall := ToolCallContent{
1104				ToolCallID:       part.ID,
1105				ToolName:         part.ToolCallName,
1106				Input:            part.ToolCallInput,
1107				ProviderExecuted: part.ProviderExecuted,
1108				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1109			}
1110
1111			// Validate and potentially repair the tool call
1112			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1113			stepToolCalls = append(stepToolCalls, validatedToolCall)
1114			stepContent = append(stepContent, validatedToolCall)
1115
1116			if opts.OnToolCall != nil {
1117				opts.OnToolCall(validatedToolCall)
1118			}
1119
1120			// Clean up active tool call
1121			delete(activeToolCalls, part.ID)
1122
1123		case StreamPartTypeSource:
1124			sourceContent := SourceContent{
1125				SourceType:       part.SourceType,
1126				ID:               part.ID,
1127				URL:              part.URL,
1128				Title:            part.Title,
1129				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1130			}
1131			stepContent = append(stepContent, sourceContent)
1132			if opts.OnSource != nil {
1133				opts.OnSource(sourceContent)
1134			}
1135
1136		case StreamPartTypeFinish:
1137			stepUsage = part.Usage
1138			stepFinishReason = part.FinishReason
1139			stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1140			if opts.OnStreamFinish != nil {
1141				opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1142			}
1143
1144		case StreamPartTypeError:
1145			if opts.OnStreamError != nil {
1146				opts.OnStreamError(part.Error)
1147			}
1148			if opts.OnError != nil {
1149				opts.OnError(part.Error)
1150			}
1151			return StepResult{}, false, part.Error
1152		}
1153	}
1154
1155	// Execute tools if any
1156	var toolResults []ToolResultContent
1157	if len(stepToolCalls) > 0 {
1158		var err error
1159		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1160		if err != nil {
1161			return StepResult{}, false, err
1162		}
1163		// Add tool results to content
1164		for _, result := range toolResults {
1165			stepContent = append(stepContent, result)
1166		}
1167	}
1168
1169	stepResult := StepResult{
1170		Response: Response{
1171			Content:          stepContent,
1172			FinishReason:     stepFinishReason,
1173			Usage:            stepUsage,
1174			Warnings:         stepWarnings,
1175			ProviderMetadata: stepProviderMetadata,
1176		},
1177		Messages: toResponseMessages(stepContent),
1178	}
1179
1180	// Determine if we should continue (has tool calls and not stopped)
1181	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1182
1183	return stepResult, shouldContinue, nil
1184}
1185
1186func addUsage(a, b Usage) Usage {
1187	return Usage{
1188		InputTokens:         a.InputTokens + b.InputTokens,
1189		OutputTokens:        a.OutputTokens + b.OutputTokens,
1190		TotalTokens:         a.TotalTokens + b.TotalTokens,
1191		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1192		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1193		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1194	}
1195}
1196
1197func WithHeaders(headers map[string]string) agentOption {
1198	return func(s *AgentSettings) {
1199		s.headers = headers
1200	}
1201}
1202
1203func WithProviderOptions(providerOptions ProviderOptions) agentOption {
1204	return func(s *AgentSettings) {
1205		s.providerOptions = providerOptions
1206	}
1207}