1package ai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11	"sync"
  12)
  13
  14type StepResult struct {
  15	Response
  16	Messages []Message
  17}
  18
  19type StopCondition = func(steps []StepResult) bool
  20
  21// StepCountIs returns a stop condition that stops after the specified number of steps.
  22func StepCountIs(stepCount int) StopCondition {
  23	return func(steps []StepResult) bool {
  24		return len(steps) >= stepCount
  25	}
  26}
  27
  28// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  29func HasToolCall(toolName string) StopCondition {
  30	return func(steps []StepResult) bool {
  31		if len(steps) == 0 {
  32			return false
  33		}
  34		lastStep := steps[len(steps)-1]
  35		toolCalls := lastStep.Content.ToolCalls()
  36		for _, toolCall := range toolCalls {
  37			if toolCall.ToolName == toolName {
  38				return true
  39			}
  40		}
  41		return false
  42	}
  43}
  44
  45// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  46func HasContent(contentType ContentType) StopCondition {
  47	return func(steps []StepResult) bool {
  48		if len(steps) == 0 {
  49			return false
  50		}
  51		lastStep := steps[len(steps)-1]
  52		for _, content := range lastStep.Content {
  53			if content.GetType() == contentType {
  54				return true
  55			}
  56		}
  57		return false
  58	}
  59}
  60
  61// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  62func FinishReasonIs(reason FinishReason) StopCondition {
  63	return func(steps []StepResult) bool {
  64		if len(steps) == 0 {
  65			return false
  66		}
  67		lastStep := steps[len(steps)-1]
  68		return lastStep.FinishReason == reason
  69	}
  70}
  71
  72// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  73func MaxTokensUsed(maxTokens int64) StopCondition {
  74	return func(steps []StepResult) bool {
  75		var totalTokens int64
  76		for _, step := range steps {
  77			totalTokens += step.Usage.TotalTokens
  78		}
  79		return totalTokens >= maxTokens
  80	}
  81}
  82
  83type PrepareStepFunctionOptions struct {
  84	Steps      []StepResult
  85	StepNumber int
  86	Model      LanguageModel
  87	Messages   []Message
  88}
  89
  90type PrepareStepResult struct {
  91	Model           LanguageModel
  92	Messages        []Message
  93	System          *string
  94	ToolChoice      *ToolChoice
  95	ActiveTools     []string
  96	DisableAllTools bool
  97}
  98
  99type ToolCallRepairOptions struct {
 100	OriginalToolCall ToolCallContent
 101	ValidationError  error
 102	AvailableTools   []AgentTool
 103	SystemPrompt     string
 104	Messages         []Message
 105}
 106
 107type (
 108	PrepareStepFunction    = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
 109	OnStepFinishedFunction = func(step StepResult)
 110	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 111)
 112
 113type agentSettings struct {
 114	systemPrompt     string
 115	maxOutputTokens  *int64
 116	temperature      *float64
 117	topP             *float64
 118	topK             *int64
 119	presencePenalty  *float64
 120	frequencyPenalty *float64
 121	headers          map[string]string
 122	providerOptions  ProviderOptions
 123
 124	// TODO: add support for provider tools
 125	tools      []AgentTool
 126	maxRetries *int
 127
 128	model LanguageModel
 129
 130	stopWhen       []StopCondition
 131	prepareStep    PrepareStepFunction
 132	repairToolCall RepairToolCallFunction
 133	onRetry        OnRetryCallback
 134}
 135
 136type AgentCall struct {
 137	Prompt           string     `json:"prompt"`
 138	Files            []FilePart `json:"files"`
 139	Messages         []Message  `json:"messages"`
 140	MaxOutputTokens  *int64
 141	Temperature      *float64 `json:"temperature"`
 142	TopP             *float64 `json:"top_p"`
 143	TopK             *int64   `json:"top_k"`
 144	PresencePenalty  *float64 `json:"presence_penalty"`
 145	FrequencyPenalty *float64 `json:"frequency_penalty"`
 146	ActiveTools      []string `json:"active_tools"`
 147	ProviderOptions  ProviderOptions
 148	OnRetry          OnRetryCallback
 149	MaxRetries       *int
 150
 151	StopWhen       []StopCondition
 152	PrepareStep    PrepareStepFunction
 153	RepairToolCall RepairToolCallFunction
 154}
 155
 156// Agent-level callbacks.
 157type (
 158	// OnAgentStartFunc is called when agent starts.
 159	OnAgentStartFunc func()
 160
 161	// OnAgentFinishFunc is called when agent finishes.
 162	OnAgentFinishFunc func(result *AgentResult) error
 163
 164	// OnStepStartFunc is called when a step starts.
 165	OnStepStartFunc func(stepNumber int) error
 166
 167	// OnStepFinishFunc is called when a step finishes.
 168	OnStepFinishFunc func(stepResult StepResult) error
 169
 170	// OnFinishFunc is called when entire agent completes.
 171	OnFinishFunc func(result *AgentResult)
 172
 173	// OnErrorFunc is called when an error occurs.
 174	OnErrorFunc func(error)
 175)
 176
 177// Stream part callbacks - called for each corresponding stream part type.
 178type (
 179	// OnChunkFunc is called for each stream part (catch-all).
 180	OnChunkFunc func(StreamPart) error
 181
 182	// OnWarningsFunc is called for warnings.
 183	OnWarningsFunc func(warnings []CallWarning) error
 184
 185	// OnTextStartFunc is called when text starts.
 186	OnTextStartFunc func(id string) error
 187
 188	// OnTextDeltaFunc is called for text deltas.
 189	OnTextDeltaFunc func(id, text string) error
 190
 191	// OnTextEndFunc is called when text ends.
 192	OnTextEndFunc func(id string) error
 193
 194	// OnReasoningStartFunc is called when reasoning starts.
 195	OnReasoningStartFunc func(id string) error
 196
 197	// OnReasoningDeltaFunc is called for reasoning deltas.
 198	OnReasoningDeltaFunc func(id, text string) error
 199
 200	// OnReasoningEndFunc is called when reasoning ends.
 201	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 202
 203	// OnToolInputStartFunc is called when tool input starts.
 204	OnToolInputStartFunc func(id, toolName string) error
 205
 206	// OnToolInputDeltaFunc is called for tool input deltas.
 207	OnToolInputDeltaFunc func(id, delta string) error
 208
 209	// OnToolInputEndFunc is called when tool input ends.
 210	OnToolInputEndFunc func(id string) error
 211
 212	// OnToolCallFunc is called when tool call is complete.
 213	OnToolCallFunc func(toolCall ToolCallContent) error
 214
 215	// OnToolResultFunc is called when tool execution completes.
 216	OnToolResultFunc func(result ToolResultContent) error
 217
 218	// OnSourceFunc is called for source references.
 219	OnSourceFunc func(source SourceContent) error
 220
 221	// OnStreamFinishFunc is called when stream finishes.
 222	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 223)
 224
 225type AgentStreamCall struct {
 226	Prompt           string     `json:"prompt"`
 227	Files            []FilePart `json:"files"`
 228	Messages         []Message  `json:"messages"`
 229	MaxOutputTokens  *int64
 230	Temperature      *float64 `json:"temperature"`
 231	TopP             *float64 `json:"top_p"`
 232	TopK             *int64   `json:"top_k"`
 233	PresencePenalty  *float64 `json:"presence_penalty"`
 234	FrequencyPenalty *float64 `json:"frequency_penalty"`
 235	ActiveTools      []string `json:"active_tools"`
 236	Headers          map[string]string
 237	ProviderOptions  ProviderOptions
 238	OnRetry          OnRetryCallback
 239	MaxRetries       *int
 240
 241	StopWhen       []StopCondition
 242	PrepareStep    PrepareStepFunction
 243	RepairToolCall RepairToolCallFunction
 244
 245	// Agent-level callbacks
 246	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 247	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 248	OnStepStart   OnStepStartFunc   // Called when a step starts
 249	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 250	OnFinish      OnFinishFunc      // Called when entire agent completes
 251	OnError       OnErrorFunc       // Called when an error occurs
 252
 253	// Stream part callbacks - called for each corresponding stream part type
 254	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 255	OnWarnings       OnWarningsFunc       // Called for warnings
 256	OnTextStart      OnTextStartFunc      // Called when text starts
 257	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 258	OnTextEnd        OnTextEndFunc        // Called when text ends
 259	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 260	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 261	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 262	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 263	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 264	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 265	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 266	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 267	OnSource         OnSourceFunc         // Called for source references
 268	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 269}
 270
 271type AgentResult struct {
 272	Steps []StepResult
 273	// Final response
 274	Response   Response
 275	TotalUsage Usage
 276}
 277
 278type Agent interface {
 279	Generate(context.Context, AgentCall) (*AgentResult, error)
 280	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 281}
 282
 283type AgentOption = func(*agentSettings)
 284
 285type agent struct {
 286	settings agentSettings
 287}
 288
 289func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 290	settings := agentSettings{
 291		model: model,
 292	}
 293	for _, o := range opts {
 294		o(&settings)
 295	}
 296	return &agent{
 297		settings: settings,
 298	}
 299}
 300
 301func (a *agent) prepareCall(call AgentCall) AgentCall {
 302	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 303	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 304	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 305	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 306	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 307	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 308	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 309
 310	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 311		call.StopWhen = a.settings.stopWhen
 312	}
 313	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 314		call.PrepareStep = a.settings.prepareStep
 315	}
 316	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 317		call.RepairToolCall = a.settings.repairToolCall
 318	}
 319	if call.OnRetry == nil && a.settings.onRetry != nil {
 320		call.OnRetry = a.settings.onRetry
 321	}
 322
 323	providerOptions := ProviderOptions{}
 324	if a.settings.providerOptions != nil {
 325		maps.Copy(providerOptions, a.settings.providerOptions)
 326	}
 327	if call.ProviderOptions != nil {
 328		maps.Copy(providerOptions, call.ProviderOptions)
 329	}
 330	call.ProviderOptions = providerOptions
 331
 332	headers := map[string]string{}
 333
 334	if a.settings.headers != nil {
 335		maps.Copy(headers, a.settings.headers)
 336	}
 337
 338	return call
 339}
 340
 341// Generate implements Agent.
 342func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 343	opts = a.prepareCall(opts)
 344	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 345	if err != nil {
 346		return nil, err
 347	}
 348	var responseMessages []Message
 349	var steps []StepResult
 350
 351	for {
 352		stepInputMessages := append(initialPrompt, responseMessages...)
 353		stepModel := a.settings.model
 354		stepSystemPrompt := a.settings.systemPrompt
 355		stepActiveTools := opts.ActiveTools
 356		stepToolChoice := ToolChoiceAuto
 357		disableAllTools := false
 358
 359		if opts.PrepareStep != nil {
 360			prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
 361				Model:      stepModel,
 362				Steps:      steps,
 363				StepNumber: len(steps),
 364				Messages:   stepInputMessages,
 365			})
 366			if err != nil {
 367				return nil, err
 368			}
 369
 370			// Apply prepared step modifications
 371			if prepared.Messages != nil {
 372				stepInputMessages = prepared.Messages
 373			}
 374			if prepared.Model != nil {
 375				stepModel = prepared.Model
 376			}
 377			if prepared.System != nil {
 378				stepSystemPrompt = *prepared.System
 379			}
 380			if prepared.ToolChoice != nil {
 381				stepToolChoice = *prepared.ToolChoice
 382			}
 383			if len(prepared.ActiveTools) > 0 {
 384				stepActiveTools = prepared.ActiveTools
 385			}
 386			disableAllTools = prepared.DisableAllTools
 387		}
 388
 389		// Recreate prompt with potentially modified system prompt
 390		if stepSystemPrompt != a.settings.systemPrompt {
 391			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 392			if err != nil {
 393				return nil, err
 394			}
 395			// Replace system message part, keep the rest
 396			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 397				stepInputMessages[0] = stepPrompt[0] // Replace system message
 398			}
 399		}
 400
 401		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 402
 403		retryOptions := DefaultRetryOptions()
 404		retryOptions.OnRetry = opts.OnRetry
 405		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 406
 407		result, err := retry(ctx, func() (*Response, error) {
 408			return stepModel.Generate(ctx, Call{
 409				Prompt:           stepInputMessages,
 410				MaxOutputTokens:  opts.MaxOutputTokens,
 411				Temperature:      opts.Temperature,
 412				TopP:             opts.TopP,
 413				TopK:             opts.TopK,
 414				PresencePenalty:  opts.PresencePenalty,
 415				FrequencyPenalty: opts.FrequencyPenalty,
 416				Tools:            preparedTools,
 417				ToolChoice:       &stepToolChoice,
 418				ProviderOptions:  opts.ProviderOptions,
 419			})
 420		})
 421		if err != nil {
 422			return nil, err
 423		}
 424
 425		var stepToolCalls []ToolCallContent
 426		for _, content := range result.Content {
 427			if content.GetType() == ContentTypeToolCall {
 428				toolCall, ok := AsContentType[ToolCallContent](content)
 429				if !ok {
 430					continue
 431				}
 432
 433				// Validate and potentially repair the tool call
 434				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 435				stepToolCalls = append(stepToolCalls, validatedToolCall)
 436			}
 437		}
 438
 439		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 440
 441		// Build step content with validated tool calls and tool results
 442		stepContent := []Content{}
 443		toolCallIndex := 0
 444		for _, content := range result.Content {
 445			if content.GetType() == ContentTypeToolCall {
 446				// Replace with validated tool call
 447				if toolCallIndex < len(stepToolCalls) {
 448					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 449					toolCallIndex++
 450				}
 451			} else {
 452				// Keep other content as-is
 453				stepContent = append(stepContent, content)
 454			}
 455		}
 456		// Add tool results
 457		for _, result := range toolResults {
 458			stepContent = append(stepContent, result)
 459		}
 460		currentStepMessages := toResponseMessages(stepContent)
 461		responseMessages = append(responseMessages, currentStepMessages...)
 462
 463		stepResult := StepResult{
 464			Response: Response{
 465				Content:          stepContent,
 466				FinishReason:     result.FinishReason,
 467				Usage:            result.Usage,
 468				Warnings:         result.Warnings,
 469				ProviderMetadata: result.ProviderMetadata,
 470			},
 471			Messages: currentStepMessages,
 472		}
 473		steps = append(steps, stepResult)
 474		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 475
 476		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 477			break
 478		}
 479	}
 480
 481	totalUsage := Usage{}
 482
 483	for _, step := range steps {
 484		usage := step.Usage
 485		totalUsage.InputTokens += usage.InputTokens
 486		totalUsage.OutputTokens += usage.OutputTokens
 487		totalUsage.ReasoningTokens += usage.ReasoningTokens
 488		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 489		totalUsage.CacheReadTokens += usage.CacheReadTokens
 490		totalUsage.TotalTokens += usage.TotalTokens
 491	}
 492
 493	agentResult := &AgentResult{
 494		Steps:      steps,
 495		Response:   steps[len(steps)-1].Response,
 496		TotalUsage: totalUsage,
 497	}
 498	return agentResult, nil
 499}
 500
 501func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 502	if len(conditions) == 0 {
 503		return false
 504	}
 505
 506	for _, condition := range conditions {
 507		if condition(steps) {
 508			return true
 509		}
 510	}
 511	return false
 512}
 513
 514func toResponseMessages(content []Content) []Message {
 515	var assistantParts []MessagePart
 516	var toolParts []MessagePart
 517
 518	for _, c := range content {
 519		switch c.GetType() {
 520		case ContentTypeText:
 521			text, ok := AsContentType[TextContent](c)
 522			if !ok {
 523				continue
 524			}
 525			assistantParts = append(assistantParts, TextPart{
 526				Text:            text.Text,
 527				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 528			})
 529		case ContentTypeReasoning:
 530			reasoning, ok := AsContentType[ReasoningContent](c)
 531			if !ok {
 532				continue
 533			}
 534			assistantParts = append(assistantParts, ReasoningPart{
 535				Text:            reasoning.Text,
 536				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 537			})
 538		case ContentTypeToolCall:
 539			toolCall, ok := AsContentType[ToolCallContent](c)
 540			if !ok {
 541				continue
 542			}
 543			assistantParts = append(assistantParts, ToolCallPart{
 544				ToolCallID:       toolCall.ToolCallID,
 545				ToolName:         toolCall.ToolName,
 546				Input:            toolCall.Input,
 547				ProviderExecuted: toolCall.ProviderExecuted,
 548				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 549			})
 550		case ContentTypeFile:
 551			file, ok := AsContentType[FileContent](c)
 552			if !ok {
 553				continue
 554			}
 555			assistantParts = append(assistantParts, FilePart{
 556				Data:            file.Data,
 557				MediaType:       file.MediaType,
 558				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 559			})
 560		case ContentTypeSource:
 561			// Sources are metadata about references used to generate the response.
 562			// They don't need to be included in the conversation messages.
 563			continue
 564		case ContentTypeToolResult:
 565			result, ok := AsContentType[ToolResultContent](c)
 566			if !ok {
 567				continue
 568			}
 569			toolParts = append(toolParts, ToolResultPart{
 570				ToolCallID:      result.ToolCallID,
 571				Output:          result.Result,
 572				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 573			})
 574		}
 575	}
 576
 577	var messages []Message
 578	if len(assistantParts) > 0 {
 579		messages = append(messages, Message{
 580			Role:    MessageRoleAssistant,
 581			Content: assistantParts,
 582		})
 583	}
 584	if len(toolParts) > 0 {
 585		messages = append(messages, Message{
 586			Role:    MessageRoleTool,
 587			Content: toolParts,
 588		})
 589	}
 590	return messages
 591}
 592
 593func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 594	if len(toolCalls) == 0 {
 595		return nil, nil
 596	}
 597
 598	// Create a map for quick tool lookup
 599	toolMap := make(map[string]AgentTool)
 600	for _, tool := range allTools {
 601		toolMap[tool.Info().Name] = tool
 602	}
 603
 604	// Execute all tool calls in parallel
 605	results := make([]ToolResultContent, len(toolCalls))
 606	executeErrors := make([]error, len(toolCalls))
 607
 608	var wg sync.WaitGroup
 609
 610	for i, toolCall := range toolCalls {
 611		wg.Add(1)
 612		go func(index int, call ToolCallContent) {
 613			defer wg.Done()
 614
 615			// Skip invalid tool calls - create error result
 616			if call.Invalid {
 617				results[index] = ToolResultContent{
 618					ToolCallID: call.ToolCallID,
 619					ToolName:   call.ToolName,
 620					Result: ToolResultOutputContentError{
 621						Error: call.ValidationError,
 622					},
 623					ProviderExecuted: false,
 624				}
 625				if toolResultCallback != nil {
 626					err := toolResultCallback(results[index])
 627					if err != nil {
 628						executeErrors[index] = err
 629					}
 630				}
 631
 632				return
 633			}
 634
 635			tool, exists := toolMap[call.ToolName]
 636			if !exists {
 637				results[index] = ToolResultContent{
 638					ToolCallID: call.ToolCallID,
 639					ToolName:   call.ToolName,
 640					Result: ToolResultOutputContentError{
 641						Error: errors.New("Error: Tool not found: " + call.ToolName),
 642					},
 643					ProviderExecuted: false,
 644				}
 645
 646				if toolResultCallback != nil {
 647					err := toolResultCallback(results[index])
 648					if err != nil {
 649						executeErrors[index] = err
 650					}
 651				}
 652				return
 653			}
 654
 655			// Execute the tool
 656			result, err := tool.Run(ctx, ToolCall{
 657				ID:    call.ToolCallID,
 658				Name:  call.ToolName,
 659				Input: call.Input,
 660			})
 661			if err != nil {
 662				results[index] = ToolResultContent{
 663					ToolCallID: call.ToolCallID,
 664					ToolName:   call.ToolName,
 665					Result: ToolResultOutputContentError{
 666						Error: err,
 667					},
 668					ClientMetadata:   result.Metadata,
 669					ProviderExecuted: false,
 670				}
 671				if toolResultCallback != nil {
 672					cbErr := toolResultCallback(results[index])
 673					if cbErr != nil {
 674						executeErrors[index] = cbErr
 675					}
 676				}
 677				executeErrors[index] = err
 678				return
 679			}
 680
 681			if result.IsError {
 682				results[index] = ToolResultContent{
 683					ToolCallID: call.ToolCallID,
 684					ToolName:   call.ToolName,
 685					Result: ToolResultOutputContentError{
 686						Error: errors.New(result.Content),
 687					},
 688					ClientMetadata:   result.Metadata,
 689					ProviderExecuted: false,
 690				}
 691
 692				if toolResultCallback != nil {
 693					err := toolResultCallback(results[index])
 694					if err != nil {
 695						executeErrors[index] = err
 696					}
 697				}
 698			} else {
 699				results[index] = ToolResultContent{
 700					ToolCallID: call.ToolCallID,
 701					ToolName:   toolCall.ToolName,
 702					Result: ToolResultOutputContentText{
 703						Text: result.Content,
 704					},
 705					ClientMetadata:   result.Metadata,
 706					ProviderExecuted: false,
 707				}
 708				if toolResultCallback != nil {
 709					err := toolResultCallback(results[index])
 710					if err != nil {
 711						executeErrors[index] = err
 712					}
 713				}
 714			}
 715		}(i, toolCall)
 716	}
 717
 718	// Wait for all tool executions to complete
 719	wg.Wait()
 720
 721	for _, err := range executeErrors {
 722		if err != nil {
 723			return nil, err
 724		}
 725	}
 726
 727	return results, nil
 728}
 729
 730// Stream implements Agent.
 731func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 732	// Convert AgentStreamCall to AgentCall for preparation
 733	call := AgentCall{
 734		Prompt:           opts.Prompt,
 735		Files:            opts.Files,
 736		Messages:         opts.Messages,
 737		MaxOutputTokens:  opts.MaxOutputTokens,
 738		Temperature:      opts.Temperature,
 739		TopP:             opts.TopP,
 740		TopK:             opts.TopK,
 741		PresencePenalty:  opts.PresencePenalty,
 742		FrequencyPenalty: opts.FrequencyPenalty,
 743		ActiveTools:      opts.ActiveTools,
 744		ProviderOptions:  opts.ProviderOptions,
 745		MaxRetries:       opts.MaxRetries,
 746		StopWhen:         opts.StopWhen,
 747		PrepareStep:      opts.PrepareStep,
 748		RepairToolCall:   opts.RepairToolCall,
 749	}
 750
 751	call = a.prepareCall(call)
 752
 753	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 754	if err != nil {
 755		return nil, err
 756	}
 757
 758	var responseMessages []Message
 759	var steps []StepResult
 760	var totalUsage Usage
 761
 762	// Start agent stream
 763	if opts.OnAgentStart != nil {
 764		opts.OnAgentStart()
 765	}
 766
 767	for stepNumber := 0; ; stepNumber++ {
 768		stepInputMessages := append(initialPrompt, responseMessages...)
 769		stepModel := a.settings.model
 770		stepSystemPrompt := a.settings.systemPrompt
 771		stepActiveTools := call.ActiveTools
 772		stepToolChoice := ToolChoiceAuto
 773		disableAllTools := false
 774
 775		// Apply step preparation if provided
 776		if call.PrepareStep != nil {
 777			prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
 778				Model:      stepModel,
 779				Steps:      steps,
 780				StepNumber: stepNumber,
 781				Messages:   stepInputMessages,
 782			})
 783			if err != nil {
 784				return nil, err
 785			}
 786
 787			if prepared.Messages != nil {
 788				stepInputMessages = prepared.Messages
 789			}
 790			if prepared.Model != nil {
 791				stepModel = prepared.Model
 792			}
 793			if prepared.System != nil {
 794				stepSystemPrompt = *prepared.System
 795			}
 796			if prepared.ToolChoice != nil {
 797				stepToolChoice = *prepared.ToolChoice
 798			}
 799			if len(prepared.ActiveTools) > 0 {
 800				stepActiveTools = prepared.ActiveTools
 801			}
 802			disableAllTools = prepared.DisableAllTools
 803		}
 804
 805		// Recreate prompt with potentially modified system prompt
 806		if stepSystemPrompt != a.settings.systemPrompt {
 807			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 808			if err != nil {
 809				return nil, err
 810			}
 811			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 812				stepInputMessages[0] = stepPrompt[0]
 813			}
 814		}
 815
 816		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 817
 818		// Start step stream
 819		if opts.OnStepStart != nil {
 820			_ = opts.OnStepStart(stepNumber)
 821		}
 822
 823		// Create streaming call
 824		streamCall := Call{
 825			Prompt:           stepInputMessages,
 826			MaxOutputTokens:  call.MaxOutputTokens,
 827			Temperature:      call.Temperature,
 828			TopP:             call.TopP,
 829			TopK:             call.TopK,
 830			PresencePenalty:  call.PresencePenalty,
 831			FrequencyPenalty: call.FrequencyPenalty,
 832			Tools:            preparedTools,
 833			ToolChoice:       &stepToolChoice,
 834			ProviderOptions:  call.ProviderOptions,
 835		}
 836
 837		// Get streaming response
 838		stream, err := stepModel.Stream(ctx, streamCall)
 839		if err != nil {
 840			if opts.OnError != nil {
 841				opts.OnError(err)
 842			}
 843			return nil, err
 844		}
 845
 846		// Process stream with tool execution
 847		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 848		if err != nil {
 849			if opts.OnError != nil {
 850				opts.OnError(err)
 851			}
 852			return nil, err
 853		}
 854
 855		steps = append(steps, stepResult)
 856		totalUsage = addUsage(totalUsage, stepResult.Usage)
 857
 858		// Call step finished callback
 859		if opts.OnStepFinish != nil {
 860			_ = opts.OnStepFinish(stepResult)
 861		}
 862
 863		// Add step messages to response messages
 864		stepMessages := toResponseMessages(stepResult.Content)
 865		responseMessages = append(responseMessages, stepMessages...)
 866
 867		// Check stop conditions
 868		shouldStop := isStopConditionMet(call.StopWhen, steps)
 869		if shouldStop || !shouldContinue {
 870			break
 871		}
 872	}
 873
 874	// Finish agent stream
 875	agentResult := &AgentResult{
 876		Steps:      steps,
 877		Response:   steps[len(steps)-1].Response,
 878		TotalUsage: totalUsage,
 879	}
 880
 881	if opts.OnFinish != nil {
 882		opts.OnFinish(agentResult)
 883	}
 884
 885	if opts.OnAgentFinish != nil {
 886		_ = opts.OnAgentFinish(agentResult)
 887	}
 888
 889	return agentResult, nil
 890}
 891
 892func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 893	preparedTools := make([]Tool, 0, len(tools))
 894
 895	// If explicitly disabling all tools, return no tools
 896	if disableAllTools {
 897		return preparedTools
 898	}
 899
 900	for _, tool := range tools {
 901		// If activeTools has items, only include tools in the list
 902		// If activeTools is empty, include all tools
 903		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 904			continue
 905		}
 906		info := tool.Info()
 907		preparedTools = append(preparedTools, FunctionTool{
 908			Name:        info.Name,
 909			Description: info.Description,
 910			InputSchema: map[string]any{
 911				"type":       "object",
 912				"properties": info.Parameters,
 913				"required":   info.Required,
 914			},
 915		})
 916	}
 917	return preparedTools
 918}
 919
 920// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 921func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 922	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 923		return toolCall
 924	} else {
 925		if repairFunc != nil {
 926			repairOptions := ToolCallRepairOptions{
 927				OriginalToolCall: toolCall,
 928				ValidationError:  err,
 929				AvailableTools:   availableTools,
 930				SystemPrompt:     systemPrompt,
 931				Messages:         messages,
 932			}
 933
 934			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 935				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 936					return *repairedToolCall
 937				}
 938			}
 939		}
 940
 941		invalidToolCall := toolCall
 942		invalidToolCall.Invalid = true
 943		invalidToolCall.ValidationError = err
 944		return invalidToolCall
 945	}
 946}
 947
 948// validateToolCall validates a tool call against available tools and their schemas.
 949func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 950	var tool AgentTool
 951	for _, t := range availableTools {
 952		if t.Info().Name == toolCall.ToolName {
 953			tool = t
 954			break
 955		}
 956	}
 957
 958	if tool == nil {
 959		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 960	}
 961
 962	// Validate JSON parsing
 963	var input map[string]any
 964	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 965		return fmt.Errorf("invalid JSON input: %w", err)
 966	}
 967
 968	// Basic schema validation (check required fields)
 969	// TODO: more robust schema validation using JSON Schema or similar
 970	toolInfo := tool.Info()
 971	for _, required := range toolInfo.Required {
 972		if _, exists := input[required]; !exists {
 973			return fmt.Errorf("missing required parameter: %s", required)
 974		}
 975	}
 976	return nil
 977}
 978
 979func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 980	if prompt == "" {
 981		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 982	}
 983
 984	var preparedPrompt Prompt
 985
 986	if system != "" {
 987		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 988	}
 989	preparedPrompt = append(preparedPrompt, messages...)
 990	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 991	return preparedPrompt, nil
 992}
 993
 994func WithSystemPrompt(prompt string) AgentOption {
 995	return func(s *agentSettings) {
 996		s.systemPrompt = prompt
 997	}
 998}
 999
