agent.go

   1package ai
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"maps"
   9	"slices"
  10	"sync"
  11)
  12
  13type StepResult struct {
  14	Response
  15	Messages []Message
  16}
  17
  18type StopCondition = func(steps []StepResult) bool
  19
  20// StepCountIs returns a stop condition that stops after the specified number of steps.
  21func StepCountIs(stepCount int) StopCondition {
  22	return func(steps []StepResult) bool {
  23		return len(steps) >= stepCount
  24	}
  25}
  26
  27// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  28func HasToolCall(toolName string) StopCondition {
  29	return func(steps []StepResult) bool {
  30		if len(steps) == 0 {
  31			return false
  32		}
  33		lastStep := steps[len(steps)-1]
  34		toolCalls := lastStep.Content.ToolCalls()
  35		for _, toolCall := range toolCalls {
  36			if toolCall.ToolName == toolName {
  37				return true
  38			}
  39		}
  40		return false
  41	}
  42}
  43
  44// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  45func HasContent(contentType ContentType) StopCondition {
  46	return func(steps []StepResult) bool {
  47		if len(steps) == 0 {
  48			return false
  49		}
  50		lastStep := steps[len(steps)-1]
  51		for _, content := range lastStep.Content {
  52			if content.GetType() == contentType {
  53				return true
  54			}
  55		}
  56		return false
  57	}
  58}
  59
  60// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  61func FinishReasonIs(reason FinishReason) StopCondition {
  62	return func(steps []StepResult) bool {
  63		if len(steps) == 0 {
  64			return false
  65		}
  66		lastStep := steps[len(steps)-1]
  67		return lastStep.FinishReason == reason
  68	}
  69}
  70
  71// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  72func MaxTokensUsed(maxTokens int64) StopCondition {
  73	return func(steps []StepResult) bool {
  74		var totalTokens int64
  75		for _, step := range steps {
  76			totalTokens += step.Usage.TotalTokens
  77		}
  78		return totalTokens >= maxTokens
  79	}
  80}
  81
  82type PrepareStepFunctionOptions struct {
  83	Steps      []StepResult
  84	StepNumber int
  85	Model      LanguageModel
  86	Messages   []Message
  87}
  88
  89type PrepareStepResult struct {
  90	Model           LanguageModel
  91	Messages        []Message
  92	System          *string
  93	ToolChoice      *ToolChoice
  94	ActiveTools     []string
  95	DisableAllTools bool
  96}
  97
  98type ToolCallRepairOptions struct {
  99	OriginalToolCall ToolCallContent
 100	ValidationError  error
 101	AvailableTools   []AgentTool
 102	SystemPrompt     string
 103	Messages         []Message
 104}
 105
 106type (
 107	PrepareStepFunction    = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
 108	OnStepFinishedFunction = func(step StepResult)
 109	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 110)
 111
 112type agentSettings struct {
 113	systemPrompt     string
 114	maxOutputTokens  *int64
 115	temperature      *float64
 116	topP             *float64
 117	topK             *int64
 118	presencePenalty  *float64
 119	frequencyPenalty *float64
 120	headers          map[string]string
 121	providerOptions  ProviderOptions
 122
 123	// TODO: add support for provider tools
 124	tools      []AgentTool
 125	maxRetries *int
 126
 127	model LanguageModel
 128
 129	stopWhen       []StopCondition
 130	prepareStep    PrepareStepFunction
 131	repairToolCall RepairToolCallFunction
 132	onRetry        OnRetryCallback
 133}
 134
 135type AgentCall struct {
 136	Prompt           string     `json:"prompt"`
 137	Files            []FilePart `json:"files"`
 138	Messages         []Message  `json:"messages"`
 139	MaxOutputTokens  *int64
 140	Temperature      *float64 `json:"temperature"`
 141	TopP             *float64 `json:"top_p"`
 142	TopK             *int64   `json:"top_k"`
 143	PresencePenalty  *float64 `json:"presence_penalty"`
 144	FrequencyPenalty *float64 `json:"frequency_penalty"`
 145	ActiveTools      []string `json:"active_tools"`
 146	Headers          map[string]string
 147	ProviderOptions  ProviderOptions
 148	OnRetry          OnRetryCallback
 149	MaxRetries       *int
 150
 151	StopWhen       []StopCondition
 152	PrepareStep    PrepareStepFunction
 153	RepairToolCall RepairToolCallFunction
 154}
 155
 156type AgentStreamCall struct {
 157	Prompt           string     `json:"prompt"`
 158	Files            []FilePart `json:"files"`
 159	Messages         []Message  `json:"messages"`
 160	MaxOutputTokens  *int64
 161	Temperature      *float64 `json:"temperature"`
 162	TopP             *float64 `json:"top_p"`
 163	TopK             *int64   `json:"top_k"`
 164	PresencePenalty  *float64 `json:"presence_penalty"`
 165	FrequencyPenalty *float64 `json:"frequency_penalty"`
 166	ActiveTools      []string `json:"active_tools"`
 167	Headers          map[string]string
 168	ProviderOptions  ProviderOptions
 169	OnRetry          OnRetryCallback
 170	MaxRetries       *int
 171
 172	StopWhen       []StopCondition
 173	PrepareStep    PrepareStepFunction
 174	RepairToolCall RepairToolCallFunction
 175
 176	// Agent-level callbacks
 177	OnAgentStart  func()                      // Called when agent starts
 178	OnAgentFinish func(result *AgentResult)   // Called when agent finishes
 179	OnStepStart   func(stepNumber int)        // Called when a step starts
 180	OnStepFinish  func(stepResult StepResult) // Called when a step finishes
 181	OnFinish      func(result *AgentResult)   // Called when entire agent completes
 182	OnError       func(error)                 // Called when an error occurs
 183
 184	// Stream part callbacks - called for each corresponding stream part type
 185	OnChunk          func(StreamPart)                                                                // Called for each stream part (catch-all)
 186	OnWarnings       func(warnings []CallWarning)                                                    // Called for warnings
 187	OnTextStart      func(id string)                                                                 // Called when text starts
 188	OnTextDelta      func(id, text string)                                                           // Called for text deltas
 189	OnTextEnd        func(id string)                                                                 // Called when text ends
 190	OnReasoningStart func(id string)                                                                 // Called when reasoning starts
 191	OnReasoningDelta func(id, text string)                                                           // Called for reasoning deltas
 192	OnReasoningEnd   func(id string, reasoning ReasoningContent)                                     // Called when reasoning ends
 193	OnToolInputStart func(id, toolName string)                                                       // Called when tool input starts
 194	OnToolInputDelta func(id, delta string)                                                          // Called for tool input deltas
 195	OnToolInputEnd   func(id string)                                                                 // Called when tool input ends
 196	OnToolCall       func(toolCall ToolCallContent)                                                  // Called when tool call is complete
 197	OnToolResult     func(result ToolResultContent)                                                  // Called when tool execution completes
 198	OnSource         func(source SourceContent)                                                      // Called for source references
 199	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) // Called when stream finishes
 200	OnStreamError    func(error)                                                                     // Called when stream error occurs
 201}
 202
 203type AgentResult struct {
 204	Steps []StepResult
 205	// Final response
 206	Response   Response
 207	TotalUsage Usage
 208}
 209
 210type Agent interface {
 211	Generate(context.Context, AgentCall) (*AgentResult, error)
 212	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 213}
 214
 215type AgentOption = func(*agentSettings)
 216
 217type agent struct {
 218	settings agentSettings
 219}
 220
 221func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 222	settings := agentSettings{
 223		model: model,
 224	}
 225	for _, o := range opts {
 226		o(&settings)
 227	}
 228	return &agent{
 229		settings: settings,
 230	}
 231}
 232
 233func (a *agent) prepareCall(call AgentCall) AgentCall {
 234	if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
 235		call.MaxOutputTokens = a.settings.maxOutputTokens
 236	}
 237	if call.Temperature == nil && a.settings.temperature != nil {
 238		call.Temperature = a.settings.temperature
 239	}
 240	if call.TopP == nil && a.settings.topP != nil {
 241		call.TopP = a.settings.topP
 242	}
 243	if call.TopK == nil && a.settings.topK != nil {
 244		call.TopK = a.settings.topK
 245	}
 246	if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
 247		call.PresencePenalty = a.settings.presencePenalty
 248	}
 249	if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
 250		call.FrequencyPenalty = a.settings.frequencyPenalty
 251	}
 252	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 253		call.StopWhen = a.settings.stopWhen
 254	}
 255	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 256		call.PrepareStep = a.settings.prepareStep
 257	}
 258	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 259		call.RepairToolCall = a.settings.repairToolCall
 260	}
 261	if call.OnRetry == nil && a.settings.onRetry != nil {
 262		call.OnRetry = a.settings.onRetry
 263	}
 264	if call.MaxRetries == nil && a.settings.maxRetries != nil {
 265		call.MaxRetries = a.settings.maxRetries
 266	}
 267
 268	providerOptions := ProviderOptions{}
 269	if a.settings.providerOptions != nil {
 270		maps.Copy(providerOptions, a.settings.providerOptions)
 271	}
 272	if call.ProviderOptions != nil {
 273		maps.Copy(providerOptions, call.ProviderOptions)
 274	}
 275	call.ProviderOptions = providerOptions
 276
 277	headers := map[string]string{}
 278
 279	if a.settings.headers != nil {
 280		maps.Copy(headers, a.settings.headers)
 281	}
 282
 283	if call.Headers != nil {
 284		maps.Copy(headers, call.Headers)
 285	}
 286	call.Headers = headers
 287	return call
 288}
 289
 290// Generate implements Agent.
 291func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 292	opts = a.prepareCall(opts)
 293	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 294	if err != nil {
 295		return nil, err
 296	}
 297	var responseMessages []Message
 298	var steps []StepResult
 299
 300	for {
 301		stepInputMessages := append(initialPrompt, responseMessages...)
 302		stepModel := a.settings.model
 303		stepSystemPrompt := a.settings.systemPrompt
 304		stepActiveTools := opts.ActiveTools
 305		stepToolChoice := ToolChoiceAuto
 306		disableAllTools := false
 307
 308		if opts.PrepareStep != nil {
 309			prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
 310				Model:      stepModel,
 311				Steps:      steps,
 312				StepNumber: len(steps),
 313				Messages:   stepInputMessages,
 314			})
 315			if err != nil {
 316				return nil, err
 317			}
 318
 319			// Apply prepared step modifications
 320			if prepared.Messages != nil {
 321				stepInputMessages = prepared.Messages
 322			}
 323			if prepared.Model != nil {
 324				stepModel = prepared.Model
 325			}
 326			if prepared.System != nil {
 327				stepSystemPrompt = *prepared.System
 328			}
 329			if prepared.ToolChoice != nil {
 330				stepToolChoice = *prepared.ToolChoice
 331			}
 332			if len(prepared.ActiveTools) > 0 {
 333				stepActiveTools = prepared.ActiveTools
 334			}
 335			disableAllTools = prepared.DisableAllTools
 336		}
 337
 338		// Recreate prompt with potentially modified system prompt
 339		if stepSystemPrompt != a.settings.systemPrompt {
 340			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 341			if err != nil {
 342				return nil, err
 343			}
 344			// Replace system message part, keep the rest
 345			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 346				stepInputMessages[0] = stepPrompt[0] // Replace system message
 347			}
 348		}
 349
 350		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 351
 352		retryOptions := DefaultRetryOptions()
 353		retryOptions.OnRetry = opts.OnRetry
 354		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 355
 356		result, err := retry(ctx, func() (*Response, error) {
 357			return stepModel.Generate(ctx, Call{
 358				Prompt:           stepInputMessages,
 359				MaxOutputTokens:  opts.MaxOutputTokens,
 360				Temperature:      opts.Temperature,
 361				TopP:             opts.TopP,
 362				TopK:             opts.TopK,
 363				PresencePenalty:  opts.PresencePenalty,
 364				FrequencyPenalty: opts.FrequencyPenalty,
 365				Tools:            preparedTools,
 366				ToolChoice:       &stepToolChoice,
 367				Headers:          opts.Headers,
 368				ProviderOptions:  opts.ProviderOptions,
 369			})
 370		})
 371		if err != nil {
 372			return nil, err
 373		}
 374
 375		var stepToolCalls []ToolCallContent
 376		for _, content := range result.Content {
 377			if content.GetType() == ContentTypeToolCall {
 378				toolCall, ok := AsContentType[ToolCallContent](content)
 379				if !ok {
 380					continue
 381				}
 382
 383				// Validate and potentially repair the tool call
 384				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 385				stepToolCalls = append(stepToolCalls, validatedToolCall)
 386			}
 387		}
 388
 389		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 390
 391		// Build step content with validated tool calls and tool results
 392		stepContent := []Content{}
 393		toolCallIndex := 0
 394		for _, content := range result.Content {
 395			if content.GetType() == ContentTypeToolCall {
 396				// Replace with validated tool call
 397				if toolCallIndex < len(stepToolCalls) {
 398					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 399					toolCallIndex++
 400				}
 401			} else {
 402				// Keep other content as-is
 403				stepContent = append(stepContent, content)
 404			}
 405		}
 406		// Add tool results
 407		for _, result := range toolResults {
 408			stepContent = append(stepContent, result)
 409		}
 410		currentStepMessages := toResponseMessages(stepContent)
 411		responseMessages = append(responseMessages, currentStepMessages...)
 412
 413		stepResult := StepResult{
 414			Response: Response{
 415				Content:          stepContent,
 416				FinishReason:     result.FinishReason,
 417				Usage:            result.Usage,
 418				Warnings:         result.Warnings,
 419				ProviderMetadata: result.ProviderMetadata,
 420			},
 421			Messages: currentStepMessages,
 422		}
 423		steps = append(steps, stepResult)
 424		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 425
 426		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 427			break
 428		}
 429	}
 430
 431	totalUsage := Usage{}
 432
 433	for _, step := range steps {
 434		usage := step.Usage
 435		totalUsage.InputTokens += usage.InputTokens
 436		totalUsage.OutputTokens += usage.OutputTokens
 437		totalUsage.ReasoningTokens += usage.ReasoningTokens
 438		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 439		totalUsage.CacheReadTokens += usage.CacheReadTokens
 440		totalUsage.TotalTokens += usage.TotalTokens
 441	}
 442
 443	agentResult := &AgentResult{
 444		Steps:      steps,
 445		Response:   steps[len(steps)-1].Response,
 446		TotalUsage: totalUsage,
 447	}
 448	return agentResult, nil
 449}
 450
 451func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 452	if len(conditions) == 0 {
 453		return false
 454	}
 455
 456	for _, condition := range conditions {
 457		if condition(steps) {
 458			return true
 459		}
 460	}
 461	return false
 462}
 463
 464func toResponseMessages(content []Content) []Message {
 465	var assistantParts []MessagePart
 466	var toolParts []MessagePart
 467
 468	for _, c := range content {
 469		switch c.GetType() {
 470		case ContentTypeText:
 471			text, ok := AsContentType[TextContent](c)
 472			if !ok {
 473				continue
 474			}
 475			assistantParts = append(assistantParts, TextPart{
 476				Text:            text.Text,
 477				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 478			})
 479		case ContentTypeReasoning:
 480			reasoning, ok := AsContentType[ReasoningContent](c)
 481			if !ok {
 482				continue
 483			}
 484			assistantParts = append(assistantParts, ReasoningPart{
 485				Text:            reasoning.Text,
 486				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 487			})
 488		case ContentTypeToolCall:
 489			toolCall, ok := AsContentType[ToolCallContent](c)
 490			if !ok {
 491				continue
 492			}
 493			assistantParts = append(assistantParts, ToolCallPart{
 494				ToolCallID:       toolCall.ToolCallID,
 495				ToolName:         toolCall.ToolName,
 496				Input:            toolCall.Input,
 497				ProviderExecuted: toolCall.ProviderExecuted,
 498				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 499			})
 500		case ContentTypeFile:
 501			file, ok := AsContentType[FileContent](c)
 502			if !ok {
 503				continue
 504			}
 505			assistantParts = append(assistantParts, FilePart{
 506				Data:            file.Data,
 507				MediaType:       file.MediaType,
 508				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 509			})
 510		case ContentTypeSource:
 511			// Sources are metadata about references used to generate the response.
 512			// They don't need to be included in the conversation messages.
 513			continue
 514		case ContentTypeToolResult:
 515			result, ok := AsContentType[ToolResultContent](c)
 516			if !ok {
 517				continue
 518			}
 519			toolParts = append(toolParts, ToolResultPart{
 520				ToolCallID:      result.ToolCallID,
 521				Output:          result.Result,
 522				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 523			})
 524		}
 525	}
 526
 527	var messages []Message
 528	if len(assistantParts) > 0 {
 529		messages = append(messages, Message{
 530			Role:    MessageRoleAssistant,
 531			Content: assistantParts,
 532		})
 533	}
 534	if len(toolParts) > 0 {
 535		messages = append(messages, Message{
 536			Role:    MessageRoleTool,
 537			Content: toolParts,
 538		})
 539	}
 540	return messages
 541}
 542
 543func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
 544	if len(toolCalls) == 0 {
 545		return nil, nil
 546	}
 547
 548	// Create a map for quick tool lookup
 549	toolMap := make(map[string]AgentTool)
 550	for _, tool := range allTools {
 551		toolMap[tool.Info().Name] = tool
 552	}
 553
 554	// Execute all tool calls in parallel
 555	results := make([]ToolResultContent, len(toolCalls))
 556	var toolExecutionError error
 557	var wg sync.WaitGroup
 558
 559	for i, toolCall := range toolCalls {
 560		wg.Add(1)
 561		go func(index int, call ToolCallContent) {
 562			defer wg.Done()
 563
 564			// Skip invalid tool calls - create error result
 565			if call.Invalid {
 566				results[index] = ToolResultContent{
 567					ToolCallID: call.ToolCallID,
 568					ToolName:   call.ToolName,
 569					Result: ToolResultOutputContentError{
 570						Error: call.ValidationError,
 571					},
 572					ProviderExecuted: false,
 573				}
 574				if toolResultCallback != nil {
 575					toolResultCallback(results[index])
 576				}
 577
 578				return
 579			}
 580
 581			tool, exists := toolMap[call.ToolName]
 582			if !exists {
 583				results[index] = ToolResultContent{
 584					ToolCallID: call.ToolCallID,
 585					ToolName:   call.ToolName,
 586					Result: ToolResultOutputContentError{
 587						Error: errors.New("Error: Tool not found: " + call.ToolName),
 588					},
 589					ProviderExecuted: false,
 590				}
 591
 592				if toolResultCallback != nil {
 593					toolResultCallback(results[index])
 594				}
 595				return
 596			}
 597
 598			// Execute the tool
 599			result, err := tool.Run(ctx, ToolCall{
 600				ID:    call.ToolCallID,
 601				Name:  call.ToolName,
 602				Input: call.Input,
 603			})
 604			if err != nil {
 605				results[index] = ToolResultContent{
 606					ToolCallID: call.ToolCallID,
 607					ToolName:   call.ToolName,
 608					Result: ToolResultOutputContentError{
 609						Error: err,
 610					},
 611					ClientMetadata:   result.Metadata,
 612					ProviderExecuted: false,
 613				}
 614				if toolResultCallback != nil {
 615					toolResultCallback(results[index])
 616				}
 617				toolExecutionError = err
 618				return
 619			}
 620
 621			if result.IsError {
 622				results[index] = ToolResultContent{
 623					ToolCallID: call.ToolCallID,
 624					ToolName:   call.ToolName,
 625					Result: ToolResultOutputContentError{
 626						Error: errors.New(result.Content),
 627					},
 628					ClientMetadata:   result.Metadata,
 629					ProviderExecuted: false,
 630				}
 631
 632				if toolResultCallback != nil {
 633					toolResultCallback(results[index])
 634				}
 635			} else {
 636				results[index] = ToolResultContent{
 637					ToolCallID: call.ToolCallID,
 638					ToolName:   toolCall.ToolName,
 639					Result: ToolResultOutputContentText{
 640						Text: result.Content,
 641					},
 642					ClientMetadata:   result.Metadata,
 643					ProviderExecuted: false,
 644				}
 645				if toolResultCallback != nil {
 646					toolResultCallback(results[index])
 647				}
 648			}
 649		}(i, toolCall)
 650	}
 651
 652	// Wait for all tool executions to complete
 653	wg.Wait()
 654
 655	return results, toolExecutionError
 656}
 657
 658// Stream implements Agent.
 659func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 660	// Convert AgentStreamCall to AgentCall for preparation
 661	call := AgentCall{
 662		Prompt:           opts.Prompt,
 663		Files:            opts.Files,
 664		Messages:         opts.Messages,
 665		MaxOutputTokens:  opts.MaxOutputTokens,
 666		Temperature:      opts.Temperature,
 667		TopP:             opts.TopP,
 668		TopK:             opts.TopK,
 669		PresencePenalty:  opts.PresencePenalty,
 670		FrequencyPenalty: opts.FrequencyPenalty,
 671		ActiveTools:      opts.ActiveTools,
 672		Headers:          opts.Headers,
 673		ProviderOptions:  opts.ProviderOptions,
 674		MaxRetries:       opts.MaxRetries,
 675		StopWhen:         opts.StopWhen,
 676		PrepareStep:      opts.PrepareStep,
 677		RepairToolCall:   opts.RepairToolCall,
 678	}
 679
 680	call = a.prepareCall(call)
 681
 682	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 683	if err != nil {
 684		return nil, err
 685	}
 686
 687	var responseMessages []Message
 688	var steps []StepResult
 689	var totalUsage Usage
 690
 691	// Start agent stream
 692	if opts.OnAgentStart != nil {
 693		opts.OnAgentStart()
 694	}
 695
 696	for stepNumber := 0; ; stepNumber++ {
 697		stepInputMessages := append(initialPrompt, responseMessages...)
 698		stepModel := a.settings.model
 699		stepSystemPrompt := a.settings.systemPrompt
 700		stepActiveTools := call.ActiveTools
 701		stepToolChoice := ToolChoiceAuto
 702		disableAllTools := false
 703
 704		// Apply step preparation if provided
 705		if call.PrepareStep != nil {
 706			prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
 707				Model:      stepModel,
 708				Steps:      steps,
 709				StepNumber: stepNumber,
 710				Messages:   stepInputMessages,
 711			})
 712			if err != nil {
 713				return nil, err
 714			}
 715
 716			if prepared.Messages != nil {
 717				stepInputMessages = prepared.Messages
 718			}
 719			if prepared.Model != nil {
 720				stepModel = prepared.Model
 721			}
 722			if prepared.System != nil {
 723				stepSystemPrompt = *prepared.System
 724			}
 725			if prepared.ToolChoice != nil {
 726				stepToolChoice = *prepared.ToolChoice
 727			}
 728			if len(prepared.ActiveTools) > 0 {
 729				stepActiveTools = prepared.ActiveTools
 730			}
 731			disableAllTools = prepared.DisableAllTools
 732		}
 733
 734		// Recreate prompt with potentially modified system prompt
 735		if stepSystemPrompt != a.settings.systemPrompt {
 736			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 737			if err != nil {
 738				return nil, err
 739			}
 740			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 741				stepInputMessages[0] = stepPrompt[0]
 742			}
 743		}
 744
 745		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 746
 747		// Start step stream
 748		if opts.OnStepStart != nil {
 749			opts.OnStepStart(stepNumber)
 750		}
 751
 752		// Create streaming call
 753		streamCall := Call{
 754			Prompt:           stepInputMessages,
 755			MaxOutputTokens:  call.MaxOutputTokens,
 756			Temperature:      call.Temperature,
 757			TopP:             call.TopP,
 758			TopK:             call.TopK,
 759			PresencePenalty:  call.PresencePenalty,
 760			FrequencyPenalty: call.FrequencyPenalty,
 761			Tools:            preparedTools,
 762			ToolChoice:       &stepToolChoice,
 763			Headers:          call.Headers,
 764			ProviderOptions:  call.ProviderOptions,
 765		}
 766
 767		// Get streaming response
 768		stream, err := stepModel.Stream(ctx, streamCall)
 769		if err != nil {
 770			if opts.OnError != nil {
 771				opts.OnError(err)
 772			}
 773			return nil, err
 774		}
 775
 776		// Process stream with tool execution
 777		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 778		if err != nil {
 779			if opts.OnError != nil {
 780				opts.OnError(err)
 781			}
 782			return nil, err
 783		}
 784
 785		steps = append(steps, stepResult)
 786		totalUsage = addUsage(totalUsage, stepResult.Usage)
 787
 788		// Call step finished callback
 789		if opts.OnStepFinish != nil {
 790			opts.OnStepFinish(stepResult)
 791		}
 792
 793		// Add step messages to response messages
 794		stepMessages := toResponseMessages(stepResult.Content)
 795		responseMessages = append(responseMessages, stepMessages...)
 796
 797		// Check stop conditions
 798		shouldStop := isStopConditionMet(call.StopWhen, steps)
 799		if shouldStop || !shouldContinue {
 800			break
 801		}
 802	}
 803
 804	// Finish agent stream
 805	agentResult := &AgentResult{
 806		Steps:      steps,
 807		Response:   steps[len(steps)-1].Response,
 808		TotalUsage: totalUsage,
 809	}
 810
 811	if opts.OnFinish != nil {
 812		opts.OnFinish(agentResult)
 813	}
 814
 815	if opts.OnAgentFinish != nil {
 816		opts.OnAgentFinish(agentResult)
 817	}
 818
 819	return agentResult, nil
 820}
 821
 822func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 823	var preparedTools []Tool
 824
 825	// If explicitly disabling all tools, return no tools
 826	if disableAllTools {
 827		return preparedTools
 828	}
 829
 830	for _, tool := range tools {
 831		// If activeTools has items, only include tools in the list
 832		// If activeTools is empty, include all tools
 833		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 834			continue
 835		}
 836		info := tool.Info()
 837		preparedTools = append(preparedTools, FunctionTool{
 838			Name:        info.Name,
 839			Description: info.Description,
 840			InputSchema: map[string]any{
 841				"type":       "object",
 842				"properties": info.Parameters,
 843				"required":   info.Required,
 844			},
 845		})
 846	}
 847	return preparedTools
 848}
 849
 850// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
 851func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 852	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 853		return toolCall
 854	} else {
 855		if repairFunc != nil {
 856			repairOptions := ToolCallRepairOptions{
 857				OriginalToolCall: toolCall,
 858				ValidationError:  err,
 859				AvailableTools:   availableTools,
 860				SystemPrompt:     systemPrompt,
 861				Messages:         messages,
 862			}
 863
 864			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 865				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 866					return *repairedToolCall
 867				}
 868			}
 869		}
 870
 871		invalidToolCall := toolCall
 872		invalidToolCall.Invalid = true
 873		invalidToolCall.ValidationError = err
 874		return invalidToolCall
 875	}
 876}
 877
 878// validateToolCall validates a tool call against available tools and their schemas
 879func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 880	var tool AgentTool
 881	for _, t := range availableTools {
 882		if t.Info().Name == toolCall.ToolName {
 883			tool = t
 884			break
 885		}
 886	}
 887
 888	if tool == nil {
 889		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 890	}
 891
 892	// Validate JSON parsing
 893	var input map[string]any
 894	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 895		return fmt.Errorf("invalid JSON input: %w", err)
 896	}
 897
 898	// Basic schema validation (check required fields)
 899	// TODO: more robust schema validation using JSON Schema or similar
 900	toolInfo := tool.Info()
 901	for _, required := range toolInfo.Required {
 902		if _, exists := input[required]; !exists {
 903			return fmt.Errorf("missing required parameter: %s", required)
 904		}
 905	}
 906	return nil
 907}
 908
 909func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 910	if prompt == "" {
 911		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 912	}
 913
 914	var preparedPrompt Prompt
 915
 916	if system != "" {
 917		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 918	}
 919
 920	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 921	preparedPrompt = append(preparedPrompt, messages...)
 922	return preparedPrompt, nil
 923}
 924
 925func WithSystemPrompt(prompt string) AgentOption {
 926	return func(s *agentSettings) {
 927		s.systemPrompt = prompt
 928	}
 929}
 930
 931func WithMaxOutputTokens(tokens int64) AgentOption {
 932	return func(s *agentSettings) {
 933		s.maxOutputTokens = &tokens
 934	}
 935}
 936
 937func WithTemperature(temp float64) AgentOption {
 938	return func(s *agentSettings) {
 939		s.temperature = &temp
 940	}
 941}
 942
 943func WithTopP(topP float64) AgentOption {
 944	return func(s *agentSettings) {
 945		s.topP = &topP
 946	}
 947}
 948
 949func WithTopK(topK int64) AgentOption {
 950	return func(s *agentSettings) {
 951		s.topK = &topK
 952	}
 953}
 954
 955func WithPresencePenalty(penalty float64) AgentOption {
 956	return func(s *agentSettings) {
 957		s.presencePenalty = &penalty
 958	}
 959}
 960
 961func WithFrequencyPenalty(penalty float64) AgentOption {
 962	return func(s *agentSettings) {
 963		s.frequencyPenalty = &penalty
 964	}
 965}
 966
 967func WithTools(tools ...AgentTool) AgentOption {
 968	return func(s *agentSettings) {
 969		s.tools = append(s.tools, tools...)
 970	}
 971}
 972
 973func WithStopConditions(conditions ...StopCondition) AgentOption {
 974	return func(s *agentSettings) {
 975		s.stopWhen = append(s.stopWhen, conditions...)
 976	}
 977}
 978
 979func WithPrepareStep(fn PrepareStepFunction) AgentOption {
 980	return func(s *agentSettings) {
 981		s.prepareStep = fn
 982	}
 983}
 984
 985func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
 986	return func(s *agentSettings) {
 987		s.repairToolCall = fn
 988	}
 989}
 990
 991// processStepStream processes a single step's stream and returns the step result
 992func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
 993	var stepContent []Content
 994	var stepToolCalls []ToolCallContent
 995	var stepUsage Usage
 996	stepFinishReason := FinishReasonUnknown
 997	var stepWarnings []CallWarning
 998	var stepProviderMetadata ProviderMetadata
 999
