agent.go

   1package ai
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"maps"
   9	"slices"
  10	"sync"
  11)
  12
  13type StepResult struct {
  14	Response
  15	Messages []Message
  16}
  17
  18type StopCondition = func(steps []StepResult) bool
  19
  20// StepCountIs returns a stop condition that stops after the specified number of steps.
  21func StepCountIs(stepCount int) StopCondition {
  22	return func(steps []StepResult) bool {
  23		return len(steps) >= stepCount
  24	}
  25}
  26
  27// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  28func HasToolCall(toolName string) StopCondition {
  29	return func(steps []StepResult) bool {
  30		if len(steps) == 0 {
  31			return false
  32		}
  33		lastStep := steps[len(steps)-1]
  34		toolCalls := lastStep.Content.ToolCalls()
  35		for _, toolCall := range toolCalls {
  36			if toolCall.ToolName == toolName {
  37				return true
  38			}
  39		}
  40		return false
  41	}
  42}
  43
  44// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  45func HasContent(contentType ContentType) StopCondition {
  46	return func(steps []StepResult) bool {
  47		if len(steps) == 0 {
  48			return false
  49		}
  50		lastStep := steps[len(steps)-1]
  51		for _, content := range lastStep.Content {
  52			if content.GetType() == contentType {
  53				return true
  54			}
  55		}
  56		return false
  57	}
  58}
  59
  60// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  61func FinishReasonIs(reason FinishReason) StopCondition {
  62	return func(steps []StepResult) bool {
  63		if len(steps) == 0 {
  64			return false
  65		}
  66		lastStep := steps[len(steps)-1]
  67		return lastStep.FinishReason == reason
  68	}
  69}
  70
  71// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  72func MaxTokensUsed(maxTokens int64) StopCondition {
  73	return func(steps []StepResult) bool {
  74		var totalTokens int64
  75		for _, step := range steps {
  76			totalTokens += step.Usage.TotalTokens
  77		}
  78		return totalTokens >= maxTokens
  79	}
  80}
  81
  82type PrepareStepFunctionOptions struct {
  83	Steps      []StepResult
  84	StepNumber int
  85	Model      LanguageModel
  86	Messages   []Message
  87}
  88
  89type PrepareStepResult struct {
  90	Model           LanguageModel
  91	Messages        []Message
  92	System          *string
  93	ToolChoice      *ToolChoice
  94	ActiveTools     []string
  95	DisableAllTools bool
  96}
  97
  98type ToolCallRepairOptions struct {
  99	OriginalToolCall ToolCallContent
 100	ValidationError  error
 101	AvailableTools   []AgentTool
 102	SystemPrompt     string
 103	Messages         []Message
 104}
 105
 106type (
 107	PrepareStepFunction    = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
 108	OnStepFinishedFunction = func(step StepResult)
 109	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 110)
 111
 112type agentSettings struct {
 113	systemPrompt     string
 114	maxOutputTokens  *int64
 115	temperature      *float64
 116	topP             *float64
 117	topK             *int64
 118	presencePenalty  *float64
 119	frequencyPenalty *float64
 120	headers          map[string]string
 121	providerOptions  ProviderOptions
 122
 123	// TODO: add support for provider tools
 124	tools      []AgentTool
 125	maxRetries *int
 126
 127	model LanguageModel
 128
 129	stopWhen       []StopCondition
 130	prepareStep    PrepareStepFunction
 131	repairToolCall RepairToolCallFunction
 132	onRetry        OnRetryCallback
 133}
 134
 135type AgentCall struct {
 136	Prompt           string     `json:"prompt"`
 137	Files            []FilePart `json:"files"`
 138	Messages         []Message  `json:"messages"`
 139	MaxOutputTokens  *int64
 140	Temperature      *float64 `json:"temperature"`
 141	TopP             *float64 `json:"top_p"`
 142	TopK             *int64   `json:"top_k"`
 143	PresencePenalty  *float64 `json:"presence_penalty"`
 144	FrequencyPenalty *float64 `json:"frequency_penalty"`
 145	ActiveTools      []string `json:"active_tools"`
 146	Headers          map[string]string
 147	ProviderOptions  ProviderOptions
 148	OnRetry          OnRetryCallback
 149	MaxRetries       *int
 150
 151	StopWhen       []StopCondition
 152	PrepareStep    PrepareStepFunction
 153	RepairToolCall RepairToolCallFunction
 154}
 155
 156type AgentStreamCall struct {
 157	Prompt           string     `json:"prompt"`
 158	Files            []FilePart `json:"files"`
 159	Messages         []Message  `json:"messages"`
 160	MaxOutputTokens  *int64
 161	Temperature      *float64 `json:"temperature"`
 162	TopP             *float64 `json:"top_p"`
 163	TopK             *int64   `json:"top_k"`
 164	PresencePenalty  *float64 `json:"presence_penalty"`
 165	FrequencyPenalty *float64 `json:"frequency_penalty"`
 166	ActiveTools      []string `json:"active_tools"`
 167	Headers          map[string]string
 168	ProviderOptions  ProviderOptions
 169	OnRetry          OnRetryCallback
 170	MaxRetries       *int
 171
 172	StopWhen       []StopCondition
 173	PrepareStep    PrepareStepFunction
 174	RepairToolCall RepairToolCallFunction
 175
 176	// Agent-level callbacks
 177	OnAgentStart  func()                            // Called when agent starts
 178	OnAgentFinish func(result *AgentResult) error   // Called when agent finishes
 179	OnStepStart   func(stepNumber int) error        // Called when a step starts
 180	OnStepFinish  func(stepResult StepResult) error // Called when a step finishes
 181	OnFinish      func(result *AgentResult)         // Called when entire agent completes
 182	OnError       func(error)                       // Called when an error occurs
 183
 184	// Stream part callbacks - called for each corresponding stream part type
 185	OnChunk          func(StreamPart) error                                                                // Called for each stream part (catch-all)
 186	OnWarnings       func(warnings []CallWarning) error                                                    // Called for warnings
 187	OnTextStart      func(id string) error                                                                 // Called when text starts
 188	OnTextDelta      func(id, text string) error                                                           // Called for text deltas
 189	OnTextEnd        func(id string) error                                                                 // Called when text ends
 190	OnReasoningStart func(id string) error                                                                 // Called when reasoning starts
 191	OnReasoningDelta func(id, text string) error                                                           // Called for reasoning deltas
 192	OnReasoningEnd   func(id string, reasoning ReasoningContent) error                                     // Called when reasoning ends
 193	OnToolInputStart func(id, toolName string) error                                                       // Called when tool input starts
 194	OnToolInputDelta func(id, delta string) error                                                          // Called for tool input deltas
 195	OnToolInputEnd   func(id string) error                                                                 // Called when tool input ends
 196	OnToolCall       func(toolCall ToolCallContent) error                                                  // Called when tool call is complete
 197	OnToolResult     func(result ToolResultContent) error                                                  // Called when tool execution completes
 198	OnSource         func(source SourceContent) error                                                      // Called for source references
 199	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error // Called when stream finishes
 200}
 201
 202type AgentResult struct {
 203	Steps []StepResult
 204	// Final response
 205	Response   Response
 206	TotalUsage Usage
 207}
 208
 209type Agent interface {
 210	Generate(context.Context, AgentCall) (*AgentResult, error)
 211	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 212}
 213
 214type AgentOption = func(*agentSettings)
 215
 216type agent struct {
 217	settings agentSettings
 218}
 219
 220func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 221	settings := agentSettings{
 222		model: model,
 223	}
 224	for _, o := range opts {
 225		o(&settings)
 226	}
 227	return &agent{
 228		settings: settings,
 229	}
 230}
 231
 232func (a *agent) prepareCall(call AgentCall) AgentCall {
 233	if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
 234		call.MaxOutputTokens = a.settings.maxOutputTokens
 235	}
 236	if call.Temperature == nil && a.settings.temperature != nil {
 237		call.Temperature = a.settings.temperature
 238	}
 239	if call.TopP == nil && a.settings.topP != nil {
 240		call.TopP = a.settings.topP
 241	}
 242	if call.TopK == nil && a.settings.topK != nil {
 243		call.TopK = a.settings.topK
 244	}
 245	if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
 246		call.PresencePenalty = a.settings.presencePenalty
 247	}
 248	if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
 249		call.FrequencyPenalty = a.settings.frequencyPenalty
 250	}
 251	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 252		call.StopWhen = a.settings.stopWhen
 253	}
 254	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 255		call.PrepareStep = a.settings.prepareStep
 256	}
 257	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 258		call.RepairToolCall = a.settings.repairToolCall
 259	}
 260	if call.OnRetry == nil && a.settings.onRetry != nil {
 261		call.OnRetry = a.settings.onRetry
 262	}
 263	if call.MaxRetries == nil && a.settings.maxRetries != nil {
 264		call.MaxRetries = a.settings.maxRetries
 265	}
 266
 267	providerOptions := ProviderOptions{}
 268	if a.settings.providerOptions != nil {
 269		maps.Copy(providerOptions, a.settings.providerOptions)
 270	}
 271	if call.ProviderOptions != nil {
 272		maps.Copy(providerOptions, call.ProviderOptions)
 273	}
 274	call.ProviderOptions = providerOptions
 275
 276	headers := map[string]string{}
 277
 278	if a.settings.headers != nil {
 279		maps.Copy(headers, a.settings.headers)
 280	}
 281
 282	if call.Headers != nil {
 283		maps.Copy(headers, call.Headers)
 284	}
 285	call.Headers = headers
 286	return call
 287}
 288
 289// Generate implements Agent.
 290func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 291	opts = a.prepareCall(opts)
 292	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 293	if err != nil {
 294		return nil, err
 295	}
 296	var responseMessages []Message
 297	var steps []StepResult
 298
 299	for {
 300		stepInputMessages := append(initialPrompt, responseMessages...)
 301		stepModel := a.settings.model
 302		stepSystemPrompt := a.settings.systemPrompt
 303		stepActiveTools := opts.ActiveTools
 304		stepToolChoice := ToolChoiceAuto
 305		disableAllTools := false
 306
 307		if opts.PrepareStep != nil {
 308			prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
 309				Model:      stepModel,
 310				Steps:      steps,
 311				StepNumber: len(steps),
 312				Messages:   stepInputMessages,
 313			})
 314			if err != nil {
 315				return nil, err
 316			}
 317
 318			// Apply prepared step modifications
 319			if prepared.Messages != nil {
 320				stepInputMessages = prepared.Messages
 321			}
 322			if prepared.Model != nil {
 323				stepModel = prepared.Model
 324			}
 325			if prepared.System != nil {
 326				stepSystemPrompt = *prepared.System
 327			}
 328			if prepared.ToolChoice != nil {
 329				stepToolChoice = *prepared.ToolChoice
 330			}
 331			if len(prepared.ActiveTools) > 0 {
 332				stepActiveTools = prepared.ActiveTools
 333			}
 334			disableAllTools = prepared.DisableAllTools
 335		}
 336
 337		// Recreate prompt with potentially modified system prompt
 338		if stepSystemPrompt != a.settings.systemPrompt {
 339			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 340			if err != nil {
 341				return nil, err
 342			}
 343			// Replace system message part, keep the rest
 344			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 345				stepInputMessages[0] = stepPrompt[0] // Replace system message
 346			}
 347		}
 348
 349		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 350
 351		retryOptions := DefaultRetryOptions()
 352		retryOptions.OnRetry = opts.OnRetry
 353		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 354
 355		result, err := retry(ctx, func() (*Response, error) {
 356			return stepModel.Generate(ctx, Call{
 357				Prompt:           stepInputMessages,
 358				MaxOutputTokens:  opts.MaxOutputTokens,
 359				Temperature:      opts.Temperature,
 360				TopP:             opts.TopP,
 361				TopK:             opts.TopK,
 362				PresencePenalty:  opts.PresencePenalty,
 363				FrequencyPenalty: opts.FrequencyPenalty,
 364				Tools:            preparedTools,
 365				ToolChoice:       &stepToolChoice,
 366				Headers:          opts.Headers,
 367				ProviderOptions:  opts.ProviderOptions,
 368			})
 369		})
 370		if err != nil {
 371			return nil, err
 372		}
 373
 374		var stepToolCalls []ToolCallContent
 375		for _, content := range result.Content {
 376			if content.GetType() == ContentTypeToolCall {
 377				toolCall, ok := AsContentType[ToolCallContent](content)
 378				if !ok {
 379					continue
 380				}
 381
 382				// Validate and potentially repair the tool call
 383				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 384				stepToolCalls = append(stepToolCalls, validatedToolCall)
 385			}
 386		}
 387
 388		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 389
 390		// Build step content with validated tool calls and tool results
 391		stepContent := []Content{}
 392		toolCallIndex := 0
 393		for _, content := range result.Content {
 394			if content.GetType() == ContentTypeToolCall {
 395				// Replace with validated tool call
 396				if toolCallIndex < len(stepToolCalls) {
 397					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 398					toolCallIndex++
 399				}
 400			} else {
 401				// Keep other content as-is
 402				stepContent = append(stepContent, content)
 403			}
 404		}
 405		// Add tool results
 406		for _, result := range toolResults {
 407			stepContent = append(stepContent, result)
 408		}
 409		currentStepMessages := toResponseMessages(stepContent)
 410		responseMessages = append(responseMessages, currentStepMessages...)
 411
 412		stepResult := StepResult{
 413			Response: Response{
 414				Content:          stepContent,
 415				FinishReason:     result.FinishReason,
 416				Usage:            result.Usage,
 417				Warnings:         result.Warnings,
 418				ProviderMetadata: result.ProviderMetadata,
 419			},
 420			Messages: currentStepMessages,
 421		}
 422		steps = append(steps, stepResult)
 423		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 424
 425		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 426			break
 427		}
 428	}
 429
 430	totalUsage := Usage{}
 431
 432	for _, step := range steps {
 433		usage := step.Usage
 434		totalUsage.InputTokens += usage.InputTokens
 435		totalUsage.OutputTokens += usage.OutputTokens
 436		totalUsage.ReasoningTokens += usage.ReasoningTokens
 437		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 438		totalUsage.CacheReadTokens += usage.CacheReadTokens
 439		totalUsage.TotalTokens += usage.TotalTokens
 440	}
 441
 442	agentResult := &AgentResult{
 443		Steps:      steps,
 444		Response:   steps[len(steps)-1].Response,
 445		TotalUsage: totalUsage,
 446	}
 447	return agentResult, nil
 448}
 449
 450func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 451	if len(conditions) == 0 {
 452		return false
 453	}
 454
 455	for _, condition := range conditions {
 456		if condition(steps) {
 457			return true
 458		}
 459	}
 460	return false
 461}
 462
 463func toResponseMessages(content []Content) []Message {
 464	var assistantParts []MessagePart
 465	var toolParts []MessagePart
 466
 467	for _, c := range content {
 468		switch c.GetType() {
 469		case ContentTypeText:
 470			text, ok := AsContentType[TextContent](c)
 471			if !ok {
 472				continue
 473			}
 474			assistantParts = append(assistantParts, TextPart{
 475				Text:            text.Text,
 476				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 477			})
 478		case ContentTypeReasoning:
 479			reasoning, ok := AsContentType[ReasoningContent](c)
 480			if !ok {
 481				continue
 482			}
 483			assistantParts = append(assistantParts, ReasoningPart{
 484				Text:            reasoning.Text,
 485				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 486			})
 487		case ContentTypeToolCall:
 488			toolCall, ok := AsContentType[ToolCallContent](c)
 489			if !ok {
 490				continue
 491			}
 492			assistantParts = append(assistantParts, ToolCallPart{
 493				ToolCallID:       toolCall.ToolCallID,
 494				ToolName:         toolCall.ToolName,
 495				Input:            toolCall.Input,
 496				ProviderExecuted: toolCall.ProviderExecuted,
 497				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 498			})
 499		case ContentTypeFile:
 500			file, ok := AsContentType[FileContent](c)
 501			if !ok {
 502				continue
 503			}
 504			assistantParts = append(assistantParts, FilePart{
 505				Data:            file.Data,
 506				MediaType:       file.MediaType,
 507				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 508			})
 509		case ContentTypeSource:
 510			// Sources are metadata about references used to generate the response.
 511			// They don't need to be included in the conversation messages.
 512			continue
 513		case ContentTypeToolResult:
 514			result, ok := AsContentType[ToolResultContent](c)
 515			if !ok {
 516				continue
 517			}
 518			toolParts = append(toolParts, ToolResultPart{
 519				ToolCallID:      result.ToolCallID,
 520				Output:          result.Result,
 521				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 522			})
 523		}
 524	}
 525
 526	var messages []Message
 527	if len(assistantParts) > 0 {
 528		messages = append(messages, Message{
 529			Role:    MessageRoleAssistant,
 530			Content: assistantParts,
 531		})
 532	}
 533	if len(toolParts) > 0 {
 534		messages = append(messages, Message{
 535			Role:    MessageRoleTool,
 536			Content: toolParts,
 537		})
 538	}
 539	return messages
 540}
 541
 542func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 543	if len(toolCalls) == 0 {
 544		return nil, nil
 545	}
 546
 547	// Create a map for quick tool lookup
 548	toolMap := make(map[string]AgentTool)
 549	for _, tool := range allTools {
 550		toolMap[tool.Info().Name] = tool
 551	}
 552
 553	// Execute all tool calls in parallel
 554	results := make([]ToolResultContent, len(toolCalls))
 555	executeErrors := make([]error, len(toolCalls))
 556
 557	var wg sync.WaitGroup
 558
 559	for i, toolCall := range toolCalls {
 560		wg.Add(1)
 561		go func(index int, call ToolCallContent) {
 562			defer wg.Done()
 563
 564			// Skip invalid tool calls - create error result
 565			if call.Invalid {
 566				results[index] = ToolResultContent{
 567					ToolCallID: call.ToolCallID,
 568					ToolName:   call.ToolName,
 569					Result: ToolResultOutputContentError{
 570						Error: call.ValidationError,
 571					},
 572					ProviderExecuted: false,
 573				}
 574				if toolResultCallback != nil {
 575					err := toolResultCallback(results[index])
 576					if err != nil {
 577						executeErrors[index] = err
 578					}
 579				}
 580
 581				return
 582			}
 583
 584			tool, exists := toolMap[call.ToolName]
 585			if !exists {
 586				results[index] = ToolResultContent{
 587					ToolCallID: call.ToolCallID,
 588					ToolName:   call.ToolName,
 589					Result: ToolResultOutputContentError{
 590						Error: errors.New("Error: Tool not found: " + call.ToolName),
 591					},
 592					ProviderExecuted: false,
 593				}
 594
 595				if toolResultCallback != nil {
 596					err := toolResultCallback(results[index])
 597					if err != nil {
 598						executeErrors[index] = err
 599					}
 600				}
 601				return
 602			}
 603
 604			// Execute the tool
 605			result, err := tool.Run(ctx, ToolCall{
 606				ID:    call.ToolCallID,
 607				Name:  call.ToolName,
 608				Input: call.Input,
 609			})
 610			if err != nil {
 611				results[index] = ToolResultContent{
 612					ToolCallID: call.ToolCallID,
 613					ToolName:   call.ToolName,
 614					Result: ToolResultOutputContentError{
 615						Error: err,
 616					},
 617					ClientMetadata:   result.Metadata,
 618					ProviderExecuted: false,
 619				}
 620				if toolResultCallback != nil {
 621					cbErr := toolResultCallback(results[index])
 622					if cbErr != nil {
 623						executeErrors[index] = cbErr
 624					}
 625				}
 626				executeErrors[index] = err
 627				return
 628			}
 629
 630			if result.IsError {
 631				results[index] = ToolResultContent{
 632					ToolCallID: call.ToolCallID,
 633					ToolName:   call.ToolName,
 634					Result: ToolResultOutputContentError{
 635						Error: errors.New(result.Content),
 636					},
 637					ClientMetadata:   result.Metadata,
 638					ProviderExecuted: false,
 639				}
 640
 641				if toolResultCallback != nil {
 642					err := toolResultCallback(results[index])
 643					if err != nil {
 644						executeErrors[index] = err
 645					}
 646				}
 647			} else {
 648				results[index] = ToolResultContent{
 649					ToolCallID: call.ToolCallID,
 650					ToolName:   toolCall.ToolName,
 651					Result: ToolResultOutputContentText{
 652						Text: result.Content,
 653					},
 654					ClientMetadata:   result.Metadata,
 655					ProviderExecuted: false,
 656				}
 657				if toolResultCallback != nil {
 658					err := toolResultCallback(results[index])
 659					if err != nil {
 660						executeErrors[index] = err
 661					}
 662				}
 663			}
 664		}(i, toolCall)
 665	}
 666
 667	// Wait for all tool executions to complete
 668	wg.Wait()
 669
 670	for _, err := range executeErrors {
 671		if err != nil {
 672			return nil, err
 673		}
 674	}
 675
 676	return results, nil
 677}
 678
 679// Stream implements Agent.
 680func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 681	// Convert AgentStreamCall to AgentCall for preparation
 682	call := AgentCall{
 683		Prompt:           opts.Prompt,
 684		Files:            opts.Files,
 685		Messages:         opts.Messages,
 686		MaxOutputTokens:  opts.MaxOutputTokens,
 687		Temperature:      opts.Temperature,
 688		TopP:             opts.TopP,
 689		TopK:             opts.TopK,
 690		PresencePenalty:  opts.PresencePenalty,
 691		FrequencyPenalty: opts.FrequencyPenalty,
 692		ActiveTools:      opts.ActiveTools,
 693		Headers:          opts.Headers,
 694		ProviderOptions:  opts.ProviderOptions,
 695		MaxRetries:       opts.MaxRetries,
 696		StopWhen:         opts.StopWhen,
 697		PrepareStep:      opts.PrepareStep,
 698		RepairToolCall:   opts.RepairToolCall,
 699	}
 700
 701	call = a.prepareCall(call)
 702
 703	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 704	if err != nil {
 705		return nil, err
 706	}
 707
 708	var responseMessages []Message
 709	var steps []StepResult
 710	var totalUsage Usage
 711
 712	// Start agent stream
 713	if opts.OnAgentStart != nil {
 714		opts.OnAgentStart()
 715	}
 716
 717	for stepNumber := 0; ; stepNumber++ {
 718		stepInputMessages := append(initialPrompt, responseMessages...)
 719		stepModel := a.settings.model
 720		stepSystemPrompt := a.settings.systemPrompt
 721		stepActiveTools := call.ActiveTools
 722		stepToolChoice := ToolChoiceAuto
 723		disableAllTools := false
 724
 725		// Apply step preparation if provided
 726		if call.PrepareStep != nil {
 727			prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
 728				Model:      stepModel,
 729				Steps:      steps,
 730				StepNumber: stepNumber,
 731				Messages:   stepInputMessages,
 732			})
 733			if err != nil {
 734				return nil, err
 735			}
 736
 737			if prepared.Messages != nil {
 738				stepInputMessages = prepared.Messages
 739			}
 740			if prepared.Model != nil {
 741				stepModel = prepared.Model
 742			}
 743			if prepared.System != nil {
 744				stepSystemPrompt = *prepared.System
 745			}
 746			if prepared.ToolChoice != nil {
 747				stepToolChoice = *prepared.ToolChoice
 748			}
 749			if len(prepared.ActiveTools) > 0 {
 750				stepActiveTools = prepared.ActiveTools
 751			}
 752			disableAllTools = prepared.DisableAllTools
 753		}
 754
 755		// Recreate prompt with potentially modified system prompt
 756		if stepSystemPrompt != a.settings.systemPrompt {
 757			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 758			if err != nil {
 759				return nil, err
 760			}
 761			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 762				stepInputMessages[0] = stepPrompt[0]
 763			}
 764		}
 765
 766		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 767
 768		// Start step stream
 769		if opts.OnStepStart != nil {
 770			opts.OnStepStart(stepNumber)
 771		}
 772
 773		// Create streaming call
 774		streamCall := Call{
 775			Prompt:           stepInputMessages,
 776			MaxOutputTokens:  call.MaxOutputTokens,
 777			Temperature:      call.Temperature,
 778			TopP:             call.TopP,
 779			TopK:             call.TopK,
 780			PresencePenalty:  call.PresencePenalty,
 781			FrequencyPenalty: call.FrequencyPenalty,
 782			Tools:            preparedTools,
 783			ToolChoice:       &stepToolChoice,
 784			Headers:          call.Headers,
 785			ProviderOptions:  call.ProviderOptions,
 786		}
 787
 788		// Get streaming response
 789		stream, err := stepModel.Stream(ctx, streamCall)
 790		if err != nil {
 791			if opts.OnError != nil {
 792				opts.OnError(err)
 793			}
 794			return nil, err
 795		}
 796
 797		// Process stream with tool execution
 798		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 799		if err != nil {
 800			if opts.OnError != nil {
 801				opts.OnError(err)
 802			}
 803			return nil, err
 804		}
 805
 806		steps = append(steps, stepResult)
 807		totalUsage = addUsage(totalUsage, stepResult.Usage)
 808
 809		// Call step finished callback
 810		if opts.OnStepFinish != nil {
 811			opts.OnStepFinish(stepResult)
 812		}
 813
 814		// Add step messages to response messages
 815		stepMessages := toResponseMessages(stepResult.Content)
 816		responseMessages = append(responseMessages, stepMessages...)
 817
 818		// Check stop conditions
 819		shouldStop := isStopConditionMet(call.StopWhen, steps)
 820		if shouldStop || !shouldContinue {
 821			break
 822		}
 823	}
 824
 825	// Finish agent stream
 826	agentResult := &AgentResult{
 827		Steps:      steps,
 828		Response:   steps[len(steps)-1].Response,
 829		TotalUsage: totalUsage,
 830	}
 831
 832	if opts.OnFinish != nil {
 833		opts.OnFinish(agentResult)
 834	}
 835
 836	if opts.OnAgentFinish != nil {
 837		opts.OnAgentFinish(agentResult)
 838	}
 839
 840	return agentResult, nil
 841}
 842
 843func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 844	var preparedTools []Tool
 845
 846	// If explicitly disabling all tools, return no tools
 847	if disableAllTools {
 848		return preparedTools
 849	}
 850
 851	for _, tool := range tools {
 852		// If activeTools has items, only include tools in the list
 853		// If activeTools is empty, include all tools
 854		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 855			continue
 856		}
 857		info := tool.Info()
 858		preparedTools = append(preparedTools, FunctionTool{
 859			Name:        info.Name,
 860			Description: info.Description,
 861			InputSchema: map[string]any{
 862				"type":       "object",
 863				"properties": info.Parameters,
 864				"required":   info.Required,
 865			},
 866		})
 867	}
 868	return preparedTools
 869}
 870
 871// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 872func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 873	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 874		return toolCall
 875	} else {
 876		if repairFunc != nil {
 877			repairOptions := ToolCallRepairOptions{
 878				OriginalToolCall: toolCall,
 879				ValidationError:  err,
 880				AvailableTools:   availableTools,
 881				SystemPrompt:     systemPrompt,
 882				Messages:         messages,
 883			}
 884
 885			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 886				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 887					return *repairedToolCall
 888				}
 889			}
 890		}
 891
 892		invalidToolCall := toolCall
 893		invalidToolCall.Invalid = true
 894		invalidToolCall.ValidationError = err
 895		return invalidToolCall
 896	}
 897}
 898
 899// validateToolCall validates a tool call against available tools and their schemas.
 900func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 901	var tool AgentTool
 902	for _, t := range availableTools {
 903		if t.Info().Name == toolCall.ToolName {
 904			tool = t
 905			break
 906		}
 907	}
 908
 909	if tool == nil {
 910		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 911	}
 912
 913	// Validate JSON parsing
 914	var input map[string]any
 915	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 916		return fmt.Errorf("invalid JSON input: %w", err)
 917	}
 918
 919	// Basic schema validation (check required fields)
 920	// TODO: more robust schema validation using JSON Schema or similar
 921	toolInfo := tool.Info()
 922	for _, required := range toolInfo.Required {
 923		if _, exists := input[required]; !exists {
 924			return fmt.Errorf("missing required parameter: %s", required)
 925		}
 926	}
 927	return nil
 928}
 929
 930func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 931	if prompt == "" {
 932		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 933	}
 934
 935	var preparedPrompt Prompt
 936
 937	if system != "" {
 938		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 939	}
 940
 941	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 942	preparedPrompt = append(preparedPrompt, messages...)
 943	return preparedPrompt, nil
 944}
 945
 946func WithSystemPrompt(prompt string) AgentOption {
 947	return func(s *agentSettings) {
 948		s.systemPrompt = prompt
 949	}
 950}
 951
 952func WithMaxOutputTokens(tokens int64) AgentOption {
 953	return func(s *agentSettings) {
 954		s.maxOutputTokens = &tokens
 955	}
 956}
 957
 958func WithTemperature(temp float64) AgentOption {
 959	return func(s *agentSettings) {
 960		s.temperature = &temp
 961	}
 962}
 963
 964func WithTopP(topP float64) AgentOption {
 965	return func(s *agentSettings) {
 966		s.topP = &topP
 967	}
 968}
 969
 970func WithTopK(topK int64) AgentOption {
 971	return func(s *agentSettings) {
 972		s.topK = &topK
 973	}
 974}
 975
 976func WithPresencePenalty(penalty float64) AgentOption {
 977	return func(s *agentSettings) {
 978		s.presencePenalty = &penalty
 979	}
 980}
 981
 982func WithFrequencyPenalty(penalty float64) AgentOption {
 983	return func(s *agentSettings) {
 984		s.frequencyPenalty = &penalty
 985	}
 986}
 987
 988func WithTools(tools ...AgentTool) AgentOption {
 989	return func(s *agentSettings) {
 990		s.tools = append(s.tools, tools...)
 991	}
 992}
 993
 994func WithStopConditions(conditions ...StopCondition) AgentOption {
 995	return func(s *agentSettings) {
 996		s.stopWhen = append(s.stopWhen, conditions...)
 997	}
 998}
 999
