agent.go

   1package ai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11	"sync"
  12)
  13
  14type StepResult struct {
  15	Response
  16	Messages []Message
  17}
  18
  19type StopCondition = func(steps []StepResult) bool
  20
  21// StepCountIs returns a stop condition that stops after the specified number of steps.
  22func StepCountIs(stepCount int) StopCondition {
  23	return func(steps []StepResult) bool {
  24		return len(steps) >= stepCount
  25	}
  26}
  27
  28// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  29func HasToolCall(toolName string) StopCondition {
  30	return func(steps []StepResult) bool {
  31		if len(steps) == 0 {
  32			return false
  33		}
  34		lastStep := steps[len(steps)-1]
  35		toolCalls := lastStep.Content.ToolCalls()
  36		for _, toolCall := range toolCalls {
  37			if toolCall.ToolName == toolName {
  38				return true
  39			}
  40		}
  41		return false
  42	}
  43}
  44
  45// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  46func HasContent(contentType ContentType) StopCondition {
  47	return func(steps []StepResult) bool {
  48		if len(steps) == 0 {
  49			return false
  50		}
  51		lastStep := steps[len(steps)-1]
  52		for _, content := range lastStep.Content {
  53			if content.GetType() == contentType {
  54				return true
  55			}
  56		}
  57		return false
  58	}
  59}
  60
  61// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  62func FinishReasonIs(reason FinishReason) StopCondition {
  63	return func(steps []StepResult) bool {
  64		if len(steps) == 0 {
  65			return false
  66		}
  67		lastStep := steps[len(steps)-1]
  68		return lastStep.FinishReason == reason
  69	}
  70}
  71
  72// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  73func MaxTokensUsed(maxTokens int64) StopCondition {
  74	return func(steps []StepResult) bool {
  75		var totalTokens int64
  76		for _, step := range steps {
  77			totalTokens += step.Usage.TotalTokens
  78		}
  79		return totalTokens >= maxTokens
  80	}
  81}
  82
  83type PrepareStepFunctionOptions struct {
  84	Steps      []StepResult
  85	StepNumber int
  86	Model      LanguageModel
  87	Messages   []Message
  88}
  89
  90type PrepareStepResult struct {
  91	Model           LanguageModel
  92	Messages        []Message
  93	System          *string
  94	ToolChoice      *ToolChoice
  95	ActiveTools     []string
  96	DisableAllTools bool
  97}
  98
  99type ToolCallRepairOptions struct {
 100	OriginalToolCall ToolCallContent
 101	ValidationError  error
 102	AvailableTools   []AgentTool
 103	SystemPrompt     string
 104	Messages         []Message
 105}
 106
 107type (
 108	PrepareStepFunction    = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
 109	OnStepFinishedFunction = func(step StepResult)
 110	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 111)
 112
 113type agentSettings struct {
 114	systemPrompt     string
 115	maxOutputTokens  *int64
 116	temperature      *float64
 117	topP             *float64
 118	topK             *int64
 119	presencePenalty  *float64
 120	frequencyPenalty *float64
 121	headers          map[string]string
 122	providerOptions  ProviderOptions
 123
 124	// TODO: add support for provider tools
 125	tools      []AgentTool
 126	maxRetries *int
 127
 128	model LanguageModel
 129
 130	stopWhen       []StopCondition
 131	prepareStep    PrepareStepFunction
 132	repairToolCall RepairToolCallFunction
 133	onRetry        OnRetryCallback
 134}
 135
 136type AgentCall struct {
 137	Prompt           string     `json:"prompt"`
 138	Files            []FilePart `json:"files"`
 139	Messages         []Message  `json:"messages"`
 140	MaxOutputTokens  *int64
 141	Temperature      *float64 `json:"temperature"`
 142	TopP             *float64 `json:"top_p"`
 143	TopK             *int64   `json:"top_k"`
 144	PresencePenalty  *float64 `json:"presence_penalty"`
 145	FrequencyPenalty *float64 `json:"frequency_penalty"`
 146	ActiveTools      []string `json:"active_tools"`
 147	Headers          map[string]string
 148	ProviderOptions  ProviderOptions
 149	OnRetry          OnRetryCallback
 150	MaxRetries       *int
 151
 152	StopWhen       []StopCondition
 153	PrepareStep    PrepareStepFunction
 154	RepairToolCall RepairToolCallFunction
 155}
 156
 157// Agent-level callbacks.
 158type (
 159	// OnAgentStartFunc is called when agent starts.
 160	OnAgentStartFunc func()
 161
 162	// OnAgentFinishFunc is called when agent finishes.
 163	OnAgentFinishFunc func(result *AgentResult) error
 164
 165	// OnStepStartFunc is called when a step starts.
 166	OnStepStartFunc func(stepNumber int) error
 167
 168	// OnStepFinishFunc is called when a step finishes.
 169	OnStepFinishFunc func(stepResult StepResult) error
 170
 171	// OnFinishFunc is called when entire agent completes.
 172	OnFinishFunc func(result *AgentResult)
 173
 174	// OnErrorFunc is called when an error occurs.
 175	OnErrorFunc func(error)
 176)
 177
 178// Stream part callbacks - called for each corresponding stream part type.
 179type (
 180	// OnChunkFunc is called for each stream part (catch-all).
 181	OnChunkFunc func(StreamPart) error
 182
 183	// OnWarningsFunc is called for warnings.
 184	OnWarningsFunc func(warnings []CallWarning) error
 185
 186	// OnTextStartFunc is called when text starts.
 187	OnTextStartFunc func(id string) error
 188
 189	// OnTextDeltaFunc is called for text deltas.
 190	OnTextDeltaFunc func(id, text string) error
 191
 192	// OnTextEndFunc is called when text ends.
 193	OnTextEndFunc func(id string) error
 194
 195	// OnReasoningStartFunc is called when reasoning starts.
 196	OnReasoningStartFunc func(id string) error
 197
 198	// OnReasoningDeltaFunc is called for reasoning deltas.
 199	OnReasoningDeltaFunc func(id, text string) error
 200
 201	// OnReasoningEndFunc is called when reasoning ends.
 202	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 203
 204	// OnToolInputStartFunc is called when tool input starts.
 205	OnToolInputStartFunc func(id, toolName string) error
 206
 207	// OnToolInputDeltaFunc is called for tool input deltas.
 208	OnToolInputDeltaFunc func(id, delta string) error
 209
 210	// OnToolInputEndFunc is called when tool input ends.
 211	OnToolInputEndFunc func(id string) error
 212
 213	// OnToolCallFunc is called when tool call is complete.
 214	OnToolCallFunc func(toolCall ToolCallContent) error
 215
 216	// OnToolResultFunc is called when tool execution completes.
 217	OnToolResultFunc func(result ToolResultContent) error
 218
 219	// OnSourceFunc is called for source references.
 220	OnSourceFunc func(source SourceContent) error
 221
 222	// OnStreamFinishFunc is called when stream finishes.
 223	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 224)
 225
 226type AgentStreamCall struct {
 227	Prompt           string     `json:"prompt"`
 228	Files            []FilePart `json:"files"`
 229	Messages         []Message  `json:"messages"`
 230	MaxOutputTokens  *int64
 231	Temperature      *float64 `json:"temperature"`
 232	TopP             *float64 `json:"top_p"`
 233	TopK             *int64   `json:"top_k"`
 234	PresencePenalty  *float64 `json:"presence_penalty"`
 235	FrequencyPenalty *float64 `json:"frequency_penalty"`
 236	ActiveTools      []string `json:"active_tools"`
 237	Headers          map[string]string
 238	ProviderOptions  ProviderOptions
 239	OnRetry          OnRetryCallback
 240	MaxRetries       *int
 241
 242	StopWhen       []StopCondition
 243	PrepareStep    PrepareStepFunction
 244	RepairToolCall RepairToolCallFunction
 245
 246	// Agent-level callbacks
 247	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 248	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 249	OnStepStart   OnStepStartFunc   // Called when a step starts
 250	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 251	OnFinish      OnFinishFunc      // Called when entire agent completes
 252	OnError       OnErrorFunc       // Called when an error occurs
 253
 254	// Stream part callbacks - called for each corresponding stream part type
 255	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 256	OnWarnings       OnWarningsFunc       // Called for warnings
 257	OnTextStart      OnTextStartFunc      // Called when text starts
 258	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 259	OnTextEnd        OnTextEndFunc        // Called when text ends
 260	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 261	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 262	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 263	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 264	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 265	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 266	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 267	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 268	OnSource         OnSourceFunc         // Called for source references
 269	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 270}
 271
 272type AgentResult struct {
 273	Steps []StepResult
 274	// Final response
 275	Response   Response
 276	TotalUsage Usage
 277}
 278
 279type Agent interface {
 280	Generate(context.Context, AgentCall) (*AgentResult, error)
 281	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 282}
 283
 284type AgentOption = func(*agentSettings)
 285
 286type agent struct {
 287	settings agentSettings
 288}
 289
 290func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 291	settings := agentSettings{
 292		model: model,
 293	}
 294	for _, o := range opts {
 295		o(&settings)
 296	}
 297	return &agent{
 298		settings: settings,
 299	}
 300}
 301
 302func (a *agent) prepareCall(call AgentCall) AgentCall {
 303	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 304	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 305	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 306	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 307	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 308	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 309	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 310
 311	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 312		call.StopWhen = a.settings.stopWhen
 313	}
 314	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 315		call.PrepareStep = a.settings.prepareStep
 316	}
 317	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 318		call.RepairToolCall = a.settings.repairToolCall
 319	}
 320	if call.OnRetry == nil && a.settings.onRetry != nil {
 321		call.OnRetry = a.settings.onRetry
 322	}
 323
 324	providerOptions := ProviderOptions{}
 325	if a.settings.providerOptions != nil {
 326		maps.Copy(providerOptions, a.settings.providerOptions)
 327	}
 328	if call.ProviderOptions != nil {
 329		maps.Copy(providerOptions, call.ProviderOptions)
 330	}
 331	call.ProviderOptions = providerOptions
 332
 333	headers := map[string]string{}
 334
 335	if a.settings.headers != nil {
 336		maps.Copy(headers, a.settings.headers)
 337	}
 338
 339	if call.Headers != nil {
 340		maps.Copy(headers, call.Headers)
 341	}
 342	call.Headers = headers
 343	return call
 344}
 345
 346// Generate implements Agent.
 347func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 348	opts = a.prepareCall(opts)
 349	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 350	if err != nil {
 351		return nil, err
 352	}
 353	var responseMessages []Message
 354	var steps []StepResult
 355
 356	for {
 357		stepInputMessages := append(initialPrompt, responseMessages...)
 358		stepModel := a.settings.model
 359		stepSystemPrompt := a.settings.systemPrompt
 360		stepActiveTools := opts.ActiveTools
 361		stepToolChoice := ToolChoiceAuto
 362		disableAllTools := false
 363
 364		if opts.PrepareStep != nil {
 365			prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
 366				Model:      stepModel,
 367				Steps:      steps,
 368				StepNumber: len(steps),
 369				Messages:   stepInputMessages,
 370			})
 371			if err != nil {
 372				return nil, err
 373			}
 374
 375			// Apply prepared step modifications
 376			if prepared.Messages != nil {
 377				stepInputMessages = prepared.Messages
 378			}
 379			if prepared.Model != nil {
 380				stepModel = prepared.Model
 381			}
 382			if prepared.System != nil {
 383				stepSystemPrompt = *prepared.System
 384			}
 385			if prepared.ToolChoice != nil {
 386				stepToolChoice = *prepared.ToolChoice
 387			}
 388			if len(prepared.ActiveTools) > 0 {
 389				stepActiveTools = prepared.ActiveTools
 390			}
 391			disableAllTools = prepared.DisableAllTools
 392		}
 393
 394		// Recreate prompt with potentially modified system prompt
 395		if stepSystemPrompt != a.settings.systemPrompt {
 396			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 397			if err != nil {
 398				return nil, err
 399			}
 400			// Replace system message part, keep the rest
 401			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 402				stepInputMessages[0] = stepPrompt[0] // Replace system message
 403			}
 404		}
 405
 406		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 407
 408		retryOptions := DefaultRetryOptions()
 409		retryOptions.OnRetry = opts.OnRetry
 410		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 411
 412		result, err := retry(ctx, func() (*Response, error) {
 413			return stepModel.Generate(ctx, Call{
 414				Prompt:           stepInputMessages,
 415				MaxOutputTokens:  opts.MaxOutputTokens,
 416				Temperature:      opts.Temperature,
 417				TopP:             opts.TopP,
 418				TopK:             opts.TopK,
 419				PresencePenalty:  opts.PresencePenalty,
 420				FrequencyPenalty: opts.FrequencyPenalty,
 421				Tools:            preparedTools,
 422				ToolChoice:       &stepToolChoice,
 423				Headers:          opts.Headers,
 424				ProviderOptions:  opts.ProviderOptions,
 425			})
 426		})
 427		if err != nil {
 428			return nil, err
 429		}
 430
 431		var stepToolCalls []ToolCallContent
 432		for _, content := range result.Content {
 433			if content.GetType() == ContentTypeToolCall {
 434				toolCall, ok := AsContentType[ToolCallContent](content)
 435				if !ok {
 436					continue
 437				}
 438
 439				// Validate and potentially repair the tool call
 440				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 441				stepToolCalls = append(stepToolCalls, validatedToolCall)
 442			}
 443		}
 444
 445		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 446
 447		// Build step content with validated tool calls and tool results
 448		stepContent := []Content{}
 449		toolCallIndex := 0
 450		for _, content := range result.Content {
 451			if content.GetType() == ContentTypeToolCall {
 452				// Replace with validated tool call
 453				if toolCallIndex < len(stepToolCalls) {
 454					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 455					toolCallIndex++
 456				}
 457			} else {
 458				// Keep other content as-is
 459				stepContent = append(stepContent, content)
 460			}
 461		}
 462		// Add tool results
 463		for _, result := range toolResults {
 464			stepContent = append(stepContent, result)
 465		}
 466		currentStepMessages := toResponseMessages(stepContent)
 467		responseMessages = append(responseMessages, currentStepMessages...)
 468
 469		stepResult := StepResult{
 470			Response: Response{
 471				Content:          stepContent,
 472				FinishReason:     result.FinishReason,
 473				Usage:            result.Usage,
 474				Warnings:         result.Warnings,
 475				ProviderMetadata: result.ProviderMetadata,
 476			},
 477			Messages: currentStepMessages,
 478		}
 479		steps = append(steps, stepResult)
 480		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 481
 482		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 483			break
 484		}
 485	}
 486
 487	totalUsage := Usage{}
 488
 489	for _, step := range steps {
 490		usage := step.Usage
 491		totalUsage.InputTokens += usage.InputTokens
 492		totalUsage.OutputTokens += usage.OutputTokens
 493		totalUsage.ReasoningTokens += usage.ReasoningTokens
 494		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 495		totalUsage.CacheReadTokens += usage.CacheReadTokens
 496		totalUsage.TotalTokens += usage.TotalTokens
 497	}
 498
 499	agentResult := &AgentResult{
 500		Steps:      steps,
 501		Response:   steps[len(steps)-1].Response,
 502		TotalUsage: totalUsage,
 503	}
 504	return agentResult, nil
 505}
 506
 507func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 508	if len(conditions) == 0 {
 509		return false
 510	}
 511
 512	for _, condition := range conditions {
 513		if condition(steps) {
 514			return true
 515		}
 516	}
 517	return false
 518}
 519
 520func toResponseMessages(content []Content) []Message {
 521	var assistantParts []MessagePart
 522	var toolParts []MessagePart
 523
 524	for _, c := range content {
 525		switch c.GetType() {
 526		case ContentTypeText:
 527			text, ok := AsContentType[TextContent](c)
 528			if !ok {
 529				continue
 530			}
 531			assistantParts = append(assistantParts, TextPart{
 532				Text:            text.Text,
 533				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 534			})
 535		case ContentTypeReasoning:
 536			reasoning, ok := AsContentType[ReasoningContent](c)
 537			if !ok {
 538				continue
 539			}
 540			assistantParts = append(assistantParts, ReasoningPart{
 541				Text:            reasoning.Text,
 542				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 543			})
 544		case ContentTypeToolCall:
 545			toolCall, ok := AsContentType[ToolCallContent](c)
 546			if !ok {
 547				continue
 548			}
 549			assistantParts = append(assistantParts, ToolCallPart{
 550				ToolCallID:       toolCall.ToolCallID,
 551				ToolName:         toolCall.ToolName,
 552				Input:            toolCall.Input,
 553				ProviderExecuted: toolCall.ProviderExecuted,
 554				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 555			})
 556		case ContentTypeFile:
 557			file, ok := AsContentType[FileContent](c)
 558			if !ok {
 559				continue
 560			}
 561			assistantParts = append(assistantParts, FilePart{
 562				Data:            file.Data,
 563				MediaType:       file.MediaType,
 564				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 565			})
 566		case ContentTypeSource:
 567			// Sources are metadata about references used to generate the response.
 568			// They don't need to be included in the conversation messages.
 569			continue
 570		case ContentTypeToolResult:
 571			result, ok := AsContentType[ToolResultContent](c)
 572			if !ok {
 573				continue
 574			}
 575			toolParts = append(toolParts, ToolResultPart{
 576				ToolCallID:      result.ToolCallID,
 577				Output:          result.Result,
 578				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 579			})
 580		}
 581	}
 582
 583	var messages []Message
 584	if len(assistantParts) > 0 {
 585		messages = append(messages, Message{
 586			Role:    MessageRoleAssistant,
 587			Content: assistantParts,
 588		})
 589	}
 590	if len(toolParts) > 0 {
 591		messages = append(messages, Message{
 592			Role:    MessageRoleTool,
 593			Content: toolParts,
 594		})
 595	}
 596	return messages
 597}
 598
 599func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 600	if len(toolCalls) == 0 {
 601		return nil, nil
 602	}
 603
 604	// Create a map for quick tool lookup
 605	toolMap := make(map[string]AgentTool)
 606	for _, tool := range allTools {
 607		toolMap[tool.Info().Name] = tool
 608	}
 609
 610	// Execute all tool calls in parallel
 611	results := make([]ToolResultContent, len(toolCalls))
 612	executeErrors := make([]error, len(toolCalls))
 613
 614	var wg sync.WaitGroup
 615
 616	for i, toolCall := range toolCalls {
 617		wg.Add(1)
 618		go func(index int, call ToolCallContent) {
 619			defer wg.Done()
 620
 621			// Skip invalid tool calls - create error result
 622			if call.Invalid {
 623				results[index] = ToolResultContent{
 624					ToolCallID: call.ToolCallID,
 625					ToolName:   call.ToolName,
 626					Result: ToolResultOutputContentError{
 627						Error: call.ValidationError,
 628					},
 629					ProviderExecuted: false,
 630				}
 631				if toolResultCallback != nil {
 632					err := toolResultCallback(results[index])
 633					if err != nil {
 634						executeErrors[index] = err
 635					}
 636				}
 637
 638				return
 639			}
 640
 641			tool, exists := toolMap[call.ToolName]
 642			if !exists {
 643				results[index] = ToolResultContent{
 644					ToolCallID: call.ToolCallID,
 645					ToolName:   call.ToolName,
 646					Result: ToolResultOutputContentError{
 647						Error: errors.New("Error: Tool not found: " + call.ToolName),
 648					},
 649					ProviderExecuted: false,
 650				}
 651
 652				if toolResultCallback != nil {
 653					err := toolResultCallback(results[index])
 654					if err != nil {
 655						executeErrors[index] = err
 656					}
 657				}
 658				return
 659			}
 660
 661			// Execute the tool
 662			result, err := tool.Run(ctx, ToolCall{
 663				ID:    call.ToolCallID,
 664				Name:  call.ToolName,
 665				Input: call.Input,
 666			})
 667			if err != nil {
 668				results[index] = ToolResultContent{
 669					ToolCallID: call.ToolCallID,
 670					ToolName:   call.ToolName,
 671					Result: ToolResultOutputContentError{
 672						Error: err,
 673					},
 674					ClientMetadata:   result.Metadata,
 675					ProviderExecuted: false,
 676				}
 677				if toolResultCallback != nil {
 678					cbErr := toolResultCallback(results[index])
 679					if cbErr != nil {
 680						executeErrors[index] = cbErr
 681					}
 682				}
 683				executeErrors[index] = err
 684				return
 685			}
 686
 687			if result.IsError {
 688				results[index] = ToolResultContent{
 689					ToolCallID: call.ToolCallID,
 690					ToolName:   call.ToolName,
 691					Result: ToolResultOutputContentError{
 692						Error: errors.New(result.Content),
 693					},
 694					ClientMetadata:   result.Metadata,
 695					ProviderExecuted: false,
 696				}
 697
 698				if toolResultCallback != nil {
 699					err := toolResultCallback(results[index])
 700					if err != nil {
 701						executeErrors[index] = err
 702					}
 703				}
 704			} else {
 705				results[index] = ToolResultContent{
 706					ToolCallID: call.ToolCallID,
 707					ToolName:   toolCall.ToolName,
 708					Result: ToolResultOutputContentText{
 709						Text: result.Content,
 710					},
 711					ClientMetadata:   result.Metadata,
 712					ProviderExecuted: false,
 713				}
 714				if toolResultCallback != nil {
 715					err := toolResultCallback(results[index])
 716					if err != nil {
 717						executeErrors[index] = err
 718					}
 719				}
 720			}
 721		}(i, toolCall)
 722	}
 723
 724	// Wait for all tool executions to complete
 725	wg.Wait()
 726
 727	for _, err := range executeErrors {
 728		if err != nil {
 729			return nil, err
 730		}
 731	}
 732
 733	return results, nil
 734}
 735
 736// Stream implements Agent.
 737func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 738	// Convert AgentStreamCall to AgentCall for preparation
 739	call := AgentCall{
 740		Prompt:           opts.Prompt,
 741		Files:            opts.Files,
 742		Messages:         opts.Messages,
 743		MaxOutputTokens:  opts.MaxOutputTokens,
 744		Temperature:      opts.Temperature,
 745		TopP:             opts.TopP,
 746		TopK:             opts.TopK,
 747		PresencePenalty:  opts.PresencePenalty,
 748		FrequencyPenalty: opts.FrequencyPenalty,
 749		ActiveTools:      opts.ActiveTools,
 750		Headers:          opts.Headers,
 751		ProviderOptions:  opts.ProviderOptions,
 752		MaxRetries:       opts.MaxRetries,
 753		StopWhen:         opts.StopWhen,
 754		PrepareStep:      opts.PrepareStep,
 755		RepairToolCall:   opts.RepairToolCall,
 756	}
 757
 758	call = a.prepareCall(call)
 759
 760	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 761	if err != nil {
 762		return nil, err
 763	}
 764
 765	var responseMessages []Message
 766	var steps []StepResult
 767	var totalUsage Usage
 768
 769	// Start agent stream
 770	if opts.OnAgentStart != nil {
 771		opts.OnAgentStart()
 772	}
 773
 774	for stepNumber := 0; ; stepNumber++ {
 775		stepInputMessages := append(initialPrompt, responseMessages...)
 776		stepModel := a.settings.model
 777		stepSystemPrompt := a.settings.systemPrompt
 778		stepActiveTools := call.ActiveTools
 779		stepToolChoice := ToolChoiceAuto
 780		disableAllTools := false
 781
 782		// Apply step preparation if provided
 783		if call.PrepareStep != nil {
 784			prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
 785				Model:      stepModel,
 786				Steps:      steps,
 787				StepNumber: stepNumber,
 788				Messages:   stepInputMessages,
 789			})
 790			if err != nil {
 791				return nil, err
 792			}
 793
 794			if prepared.Messages != nil {
 795				stepInputMessages = prepared.Messages
 796			}
 797			if prepared.Model != nil {
 798				stepModel = prepared.Model
 799			}
 800			if prepared.System != nil {
 801				stepSystemPrompt = *prepared.System
 802			}
 803			if prepared.ToolChoice != nil {
 804				stepToolChoice = *prepared.ToolChoice
 805			}
 806			if len(prepared.ActiveTools) > 0 {
 807				stepActiveTools = prepared.ActiveTools
 808			}
 809			disableAllTools = prepared.DisableAllTools
 810		}
 811
 812		// Recreate prompt with potentially modified system prompt
 813		if stepSystemPrompt != a.settings.systemPrompt {
 814			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 815			if err != nil {
 816				return nil, err
 817			}
 818			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 819				stepInputMessages[0] = stepPrompt[0]
 820			}
 821		}
 822
 823		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 824
 825		// Start step stream
 826		if opts.OnStepStart != nil {
 827			_ = opts.OnStepStart(stepNumber)
 828		}
 829
 830		// Create streaming call
 831		streamCall := Call{
 832			Prompt:           stepInputMessages,
 833			MaxOutputTokens:  call.MaxOutputTokens,
 834			Temperature:      call.Temperature,
 835			TopP:             call.TopP,
 836			TopK:             call.TopK,
 837			PresencePenalty:  call.PresencePenalty,
 838			FrequencyPenalty: call.FrequencyPenalty,
 839			Tools:            preparedTools,
 840			ToolChoice:       &stepToolChoice,
 841			Headers:          call.Headers,
 842			ProviderOptions:  call.ProviderOptions,
 843		}
 844
 845		// Get streaming response
 846		stream, err := stepModel.Stream(ctx, streamCall)
 847		if err != nil {
 848			if opts.OnError != nil {
 849				opts.OnError(err)
 850			}
 851			return nil, err
 852		}
 853
 854		// Process stream with tool execution
 855		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 856		if err != nil {
 857			if opts.OnError != nil {
 858				opts.OnError(err)
 859			}
 860			return nil, err
 861		}
 862
 863		steps = append(steps, stepResult)
 864		totalUsage = addUsage(totalUsage, stepResult.Usage)
 865
 866		// Call step finished callback
 867		if opts.OnStepFinish != nil {
 868			_ = opts.OnStepFinish(stepResult)
 869		}
 870
 871		// Add step messages to response messages
 872		stepMessages := toResponseMessages(stepResult.Content)
 873		responseMessages = append(responseMessages, stepMessages...)
 874
 875		// Check stop conditions
 876		shouldStop := isStopConditionMet(call.StopWhen, steps)
 877		if shouldStop || !shouldContinue {
 878			break
 879		}
 880	}
 881
 882	// Finish agent stream
 883	agentResult := &AgentResult{
 884		Steps:      steps,
 885		Response:   steps[len(steps)-1].Response,
 886		TotalUsage: totalUsage,
 887	}
 888
 889	if opts.OnFinish != nil {
 890		opts.OnFinish(agentResult)
 891	}
 892
 893	if opts.OnAgentFinish != nil {
 894		_ = opts.OnAgentFinish(agentResult)
 895	}
 896
 897	return agentResult, nil
 898}
 899
 900func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 901	preparedTools := make([]Tool, 0, len(tools))
 902
 903	// If explicitly disabling all tools, return no tools
 904	if disableAllTools {
 905		return preparedTools
 906	}
 907
 908	for _, tool := range tools {
 909		// If activeTools has items, only include tools in the list
 910		// If activeTools is empty, include all tools
 911		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 912			continue
 913		}
 914		info := tool.Info()
 915		preparedTools = append(preparedTools, FunctionTool{
 916			Name:        info.Name,
 917			Description: info.Description,
 918			InputSchema: map[string]any{
 919				"type":       "object",
 920				"properties": info.Parameters,
 921				"required":   info.Required,
 922			},
 923		})
 924	}
 925	return preparedTools
 926}
 927
 928// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 929func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 930	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 931		return toolCall
 932	} else {
 933		if repairFunc != nil {
 934			repairOptions := ToolCallRepairOptions{
 935				OriginalToolCall: toolCall,
 936				ValidationError:  err,
 937				AvailableTools:   availableTools,
 938				SystemPrompt:     systemPrompt,
 939				Messages:         messages,
 940			}
 941
 942			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 943				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 944					return *repairedToolCall
 945				}
 946			}
 947		}
 948
 949		invalidToolCall := toolCall
 950		invalidToolCall.Invalid = true
 951		invalidToolCall.ValidationError = err
 952		return invalidToolCall
 953	}
 954}
 955
 956// validateToolCall validates a tool call against available tools and their schemas.
 957func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 958	var tool AgentTool
 959	for _, t := range availableTools {
 960		if t.Info().Name == toolCall.ToolName {
 961			tool = t
 962			break
 963		}
 964	}
 965
 966	if tool == nil {
 967		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 968	}
 969
 970	// Validate JSON parsing
 971	var input map[string]any
 972	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 973		return fmt.Errorf("invalid JSON input: %w", err)
 974	}
 975
 976	// Basic schema validation (check required fields)
 977	// TODO: more robust schema validation using JSON Schema or similar
 978	toolInfo := tool.Info()
 979	for _, required := range toolInfo.Required {
 980		if _, exists := input[required]; !exists {
 981			return fmt.Errorf("missing required parameter: %s", required)
 982		}
 983	}
 984	return nil
 985}
 986
 987func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 988	if prompt == "" {
 989		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 990	}
 991
 992	var preparedPrompt Prompt
 993
 994	if system != "" {
 995		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 996	}
 997
 998	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 999	preparedPrompt = append(preparedPrompt, messages...)