1000	activeToolCalls := make(map[string]*ToolCallContent)
1001	activeTextContent := make(map[string]string)
1002
1003	// Process stream parts
1004	for part := range stream {
1005		// Forward all parts to chunk callback
1006		if opts.OnChunk != nil {
1007			opts.OnChunk(part)
1008		}
1009
1010		switch part.Type {
1011		case StreamPartTypeWarnings:
1012			stepWarnings = part.Warnings
1013			if opts.OnWarnings != nil {
1014				opts.OnWarnings(part.Warnings)
1015			}
1016
1017		case StreamPartTypeTextStart:
1018			activeTextContent[part.ID] = ""
1019			if opts.OnTextStart != nil {
1020				opts.OnTextStart(part.ID)
1021			}
1022
1023		case StreamPartTypeTextDelta:
1024			if _, exists := activeTextContent[part.ID]; exists {
1025				activeTextContent[part.ID] += part.Delta
1026			}
1027			if opts.OnTextDelta != nil {
1028				opts.OnTextDelta(part.ID, part.Delta)
1029			}
1030
1031		case StreamPartTypeTextEnd:
1032			if text, exists := activeTextContent[part.ID]; exists {
1033				stepContent = append(stepContent, TextContent{
1034					Text:             text,
1035					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1036				})
1037				delete(activeTextContent, part.ID)
1038			}
1039			if opts.OnTextEnd != nil {
1040				opts.OnTextEnd(part.ID)
1041			}
1042
1043		case StreamPartTypeReasoningStart:
1044			activeTextContent[part.ID] = ""
1045			if opts.OnReasoningStart != nil {
1046				opts.OnReasoningStart(part.ID)
1047			}
1048
1049		case StreamPartTypeReasoningDelta:
1050			if _, exists := activeTextContent[part.ID]; exists {
1051				activeTextContent[part.ID] += part.Delta
1052			}
1053			if opts.OnReasoningDelta != nil {
1054				opts.OnReasoningDelta(part.ID, part.Delta)
1055			}
1056
1057		case StreamPartTypeReasoningEnd:
1058			if text, exists := activeTextContent[part.ID]; exists {
1059				stepContent = append(stepContent, ReasoningContent{
1060					Text:             text,
1061					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1062				})
1063				if opts.OnReasoningEnd != nil {
1064					opts.OnReasoningEnd(part.ID, ReasoningContent{
1065						Text:             text,
1066						ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1067					})
1068				}
1069				delete(activeTextContent, part.ID)
1070			}
1071
1072		case StreamPartTypeToolInputStart:
1073			activeToolCalls[part.ID] = &ToolCallContent{
1074				ToolCallID:       part.ID,
1075				ToolName:         part.ToolCallName,
1076				Input:            "",
1077				ProviderExecuted: part.ProviderExecuted,
1078			}
1079			if opts.OnToolInputStart != nil {
1080				opts.OnToolInputStart(part.ID, part.ToolCallName)
1081			}
1082
1083		case StreamPartTypeToolInputDelta:
1084			if toolCall, exists := activeToolCalls[part.ID]; exists {
1085				toolCall.Input += part.Delta
1086			}
1087			if opts.OnToolInputDelta != nil {
1088				opts.OnToolInputDelta(part.ID, part.Delta)
1089			}
1090
1091		case StreamPartTypeToolInputEnd:
1092			if opts.OnToolInputEnd != nil {
1093				opts.OnToolInputEnd(part.ID)
1094			}
1095
1096		case StreamPartTypeToolCall:
1097			toolCall := ToolCallContent{
1098				ToolCallID:       part.ID,
1099				ToolName:         part.ToolCallName,
1100				Input:            part.ToolCallInput,
1101				ProviderExecuted: part.ProviderExecuted,
1102				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1103			}
1104
1105			// Validate and potentially repair the tool call
1106			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1107			stepToolCalls = append(stepToolCalls, validatedToolCall)
1108			stepContent = append(stepContent, validatedToolCall)
1109
1110			if opts.OnToolCall != nil {
1111				opts.OnToolCall(validatedToolCall)
1112			}
1113
1114			// Clean up active tool call
1115			delete(activeToolCalls, part.ID)
1116
1117		case StreamPartTypeSource:
1118			sourceContent := SourceContent{
1119				SourceType:       part.SourceType,
1120				ID:               part.ID,
1121				URL:              part.URL,
1122				Title:            part.Title,
1123				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1124			}
1125			stepContent = append(stepContent, sourceContent)
1126			if opts.OnSource != nil {
1127				opts.OnSource(sourceContent)
1128			}
1129
1130		case StreamPartTypeFinish:
1131			stepUsage = part.Usage
1132			stepFinishReason = part.FinishReason
1133			stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1134			if opts.OnStreamFinish != nil {
1135				opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1136			}
1137
1138		case StreamPartTypeError:
1139			if opts.OnStreamError != nil {
1140				opts.OnStreamError(part.Error)
1141			}
1142			if opts.OnError != nil {
1143				opts.OnError(part.Error)
1144			}
1145			return StepResult{}, false, part.Error
1146		}
1147	}
1148
1149	// Execute tools if any
1150	var toolResults []ToolResultContent
1151	if len(stepToolCalls) > 0 {
1152		var err error
1153		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1154		if err != nil {
1155			return StepResult{}, false, err
1156		}
1157		// Add tool results to content
1158		for _, result := range toolResults {
1159			stepContent = append(stepContent, result)
1160		}
1161	}
1162
1163	stepResult := StepResult{
1164		Response: Response{
1165			Content:          stepContent,
1166			FinishReason:     stepFinishReason,
1167			Usage:            stepUsage,
1168			Warnings:         stepWarnings,
1169			ProviderMetadata: stepProviderMetadata,
1170		},
1171		Messages: toResponseMessages(stepContent),
1172	}
1173
1174	// Determine if we should continue (has tool calls and not stopped)
1175	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1176
1177	return stepResult, shouldContinue, nil
1178}
1179
1180func addUsage(a, b Usage) Usage {
1181	return Usage{
1182		InputTokens:         a.InputTokens + b.InputTokens,
1183		OutputTokens:        a.OutputTokens + b.OutputTokens,
1184		TotalTokens:         a.TotalTokens + b.TotalTokens,
1185		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1186		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1187		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1188	}
1189}
1190
1191func WithHeaders(headers map[string]string) AgentOption {
1192	return func(s *agentSettings) {
1193		s.headers = headers
1194	}
1195}
1196
1197func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1198	return func(s *agentSettings) {
1199		s.providerOptions = providerOptions
1200	}
1201}