1000func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1001	return func(s *agentSettings) {
1002		s.prepareStep = fn
1003	}
1004}
1005
1006func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1007	return func(s *agentSettings) {
1008		s.repairToolCall = fn
1009	}
1010}
1011
1012// processStepStream processes a single step's stream and returns the step result.
1013func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1014	var stepContent []Content
1015	var stepToolCalls []ToolCallContent
1016	var stepUsage Usage
1017	stepFinishReason := FinishReasonUnknown
1018	var stepWarnings []CallWarning
1019	var stepProviderMetadata ProviderMetadata
1020
1021	activeToolCalls := make(map[string]*ToolCallContent)
1022	activeTextContent := make(map[string]string)
1023
1024	// Process stream parts
1025	for part := range stream {
1026		// Forward all parts to chunk callback
1027		if opts.OnChunk != nil {
1028			err := opts.OnChunk(part)
1029			if err != nil {
1030				return StepResult{}, false, err
1031			}
1032		}
1033
1034		switch part.Type {
1035		case StreamPartTypeWarnings:
1036			stepWarnings = part.Warnings
1037			if opts.OnWarnings != nil {
1038				err := opts.OnWarnings(part.Warnings)
1039				if err != nil {
1040					return StepResult{}, false, err
1041				}
1042			}
1043
1044		case StreamPartTypeTextStart:
1045			activeTextContent[part.ID] = ""
1046			if opts.OnTextStart != nil {
1047				err := opts.OnTextStart(part.ID)
1048				if err != nil {
1049					return StepResult{}, false, err
1050				}
1051			}
1052
1053		case StreamPartTypeTextDelta:
1054			if _, exists := activeTextContent[part.ID]; exists {
1055				activeTextContent[part.ID] += part.Delta
1056			}
1057			if opts.OnTextDelta != nil {
1058				err := opts.OnTextDelta(part.ID, part.Delta)
1059				if err != nil {
1060					return StepResult{}, false, err
1061				}
1062			}
1063
1064		case StreamPartTypeTextEnd:
1065			if text, exists := activeTextContent[part.ID]; exists {
1066				stepContent = append(stepContent, TextContent{
1067					Text:             text,
1068					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1069				})
1070				delete(activeTextContent, part.ID)
1071			}
1072			if opts.OnTextEnd != nil {
1073				err := opts.OnTextEnd(part.ID)
1074				if err != nil {
1075					return StepResult{}, false, err
1076				}
1077			}
1078
1079		case StreamPartTypeReasoningStart:
1080			activeTextContent[part.ID] = ""
1081			if opts.OnReasoningStart != nil {
1082				err := opts.OnReasoningStart(part.ID)
1083				if err != nil {
1084					return StepResult{}, false, err
1085				}
1086			}
1087
1088		case StreamPartTypeReasoningDelta:
1089			if _, exists := activeTextContent[part.ID]; exists {
1090				activeTextContent[part.ID] += part.Delta
1091			}
1092			if opts.OnReasoningDelta != nil {
1093				err := opts.OnReasoningDelta(part.ID, part.Delta)
1094				if err != nil {
1095					return StepResult{}, false, err
1096				}
1097			}
1098
1099		case StreamPartTypeReasoningEnd:
1100			if text, exists := activeTextContent[part.ID]; exists {
1101				stepContent = append(stepContent, ReasoningContent{
1102					Text:             text,
1103					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1104				})
1105				if opts.OnReasoningEnd != nil {
1106					err := opts.OnReasoningEnd(part.ID, ReasoningContent{
1107						Text:             text,
1108						ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1109					})
1110					if err != nil {
1111						return StepResult{}, false, err
1112					}
1113				}
1114				delete(activeTextContent, part.ID)
1115			}
1116
1117		case StreamPartTypeToolInputStart:
1118			activeToolCalls[part.ID] = &ToolCallContent{
1119				ToolCallID:       part.ID,
1120				ToolName:         part.ToolCallName,
1121				Input:            "",
1122				ProviderExecuted: part.ProviderExecuted,
1123			}
1124			if opts.OnToolInputStart != nil {
1125				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1126				if err != nil {
1127					return StepResult{}, false, err
1128				}
1129			}
1130
1131		case StreamPartTypeToolInputDelta:
1132			if toolCall, exists := activeToolCalls[part.ID]; exists {
1133				toolCall.Input += part.Delta
1134			}
1135			if opts.OnToolInputDelta != nil {
1136				err := opts.OnToolInputDelta(part.ID, part.Delta)
1137				if err != nil {
1138					return StepResult{}, false, err
1139				}
1140			}
1141
1142		case StreamPartTypeToolInputEnd:
1143			if opts.OnToolInputEnd != nil {
1144				err := opts.OnToolInputEnd(part.ID)
1145				if err != nil {
1146					return StepResult{}, false, err
1147				}
1148			}
1149
1150		case StreamPartTypeToolCall:
1151			toolCall := ToolCallContent{
1152				ToolCallID:       part.ID,
1153				ToolName:         part.ToolCallName,
1154				Input:            part.ToolCallInput,
1155				ProviderExecuted: part.ProviderExecuted,
1156				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1157			}
1158
1159			// Validate and potentially repair the tool call
1160			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1161			stepToolCalls = append(stepToolCalls, validatedToolCall)
1162			stepContent = append(stepContent, validatedToolCall)
1163
1164			if opts.OnToolCall != nil {
1165				err := opts.OnToolCall(validatedToolCall)
1166				if err != nil {
1167					return StepResult{}, false, err
1168				}
1169			}
1170
1171			// Clean up active tool call
1172			delete(activeToolCalls, part.ID)
1173
1174		case StreamPartTypeSource:
1175			sourceContent := SourceContent{
1176				SourceType:       part.SourceType,
1177				ID:               part.ID,
1178				URL:              part.URL,
1179				Title:            part.Title,
1180				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1181			}
1182			stepContent = append(stepContent, sourceContent)
1183			if opts.OnSource != nil {
1184				err := opts.OnSource(sourceContent)
1185				if err != nil {
1186					return StepResult{}, false, err
1187				}
1188			}
1189
1190		case StreamPartTypeFinish:
1191			stepUsage = part.Usage
1192			stepFinishReason = part.FinishReason
1193			stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1194			if opts.OnStreamFinish != nil {
1195				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1196				if err != nil {
1197					return StepResult{}, false, err
1198				}
1199			}
1200
1201		case StreamPartTypeError:
1202			if opts.OnError != nil {
1203				opts.OnError(part.Error)
1204			}
1205			return StepResult{}, false, part.Error
1206		}
1207	}
1208
1209	// Execute tools if any
1210	var toolResults []ToolResultContent
1211	if len(stepToolCalls) > 0 {
1212		var err error
1213		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1214		if err != nil {
1215			return StepResult{}, false, err
1216		}
1217		// Add tool results to content
1218		for _, result := range toolResults {
1219			stepContent = append(stepContent, result)
1220		}
1221	}
1222
1223	stepResult := StepResult{
1224		Response: Response{
1225			Content:          stepContent,
1226			FinishReason:     stepFinishReason,
1227			Usage:            stepUsage,
1228			Warnings:         stepWarnings,
1229			ProviderMetadata: stepProviderMetadata,
1230		},
1231		Messages: toResponseMessages(stepContent),
1232	}
1233
1234	// Determine if we should continue (has tool calls and not stopped)
1235	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1236
1237	return stepResult, shouldContinue, nil
1238}
1239
1240func addUsage(a, b Usage) Usage {
1241	return Usage{
1242		InputTokens:         a.InputTokens + b.InputTokens,
1243		OutputTokens:        a.OutputTokens + b.OutputTokens,
1244		TotalTokens:         a.TotalTokens + b.TotalTokens,
1245		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1246		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1247		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1248	}
1249}
1250
1251func WithHeaders(headers map[string]string) AgentOption {
1252	return func(s *agentSettings) {
1253		s.headers = headers
1254	}
1255}
1256
1257func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1258	return func(s *agentSettings) {
1259		s.providerOptions = providerOptions
1260	}
1261}