1000func WithMaxOutputTokens(tokens int64) AgentOption {
1001	return func(s *agentSettings) {
1002		s.maxOutputTokens = &tokens
1003	}
1004}
1005
1006func WithTemperature(temp float64) AgentOption {
1007	return func(s *agentSettings) {
1008		s.temperature = &temp
1009	}
1010}
1011
1012func WithTopP(topP float64) AgentOption {
1013	return func(s *agentSettings) {
1014		s.topP = &topP
1015	}
1016}
1017
1018func WithTopK(topK int64) AgentOption {
1019	return func(s *agentSettings) {
1020		s.topK = &topK
1021	}
1022}
1023
1024func WithPresencePenalty(penalty float64) AgentOption {
1025	return func(s *agentSettings) {
1026		s.presencePenalty = &penalty
1027	}
1028}
1029
1030func WithFrequencyPenalty(penalty float64) AgentOption {
1031	return func(s *agentSettings) {
1032		s.frequencyPenalty = &penalty
1033	}
1034}
1035
1036func WithTools(tools ...AgentTool) AgentOption {
1037	return func(s *agentSettings) {
1038		s.tools = append(s.tools, tools...)
1039	}
1040}
1041
1042func WithStopConditions(conditions ...StopCondition) AgentOption {
1043	return func(s *agentSettings) {
1044		s.stopWhen = append(s.stopWhen, conditions...)
1045	}
1046}
1047
1048func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1049	return func(s *agentSettings) {
1050		s.prepareStep = fn
1051	}
1052}
1053
1054func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1055	return func(s *agentSettings) {
1056		s.repairToolCall = fn
1057	}
1058}
1059
1060// processStepStream processes a single step's stream and returns the step result.
1061func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1062	var stepContent []Content
1063	var stepToolCalls []ToolCallContent
1064	var stepUsage Usage
1065	stepFinishReason := FinishReasonUnknown
1066	var stepWarnings []CallWarning
1067	var stepProviderMetadata ProviderMetadata
1068
1069	activeToolCalls := make(map[string]*ToolCallContent)
1070	activeTextContent := make(map[string]string)
1071	type reasoningContent struct {
1072		content string
1073		options ProviderMetadata
1074	}
1075	activeReasoningContent := make(map[string]reasoningContent)
1076
1077	// Process stream parts
1078	for part := range stream {
1079		// Forward all parts to chunk callback
1080		if opts.OnChunk != nil {
1081			err := opts.OnChunk(part)
1082			if err != nil {
1083				return StepResult{}, false, err
1084			}
1085		}
1086
1087		switch part.Type {
1088		case StreamPartTypeWarnings:
1089			stepWarnings = part.Warnings
1090			if opts.OnWarnings != nil {
1091				err := opts.OnWarnings(part.Warnings)
1092				if err != nil {
1093					return StepResult{}, false, err
1094				}
1095			}
1096
1097		case StreamPartTypeTextStart:
1098			activeTextContent[part.ID] = ""
1099			if opts.OnTextStart != nil {
1100				err := opts.OnTextStart(part.ID)
1101				if err != nil {
1102					return StepResult{}, false, err
1103				}
1104			}
1105
1106		case StreamPartTypeTextDelta:
1107			if _, exists := activeTextContent[part.ID]; exists {
1108				activeTextContent[part.ID] += part.Delta
1109			}
1110			if opts.OnTextDelta != nil {
1111				err := opts.OnTextDelta(part.ID, part.Delta)
1112				if err != nil {
1113					return StepResult{}, false, err
1114				}
1115			}
1116
1117		case StreamPartTypeTextEnd:
1118			if text, exists := activeTextContent[part.ID]; exists {
1119				stepContent = append(stepContent, TextContent{
1120					Text:             text,
1121					ProviderMetadata: part.ProviderMetadata,
1122				})
1123				delete(activeTextContent, part.ID)
1124			}
1125			if opts.OnTextEnd != nil {
1126				err := opts.OnTextEnd(part.ID)
1127				if err != nil {
1128					return StepResult{}, false, err
1129				}
1130			}
1131
1132		case StreamPartTypeReasoningStart:
1133			activeReasoningContent[part.ID] = reasoningContent{content: ""}
1134			if opts.OnReasoningStart != nil {
1135				err := opts.OnReasoningStart(part.ID)
1136				if err != nil {
1137					return StepResult{}, false, err
1138				}
1139			}
1140
1141		case StreamPartTypeReasoningDelta:
1142			if active, exists := activeReasoningContent[part.ID]; exists {
1143				active.content += part.Delta
1144				active.options = part.ProviderMetadata
1145				activeReasoningContent[part.ID] = active
1146			}
1147			if opts.OnReasoningDelta != nil {
1148				err := opts.OnReasoningDelta(part.ID, part.Delta)
1149				if err != nil {
1150					return StepResult{}, false, err
1151				}
1152			}
1153
1154		case StreamPartTypeReasoningEnd:
1155			if active, exists := activeReasoningContent[part.ID]; exists {
1156				content := ReasoningContent{
1157					Text:             active.content,
1158					ProviderMetadata: active.options,
1159				}
1160				stepContent = append(stepContent, content)
1161				if opts.OnReasoningEnd != nil {
1162					err := opts.OnReasoningEnd(part.ID, content)
1163					if err != nil {
1164						return StepResult{}, false, err
1165					}
1166				}
1167				delete(activeReasoningContent, part.ID)
1168			}
1169
1170		case StreamPartTypeToolInputStart:
1171			activeToolCalls[part.ID] = &ToolCallContent{
1172				ToolCallID:       part.ID,
1173				ToolName:         part.ToolCallName,
1174				Input:            "",
1175				ProviderExecuted: part.ProviderExecuted,
1176			}
1177			if opts.OnToolInputStart != nil {
1178				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1179				if err != nil {
1180					return StepResult{}, false, err
1181				}
1182			}
1183
1184		case StreamPartTypeToolInputDelta:
1185			if toolCall, exists := activeToolCalls[part.ID]; exists {
1186				toolCall.Input += part.Delta
1187			}
1188			if opts.OnToolInputDelta != nil {
1189				err := opts.OnToolInputDelta(part.ID, part.Delta)
1190				if err != nil {
1191					return StepResult{}, false, err
1192				}
1193			}
1194
1195		case StreamPartTypeToolInputEnd:
1196			if opts.OnToolInputEnd != nil {
1197				err := opts.OnToolInputEnd(part.ID)
1198				if err != nil {
1199					return StepResult{}, false, err
1200				}
1201			}
1202
1203		case StreamPartTypeToolCall:
1204			toolCall := ToolCallContent{
1205				ToolCallID:       part.ID,
1206				ToolName:         part.ToolCallName,
1207				Input:            part.ToolCallInput,
1208				ProviderExecuted: part.ProviderExecuted,
1209				ProviderMetadata: part.ProviderMetadata,
1210			}
1211
1212			// Validate and potentially repair the tool call
1213			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1214			stepToolCalls = append(stepToolCalls, validatedToolCall)
1215			stepContent = append(stepContent, validatedToolCall)
1216
1217			if opts.OnToolCall != nil {
1218				err := opts.OnToolCall(validatedToolCall)
1219				if err != nil {
1220					return StepResult{}, false, err
1221				}
1222			}
1223
1224			// Clean up active tool call
1225			delete(activeToolCalls, part.ID)
1226
1227		case StreamPartTypeSource:
1228			sourceContent := SourceContent{
1229				SourceType:       part.SourceType,
1230				ID:               part.ID,
1231				URL:              part.URL,
1232				Title:            part.Title,
1233				ProviderMetadata: part.ProviderMetadata,
1234			}
1235			stepContent = append(stepContent, sourceContent)
1236			if opts.OnSource != nil {
1237				err := opts.OnSource(sourceContent)
1238				if err != nil {
1239					return StepResult{}, false, err
1240				}
1241			}
1242
1243		case StreamPartTypeFinish:
1244			stepUsage = part.Usage
1245			stepFinishReason = part.FinishReason
1246			stepProviderMetadata = part.ProviderMetadata
1247			if opts.OnStreamFinish != nil {
1248				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1249				if err != nil {
1250					return StepResult{}, false, err
1251				}
1252			}
1253
1254		case StreamPartTypeError:
1255			if opts.OnError != nil {
1256				opts.OnError(part.Error)
1257			}
1258			return StepResult{}, false, part.Error
1259		}
1260	}
1261
1262	// Execute tools if any
1263	var toolResults []ToolResultContent
1264	if len(stepToolCalls) > 0 {
1265		var err error
1266		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1267		if err != nil {
1268			return StepResult{}, false, err
1269		}
1270		// Add tool results to content
1271		for _, result := range toolResults {
1272			stepContent = append(stepContent, result)
1273		}
1274	}
1275
1276	stepResult := StepResult{
1277		Response: Response{
1278			Content:          stepContent,
1279			FinishReason:     stepFinishReason,
1280			Usage:            stepUsage,
1281			Warnings:         stepWarnings,
1282			ProviderMetadata: stepProviderMetadata,
1283		},
1284		Messages: toResponseMessages(stepContent),
1285	}
1286
1287	// Determine if we should continue (has tool calls and not stopped)
1288	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1289
1290	return stepResult, shouldContinue, nil
1291}
1292
1293func addUsage(a, b Usage) Usage {
1294	return Usage{
1295		InputTokens:         a.InputTokens + b.InputTokens,
1296		OutputTokens:        a.OutputTokens + b.OutputTokens,
1297		TotalTokens:         a.TotalTokens + b.TotalTokens,
1298		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1299		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1300		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1301	}
1302}
1303
1304func WithHeaders(headers map[string]string) AgentOption {
1305	return func(s *agentSettings) {
1306		s.headers = headers
1307	}
1308}
1309
1310func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1311	return func(s *agentSettings) {
1312		s.providerOptions = providerOptions
1313	}
1314}