1000	return preparedPrompt, nil
1001}
1002
1003func WithSystemPrompt(prompt string) AgentOption {
1004	return func(s *agentSettings) {
1005		s.systemPrompt = prompt
1006	}
1007}
1008
1009func WithMaxOutputTokens(tokens int64) AgentOption {
1010	return func(s *agentSettings) {
1011		s.maxOutputTokens = &tokens
1012	}
1013}
1014
1015func WithTemperature(temp float64) AgentOption {
1016	return func(s *agentSettings) {
1017		s.temperature = &temp
1018	}
1019}
1020
1021func WithTopP(topP float64) AgentOption {
1022	return func(s *agentSettings) {
1023		s.topP = &topP
1024	}
1025}
1026
1027func WithTopK(topK int64) AgentOption {
1028	return func(s *agentSettings) {
1029		s.topK = &topK
1030	}
1031}
1032
1033func WithPresencePenalty(penalty float64) AgentOption {
1034	return func(s *agentSettings) {
1035		s.presencePenalty = &penalty
1036	}
1037}
1038
1039func WithFrequencyPenalty(penalty float64) AgentOption {
1040	return func(s *agentSettings) {
1041		s.frequencyPenalty = &penalty
1042	}
1043}
1044
1045func WithTools(tools ...AgentTool) AgentOption {
1046	return func(s *agentSettings) {
1047		s.tools = append(s.tools, tools...)
1048	}
1049}
1050
1051func WithStopConditions(conditions ...StopCondition) AgentOption {
1052	return func(s *agentSettings) {
1053		s.stopWhen = append(s.stopWhen, conditions...)
1054	}
1055}
1056
1057func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1058	return func(s *agentSettings) {
1059		s.prepareStep = fn
1060	}
1061}
1062
1063func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1064	return func(s *agentSettings) {
1065		s.repairToolCall = fn
1066	}
1067}
1068
1069// processStepStream processes a single step's stream and returns the step result.
1070func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1071	var stepContent []Content
1072	var stepToolCalls []ToolCallContent
1073	var stepUsage Usage
1074	stepFinishReason := FinishReasonUnknown
1075	var stepWarnings []CallWarning
1076	var stepProviderMetadata ProviderMetadata
1077
1078	activeToolCalls := make(map[string]*ToolCallContent)
1079	activeTextContent := make(map[string]string)
1080
1081	// Process stream parts
1082	for part := range stream {
1083		// Forward all parts to chunk callback
1084		if opts.OnChunk != nil {
1085			err := opts.OnChunk(part)
1086			if err != nil {
1087				return StepResult{}, false, err
1088			}
1089		}
1090
1091		switch part.Type {
1092		case StreamPartTypeWarnings:
1093			stepWarnings = part.Warnings
1094			if opts.OnWarnings != nil {
1095				err := opts.OnWarnings(part.Warnings)
1096				if err != nil {
1097					return StepResult{}, false, err
1098				}
1099			}
1100
1101		case StreamPartTypeTextStart:
1102			activeTextContent[part.ID] = ""
1103			if opts.OnTextStart != nil {
1104				err := opts.OnTextStart(part.ID)
1105				if err != nil {
1106					return StepResult{}, false, err
1107				}
1108			}
1109
1110		case StreamPartTypeTextDelta:
1111			if _, exists := activeTextContent[part.ID]; exists {
1112				activeTextContent[part.ID] += part.Delta
1113			}
1114			if opts.OnTextDelta != nil {
1115				err := opts.OnTextDelta(part.ID, part.Delta)
1116				if err != nil {
1117					return StepResult{}, false, err
1118				}
1119			}
1120
1121		case StreamPartTypeTextEnd:
1122			if text, exists := activeTextContent[part.ID]; exists {
1123				stepContent = append(stepContent, TextContent{
1124					Text:             text,
1125					ProviderMetadata: part.ProviderMetadata,
1126				})
1127				delete(activeTextContent, part.ID)
1128			}
1129			if opts.OnTextEnd != nil {
1130				err := opts.OnTextEnd(part.ID)
1131				if err != nil {
1132					return StepResult{}, false, err
1133				}
1134			}
1135
1136		case StreamPartTypeReasoningStart:
1137			activeTextContent[part.ID] = ""
1138			if opts.OnReasoningStart != nil {
1139				err := opts.OnReasoningStart(part.ID)
1140				if err != nil {
1141					return StepResult{}, false, err
1142				}
1143			}
1144
1145		case StreamPartTypeReasoningDelta:
1146			if _, exists := activeTextContent[part.ID]; exists {
1147				activeTextContent[part.ID] += part.Delta
1148			}
1149			if opts.OnReasoningDelta != nil {
1150				err := opts.OnReasoningDelta(part.ID, part.Delta)
1151				if err != nil {
1152					return StepResult{}, false, err
1153				}
1154			}
1155
1156		case StreamPartTypeReasoningEnd:
1157			if text, exists := activeTextContent[part.ID]; exists {
1158				stepContent = append(stepContent, ReasoningContent{
1159					Text:             text,
1160					ProviderMetadata: part.ProviderMetadata,
1161				})
1162				if opts.OnReasoningEnd != nil {
1163					err := opts.OnReasoningEnd(part.ID, ReasoningContent{
1164						Text:             text,
1165						ProviderMetadata: part.ProviderMetadata,
1166					})
1167					if err != nil {
1168						return StepResult{}, false, err
1169					}
1170				}
1171				delete(activeTextContent, part.ID)
1172			}
1173
1174		case StreamPartTypeToolInputStart:
1175			activeToolCalls[part.ID] = &ToolCallContent{
1176				ToolCallID:       part.ID,
1177				ToolName:         part.ToolCallName,
1178				Input:            "",
1179				ProviderExecuted: part.ProviderExecuted,
1180			}
1181			if opts.OnToolInputStart != nil {
1182				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1183				if err != nil {
1184					return StepResult{}, false, err
1185				}
1186			}
1187
1188		case StreamPartTypeToolInputDelta:
1189			if toolCall, exists := activeToolCalls[part.ID]; exists {
1190				toolCall.Input += part.Delta
1191			}
1192			if opts.OnToolInputDelta != nil {
1193				err := opts.OnToolInputDelta(part.ID, part.Delta)
1194				if err != nil {
1195					return StepResult{}, false, err
1196				}
1197			}
1198
1199		case StreamPartTypeToolInputEnd:
1200			if opts.OnToolInputEnd != nil {
1201				err := opts.OnToolInputEnd(part.ID)
1202				if err != nil {
1203					return StepResult{}, false, err
1204				}
1205			}
1206
1207		case StreamPartTypeToolCall:
1208			toolCall := ToolCallContent{
1209				ToolCallID:       part.ID,
1210				ToolName:         part.ToolCallName,
1211				Input:            part.ToolCallInput,
1212				ProviderExecuted: part.ProviderExecuted,
1213				ProviderMetadata: part.ProviderMetadata,
1214			}
1215
1216			// Validate and potentially repair the tool call
1217			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1218			stepToolCalls = append(stepToolCalls, validatedToolCall)
1219			stepContent = append(stepContent, validatedToolCall)
1220
1221			if opts.OnToolCall != nil {
1222				err := opts.OnToolCall(validatedToolCall)
1223				if err != nil {
1224					return StepResult{}, false, err
1225				}
1226			}
1227
1228			// Clean up active tool call
1229			delete(activeToolCalls, part.ID)
1230
1231		case StreamPartTypeSource:
1232			sourceContent := SourceContent{
1233				SourceType:       part.SourceType,
1234				ID:               part.ID,
1235				URL:              part.URL,
1236				Title:            part.Title,
1237				ProviderMetadata: part.ProviderMetadata,
1238			}
1239			stepContent = append(stepContent, sourceContent)
1240			if opts.OnSource != nil {
1241				err := opts.OnSource(sourceContent)
1242				if err != nil {
1243					return StepResult{}, false, err
1244				}
1245			}
1246
1247		case StreamPartTypeFinish:
1248			stepUsage = part.Usage
1249			stepFinishReason = part.FinishReason
1250			stepProviderMetadata = part.ProviderMetadata
1251			if opts.OnStreamFinish != nil {
1252				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1253				if err != nil {
1254					return StepResult{}, false, err
1255				}
1256			}
1257
1258		case StreamPartTypeError:
1259			if opts.OnError != nil {
1260				opts.OnError(part.Error)
1261			}
1262			return StepResult{}, false, part.Error
1263		}
1264	}
1265
1266	// Execute tools if any
1267	var toolResults []ToolResultContent
1268	if len(stepToolCalls) > 0 {
1269		var err error
1270		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1271		if err != nil {
1272			return StepResult{}, false, err
1273		}
1274		// Add tool results to content
1275		for _, result := range toolResults {
1276			stepContent = append(stepContent, result)
1277		}
1278	}
1279
1280	stepResult := StepResult{
1281		Response: Response{
1282			Content:          stepContent,
1283			FinishReason:     stepFinishReason,
1284			Usage:            stepUsage,
1285			Warnings:         stepWarnings,
1286			ProviderMetadata: stepProviderMetadata,
1287		},
1288		Messages: toResponseMessages(stepContent),
1289	}
1290
1291	// Determine if we should continue (has tool calls and not stopped)
1292	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1293
1294	return stepResult, shouldContinue, nil
1295}
1296
1297func addUsage(a, b Usage) Usage {
1298	return Usage{
1299		InputTokens:         a.InputTokens + b.InputTokens,
1300		OutputTokens:        a.OutputTokens + b.OutputTokens,
1301		TotalTokens:         a.TotalTokens + b.TotalTokens,
1302		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1303		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1304		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1305	}
1306}
1307
1308func WithHeaders(headers map[string]string) AgentOption {
1309	return func(s *agentSettings) {
1310		s.headers = headers
1311	}
1312}
1313
1314func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1315	return func(s *agentSettings) {
1316		s.providerOptions = providerOptions
1317	}
1318}