agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11)
  12
  13// StepResult represents the result of a single step in an agent execution.
  14type StepResult struct {
  15	Response
  16	Messages []Message
  17}
  18
  19// stepExecutionResult encapsulates the result of executing a step with stream processing.
  20type stepExecutionResult struct {
  21	StepResult     StepResult
  22	ShouldContinue bool
  23}
  24
  25// StopCondition defines a function that determines when an agent should stop executing.
  26type StopCondition = func(steps []StepResult) bool
  27
  28// StepCountIs returns a stop condition that stops after the specified number of steps.
  29func StepCountIs(stepCount int) StopCondition {
  30	return func(steps []StepResult) bool {
  31		return len(steps) >= stepCount
  32	}
  33}
  34
  35// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  36func HasToolCall(toolName string) StopCondition {
  37	return func(steps []StepResult) bool {
  38		if len(steps) == 0 {
  39			return false
  40		}
  41		lastStep := steps[len(steps)-1]
  42		toolCalls := lastStep.Content.ToolCalls()
  43		for _, toolCall := range toolCalls {
  44			if toolCall.ToolName == toolName {
  45				return true
  46			}
  47		}
  48		return false
  49	}
  50}
  51
  52// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  53func HasContent(contentType ContentType) StopCondition {
  54	return func(steps []StepResult) bool {
  55		if len(steps) == 0 {
  56			return false
  57		}
  58		lastStep := steps[len(steps)-1]
  59		for _, content := range lastStep.Content {
  60			if content.GetType() == contentType {
  61				return true
  62			}
  63		}
  64		return false
  65	}
  66}
  67
  68// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  69func FinishReasonIs(reason FinishReason) StopCondition {
  70	return func(steps []StepResult) bool {
  71		if len(steps) == 0 {
  72			return false
  73		}
  74		lastStep := steps[len(steps)-1]
  75		return lastStep.FinishReason == reason
  76	}
  77}
  78
  79// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  80func MaxTokensUsed(maxTokens int64) StopCondition {
  81	return func(steps []StepResult) bool {
  82		var totalTokens int64
  83		for _, step := range steps {
  84			totalTokens += step.Usage.TotalTokens
  85		}
  86		return totalTokens >= maxTokens
  87	}
  88}
  89
  90// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  91type PrepareStepFunctionOptions struct {
  92	Steps      []StepResult
  93	StepNumber int
  94	Model      LanguageModel
  95	Messages   []Message
  96}
  97
  98// PrepareStepResult contains the result of preparing a step in an agent execution.
  99type PrepareStepResult struct {
 100	Model           LanguageModel
 101	Messages        []Message
 102	System          *string
 103	ToolChoice      *ToolChoice
 104	ActiveTools     []string
 105	DisableAllTools bool
 106}
 107
 108// ToolCallRepairOptions contains the options for repairing a tool call.
 109type ToolCallRepairOptions struct {
 110	OriginalToolCall ToolCallContent
 111	ValidationError  error
 112	AvailableTools   []AgentTool
 113	SystemPrompt     string
 114	Messages         []Message
 115}
 116
 117type (
 118	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 119	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 120
 121	// OnStepFinishedFunction defines a function that is called when a step finishes.
 122	OnStepFinishedFunction = func(step StepResult)
 123
 124	// RepairToolCallFunction defines a function that repairs a tool call.
 125	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 126)
 127
 128type agentSettings struct {
 129	systemPrompt     string
 130	maxOutputTokens  *int64
 131	temperature      *float64
 132	topP             *float64
 133	topK             *int64
 134	presencePenalty  *float64
 135	frequencyPenalty *float64
 136	headers          map[string]string
 137	providerOptions  ProviderOptions
 138
 139	// TODO: add support for provider tools
 140	tools      []AgentTool
 141	maxRetries *int
 142
 143	model LanguageModel
 144
 145	stopWhen       []StopCondition
 146	prepareStep    PrepareStepFunction
 147	repairToolCall RepairToolCallFunction
 148	onRetry        OnRetryCallback
 149}
 150
 151// AgentCall represents a call to an agent.
 152type AgentCall struct {
 153	Prompt           string     `json:"prompt"`
 154	Files            []FilePart `json:"files"`
 155	Messages         []Message  `json:"messages"`
 156	MaxOutputTokens  *int64
 157	Temperature      *float64 `json:"temperature"`
 158	TopP             *float64 `json:"top_p"`
 159	TopK             *int64   `json:"top_k"`
 160	PresencePenalty  *float64 `json:"presence_penalty"`
 161	FrequencyPenalty *float64 `json:"frequency_penalty"`
 162	ActiveTools      []string `json:"active_tools"`
 163	ProviderOptions  ProviderOptions
 164	OnRetry          OnRetryCallback
 165	MaxRetries       *int
 166
 167	StopWhen       []StopCondition
 168	PrepareStep    PrepareStepFunction
 169	RepairToolCall RepairToolCallFunction
 170}
 171
 172// Agent-level callbacks.
 173type (
 174	// OnAgentStartFunc is called when agent starts.
 175	OnAgentStartFunc func()
 176
 177	// OnAgentFinishFunc is called when agent finishes.
 178	OnAgentFinishFunc func(result *AgentResult) error
 179
 180	// OnStepStartFunc is called when a step starts.
 181	OnStepStartFunc func(stepNumber int) error
 182
 183	// OnStepFinishFunc is called when a step finishes.
 184	OnStepFinishFunc func(stepResult StepResult) error
 185
 186	// OnFinishFunc is called when entire agent completes.
 187	OnFinishFunc func(result *AgentResult)
 188
 189	// OnErrorFunc is called when an error occurs.
 190	OnErrorFunc func(error)
 191)
 192
 193// Stream part callbacks - called for each corresponding stream part type.
 194type (
 195	// OnChunkFunc is called for each stream part (catch-all).
 196	OnChunkFunc func(StreamPart) error
 197
 198	// OnWarningsFunc is called for warnings.
 199	OnWarningsFunc func(warnings []CallWarning) error
 200
 201	// OnTextStartFunc is called when text starts.
 202	OnTextStartFunc func(id string) error
 203
 204	// OnTextDeltaFunc is called for text deltas.
 205	OnTextDeltaFunc func(id, text string) error
 206
 207	// OnTextEndFunc is called when text ends.
 208	OnTextEndFunc func(id string) error
 209
 210	// OnReasoningStartFunc is called when reasoning starts.
 211	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 212
 213	// OnReasoningDeltaFunc is called for reasoning deltas.
 214	OnReasoningDeltaFunc func(id, text string) error
 215
 216	// OnReasoningEndFunc is called when reasoning ends.
 217	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 218
 219	// OnToolInputStartFunc is called when tool input starts.
 220	OnToolInputStartFunc func(id, toolName string) error
 221
 222	// OnToolInputDeltaFunc is called for tool input deltas.
 223	OnToolInputDeltaFunc func(id, delta string) error
 224
 225	// OnToolInputEndFunc is called when tool input ends.
 226	OnToolInputEndFunc func(id string) error
 227
 228	// OnToolCallFunc is called when tool call is complete.
 229	OnToolCallFunc func(toolCall ToolCallContent) error
 230
 231	// OnToolResultFunc is called when tool execution completes.
 232	OnToolResultFunc func(result ToolResultContent) error
 233
 234	// OnSourceFunc is called for source references.
 235	OnSourceFunc func(source SourceContent) error
 236
 237	// OnStreamFinishFunc is called when stream finishes.
 238	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 239)
 240
 241// AgentStreamCall represents a streaming call to an agent.
 242type AgentStreamCall struct {
 243	Prompt           string     `json:"prompt"`
 244	Files            []FilePart `json:"files"`
 245	Messages         []Message  `json:"messages"`
 246	MaxOutputTokens  *int64
 247	Temperature      *float64 `json:"temperature"`
 248	TopP             *float64 `json:"top_p"`
 249	TopK             *int64   `json:"top_k"`
 250	PresencePenalty  *float64 `json:"presence_penalty"`
 251	FrequencyPenalty *float64 `json:"frequency_penalty"`
 252	ActiveTools      []string `json:"active_tools"`
 253	Headers          map[string]string
 254	ProviderOptions  ProviderOptions
 255	OnRetry          OnRetryCallback
 256	MaxRetries       *int
 257
 258	StopWhen       []StopCondition
 259	PrepareStep    PrepareStepFunction
 260	RepairToolCall RepairToolCallFunction
 261
 262	// Agent-level callbacks
 263	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 264	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 265	OnStepStart   OnStepStartFunc   // Called when a step starts
 266	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 267	OnFinish      OnFinishFunc      // Called when entire agent completes
 268	OnError       OnErrorFunc       // Called when an error occurs
 269
 270	// Stream part callbacks - called for each corresponding stream part type
 271	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 272	OnWarnings       OnWarningsFunc       // Called for warnings
 273	OnTextStart      OnTextStartFunc      // Called when text starts
 274	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 275	OnTextEnd        OnTextEndFunc        // Called when text ends
 276	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 277	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 278	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 279	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 280	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 281	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 282	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 283	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 284	OnSource         OnSourceFunc         // Called for source references
 285	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 286}
 287
 288// AgentResult represents the result of an agent execution.
 289type AgentResult struct {
 290	Steps []StepResult
 291	// Final response
 292	Response   Response
 293	TotalUsage Usage
 294}
 295
 296// Agent represents an AI agent that can generate responses and stream responses.
 297type Agent interface {
 298	Generate(context.Context, AgentCall) (*AgentResult, error)
 299	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 300}
 301
 302// AgentOption defines a function that configures agent settings.
 303type AgentOption = func(*agentSettings)
 304
 305type agent struct {
 306	settings agentSettings
 307}
 308
 309// NewAgent creates a new agent with the given language model and options.
 310func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 311	settings := agentSettings{
 312		model: model,
 313	}
 314	for _, o := range opts {
 315		o(&settings)
 316	}
 317	return &agent{
 318		settings: settings,
 319	}
 320}
 321
 322func (a *agent) prepareCall(call AgentCall) AgentCall {
 323	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 324	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 325	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 326	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 327	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 328	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 329	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 330
 331	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 332		call.StopWhen = a.settings.stopWhen
 333	}
 334	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 335		call.PrepareStep = a.settings.prepareStep
 336	}
 337	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 338		call.RepairToolCall = a.settings.repairToolCall
 339	}
 340	if call.OnRetry == nil && a.settings.onRetry != nil {
 341		call.OnRetry = a.settings.onRetry
 342	}
 343
 344	providerOptions := ProviderOptions{}
 345	if a.settings.providerOptions != nil {
 346		maps.Copy(providerOptions, a.settings.providerOptions)
 347	}
 348	if call.ProviderOptions != nil {
 349		maps.Copy(providerOptions, call.ProviderOptions)
 350	}
 351	call.ProviderOptions = providerOptions
 352
 353	headers := map[string]string{}
 354
 355	if a.settings.headers != nil {
 356		maps.Copy(headers, a.settings.headers)
 357	}
 358
 359	return call
 360}
 361
 362// Generate implements Agent.
 363func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 364	opts = a.prepareCall(opts)
 365	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 366	if err != nil {
 367		return nil, err
 368	}
 369	var responseMessages []Message
 370	var steps []StepResult
 371
 372	for {
 373		stepInputMessages := append(initialPrompt, responseMessages...)
 374		stepModel := a.settings.model
 375		stepSystemPrompt := a.settings.systemPrompt
 376		stepActiveTools := opts.ActiveTools
 377		stepToolChoice := ToolChoiceAuto
 378		disableAllTools := false
 379
 380		if opts.PrepareStep != nil {
 381			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 382				Model:      stepModel,
 383				Steps:      steps,
 384				StepNumber: len(steps),
 385				Messages:   stepInputMessages,
 386			})
 387			if err != nil {
 388				return nil, err
 389			}
 390
 391			ctx = updatedCtx
 392
 393			// Apply prepared step modifications
 394			if prepared.Messages != nil {
 395				stepInputMessages = prepared.Messages
 396			}
 397			if prepared.Model != nil {
 398				stepModel = prepared.Model
 399			}
 400			if prepared.System != nil {
 401				stepSystemPrompt = *prepared.System
 402			}
 403			if prepared.ToolChoice != nil {
 404				stepToolChoice = *prepared.ToolChoice
 405			}
 406			if len(prepared.ActiveTools) > 0 {
 407				stepActiveTools = prepared.ActiveTools
 408			}
 409			disableAllTools = prepared.DisableAllTools
 410		}
 411
 412		// Recreate prompt with potentially modified system prompt
 413		if stepSystemPrompt != a.settings.systemPrompt {
 414			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 415			if err != nil {
 416				return nil, err
 417			}
 418			// Replace system message part, keep the rest
 419			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 420				stepInputMessages[0] = stepPrompt[0] // Replace system message
 421			}
 422		}
 423
 424		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 425
 426		retryOptions := DefaultRetryOptions()
 427		if opts.MaxRetries != nil {
 428			retryOptions.MaxRetries = *opts.MaxRetries
 429		}
 430		retryOptions.OnRetry = opts.OnRetry
 431		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 432
 433		result, err := retry(ctx, func() (*Response, error) {
 434			return stepModel.Generate(ctx, Call{
 435				Prompt:           stepInputMessages,
 436				MaxOutputTokens:  opts.MaxOutputTokens,
 437				Temperature:      opts.Temperature,
 438				TopP:             opts.TopP,
 439				TopK:             opts.TopK,
 440				PresencePenalty:  opts.PresencePenalty,
 441				FrequencyPenalty: opts.FrequencyPenalty,
 442				Tools:            preparedTools,
 443				ToolChoice:       &stepToolChoice,
 444				ProviderOptions:  opts.ProviderOptions,
 445			})
 446		})
 447		if err != nil {
 448			return nil, err
 449		}
 450
 451		var stepToolCalls []ToolCallContent
 452		for _, content := range result.Content {
 453			if content.GetType() == ContentTypeToolCall {
 454				toolCall, ok := AsContentType[ToolCallContent](content)
 455				if !ok {
 456					continue
 457				}
 458
 459				// Validate and potentially repair the tool call
 460				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 461				stepToolCalls = append(stepToolCalls, validatedToolCall)
 462			}
 463		}
 464
 465		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 466
 467		// Build step content with validated tool calls and tool results
 468		stepContent := []Content{}
 469		toolCallIndex := 0
 470		for _, content := range result.Content {
 471			if content.GetType() == ContentTypeToolCall {
 472				// Replace with validated tool call
 473				if toolCallIndex < len(stepToolCalls) {
 474					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 475					toolCallIndex++
 476				}
 477			} else {
 478				// Keep other content as-is
 479				stepContent = append(stepContent, content)
 480			}
 481		}
 482		// Add tool results
 483		for _, result := range toolResults {
 484			stepContent = append(stepContent, result)
 485		}
 486		currentStepMessages := toResponseMessages(stepContent)
 487		responseMessages = append(responseMessages, currentStepMessages...)
 488
 489		stepResult := StepResult{
 490			Response: Response{
 491				Content:          stepContent,
 492				FinishReason:     result.FinishReason,
 493				Usage:            result.Usage,
 494				Warnings:         result.Warnings,
 495				ProviderMetadata: result.ProviderMetadata,
 496			},
 497			Messages: currentStepMessages,
 498		}
 499		steps = append(steps, stepResult)
 500		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 501
 502		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 503			break
 504		}
 505	}
 506
 507	totalUsage := Usage{}
 508
 509	for _, step := range steps {
 510		usage := step.Usage
 511		totalUsage.InputTokens += usage.InputTokens
 512		totalUsage.OutputTokens += usage.OutputTokens
 513		totalUsage.ReasoningTokens += usage.ReasoningTokens
 514		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 515		totalUsage.CacheReadTokens += usage.CacheReadTokens
 516		totalUsage.TotalTokens += usage.TotalTokens
 517	}
 518
 519	agentResult := &AgentResult{
 520		Steps:      steps,
 521		Response:   steps[len(steps)-1].Response,
 522		TotalUsage: totalUsage,
 523	}
 524	return agentResult, nil
 525}
 526
 527func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 528	if len(conditions) == 0 {
 529		return false
 530	}
 531
 532	for _, condition := range conditions {
 533		if condition(steps) {
 534			return true
 535		}
 536	}
 537	return false
 538}
 539
 540func toResponseMessages(content []Content) []Message {
 541	var assistantParts []MessagePart
 542	var toolParts []MessagePart
 543
 544	for _, c := range content {
 545		switch c.GetType() {
 546		case ContentTypeText:
 547			text, ok := AsContentType[TextContent](c)
 548			if !ok {
 549				continue
 550			}
 551			assistantParts = append(assistantParts, TextPart{
 552				Text:            text.Text,
 553				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 554			})
 555		case ContentTypeReasoning:
 556			reasoning, ok := AsContentType[ReasoningContent](c)
 557			if !ok {
 558				continue
 559			}
 560			assistantParts = append(assistantParts, ReasoningPart{
 561				Text:            reasoning.Text,
 562				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 563			})
 564		case ContentTypeToolCall:
 565			toolCall, ok := AsContentType[ToolCallContent](c)
 566			if !ok {
 567				continue
 568			}
 569			assistantParts = append(assistantParts, ToolCallPart{
 570				ToolCallID:       toolCall.ToolCallID,
 571				ToolName:         toolCall.ToolName,
 572				Input:            toolCall.Input,
 573				ProviderExecuted: toolCall.ProviderExecuted,
 574				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 575			})
 576		case ContentTypeFile:
 577			file, ok := AsContentType[FileContent](c)
 578			if !ok {
 579				continue
 580			}
 581			assistantParts = append(assistantParts, FilePart{
 582				Data:            file.Data,
 583				MediaType:       file.MediaType,
 584				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 585			})
 586		case ContentTypeSource:
 587			// Sources are metadata about references used to generate the response.
 588			// They don't need to be included in the conversation messages.
 589			continue
 590		case ContentTypeToolResult:
 591			result, ok := AsContentType[ToolResultContent](c)
 592			if !ok {
 593				continue
 594			}
 595			toolParts = append(toolParts, ToolResultPart{
 596				ToolCallID:      result.ToolCallID,
 597				Output:          result.Result,
 598				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 599			})
 600		}
 601	}
 602
 603	var messages []Message
 604	if len(assistantParts) > 0 {
 605		messages = append(messages, Message{
 606			Role:    MessageRoleAssistant,
 607			Content: assistantParts,
 608		})
 609	}
 610	if len(toolParts) > 0 {
 611		messages = append(messages, Message{
 612			Role:    MessageRoleTool,
 613			Content: toolParts,
 614		})
 615	}
 616	return messages
 617}
 618
 619func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 620	if len(toolCalls) == 0 {
 621		return nil, nil
 622	}
 623
 624	// Create a map for quick tool lookup
 625	toolMap := make(map[string]AgentTool)
 626	for _, tool := range allTools {
 627		toolMap[tool.Info().Name] = tool
 628	}
 629
 630	// Execute all tool calls sequentially in order
 631	results := make([]ToolResultContent, 0, len(toolCalls))
 632
 633	for _, toolCall := range toolCalls {
 634		// Skip invalid tool calls - create error result
 635		if toolCall.Invalid {
 636			result := ToolResultContent{
 637				ToolCallID: toolCall.ToolCallID,
 638				ToolName:   toolCall.ToolName,
 639				Result: ToolResultOutputContentError{
 640					Error: toolCall.ValidationError,
 641				},
 642				ProviderExecuted: false,
 643			}
 644			results = append(results, result)
 645			if toolResultCallback != nil {
 646				if err := toolResultCallback(result); err != nil {
 647					return nil, err
 648				}
 649			}
 650			continue
 651		}
 652
 653		tool, exists := toolMap[toolCall.ToolName]
 654		if !exists {
 655			result := ToolResultContent{
 656				ToolCallID: toolCall.ToolCallID,
 657				ToolName:   toolCall.ToolName,
 658				Result: ToolResultOutputContentError{
 659					Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
 660				},
 661				ProviderExecuted: false,
 662			}
 663			results = append(results, result)
 664			if toolResultCallback != nil {
 665				if err := toolResultCallback(result); err != nil {
 666					return nil, err
 667				}
 668			}
 669			continue
 670		}
 671
 672		// Execute the tool
 673		toolResult, err := tool.Run(ctx, ToolCall{
 674			ID:    toolCall.ToolCallID,
 675			Name:  toolCall.ToolName,
 676			Input: toolCall.Input,
 677		})
 678		if err != nil {
 679			result := ToolResultContent{
 680				ToolCallID: toolCall.ToolCallID,
 681				ToolName:   toolCall.ToolName,
 682				Result: ToolResultOutputContentError{
 683					Error: err,
 684				},
 685				ClientMetadata:   toolResult.Metadata,
 686				ProviderExecuted: false,
 687			}
 688			if toolResultCallback != nil {
 689				if cbErr := toolResultCallback(result); cbErr != nil {
 690					return nil, cbErr
 691				}
 692			}
 693			return nil, err
 694		}
 695
 696		var result ToolResultContent
 697		if toolResult.IsError {
 698			result = ToolResultContent{
 699				ToolCallID: toolCall.ToolCallID,
 700				ToolName:   toolCall.ToolName,
 701				Result: ToolResultOutputContentError{
 702					Error: errors.New(toolResult.Content),
 703				},
 704				ClientMetadata:   toolResult.Metadata,
 705				ProviderExecuted: false,
 706			}
 707		} else {
 708			result = ToolResultContent{
 709				ToolCallID: toolCall.ToolCallID,
 710				ToolName:   toolCall.ToolName,
 711				Result: ToolResultOutputContentText{
 712					Text: toolResult.Content,
 713				},
 714				ClientMetadata:   toolResult.Metadata,
 715				ProviderExecuted: false,
 716			}
 717		}
 718		results = append(results, result)
 719		if toolResultCallback != nil {
 720			if err := toolResultCallback(result); err != nil {
 721				return nil, err
 722			}
 723		}
 724	}
 725
 726	return results, nil
 727}
 728
 729// Stream implements Agent.
 730func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 731	// Convert AgentStreamCall to AgentCall for preparation
 732	call := AgentCall{
 733		Prompt:           opts.Prompt,
 734		Files:            opts.Files,
 735		Messages:         opts.Messages,
 736		MaxOutputTokens:  opts.MaxOutputTokens,
 737		Temperature:      opts.Temperature,
 738		TopP:             opts.TopP,
 739		TopK:             opts.TopK,
 740		PresencePenalty:  opts.PresencePenalty,
 741		FrequencyPenalty: opts.FrequencyPenalty,
 742		ActiveTools:      opts.ActiveTools,
 743		ProviderOptions:  opts.ProviderOptions,
 744		MaxRetries:       opts.MaxRetries,
 745		OnRetry:          opts.OnRetry,
 746		StopWhen:         opts.StopWhen,
 747		PrepareStep:      opts.PrepareStep,
 748		RepairToolCall:   opts.RepairToolCall,
 749	}
 750
 751	call = a.prepareCall(call)
 752
 753	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 754	if err != nil {
 755		return nil, err
 756	}
 757
 758	var responseMessages []Message
 759	var steps []StepResult
 760	var totalUsage Usage
 761
 762	// Start agent stream
 763	if opts.OnAgentStart != nil {
 764		opts.OnAgentStart()
 765	}
 766
 767	for stepNumber := 0; ; stepNumber++ {
 768		stepInputMessages := append(initialPrompt, responseMessages...)
 769		stepModel := a.settings.model
 770		stepSystemPrompt := a.settings.systemPrompt
 771		stepActiveTools := call.ActiveTools
 772		stepToolChoice := ToolChoiceAuto
 773		disableAllTools := false
 774
 775		// Apply step preparation if provided
 776		if call.PrepareStep != nil {
 777			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 778				Model:      stepModel,
 779				Steps:      steps,
 780				StepNumber: stepNumber,
 781				Messages:   stepInputMessages,
 782			})
 783			if err != nil {
 784				return nil, err
 785			}
 786
 787			ctx = updatedCtx
 788
 789			if prepared.Messages != nil {
 790				stepInputMessages = prepared.Messages
 791			}
 792			if prepared.Model != nil {
 793				stepModel = prepared.Model
 794			}
 795			if prepared.System != nil {
 796				stepSystemPrompt = *prepared.System
 797			}
 798			if prepared.ToolChoice != nil {
 799				stepToolChoice = *prepared.ToolChoice
 800			}
 801			if len(prepared.ActiveTools) > 0 {
 802				stepActiveTools = prepared.ActiveTools
 803			}
 804			disableAllTools = prepared.DisableAllTools
 805		}
 806
 807		// Recreate prompt with potentially modified system prompt
 808		if stepSystemPrompt != a.settings.systemPrompt {
 809			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 810			if err != nil {
 811				return nil, err
 812			}
 813			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 814				stepInputMessages[0] = stepPrompt[0]
 815			}
 816		}
 817
 818		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 819
 820		// Start step stream
 821		if opts.OnStepStart != nil {
 822			_ = opts.OnStepStart(stepNumber)
 823		}
 824
 825		// Create streaming call
 826		streamCall := Call{
 827			Prompt:           stepInputMessages,
 828			MaxOutputTokens:  call.MaxOutputTokens,
 829			Temperature:      call.Temperature,
 830			TopP:             call.TopP,
 831			TopK:             call.TopK,
 832			PresencePenalty:  call.PresencePenalty,
 833			FrequencyPenalty: call.FrequencyPenalty,
 834			Tools:            preparedTools,
 835			ToolChoice:       &stepToolChoice,
 836			ProviderOptions:  call.ProviderOptions,
 837		}
 838
 839		// Execute step with retry logic wrapping both stream creation and processing
 840		retryOptions := DefaultRetryOptions()
 841		if call.MaxRetries != nil {
 842			retryOptions.MaxRetries = *call.MaxRetries
 843		}
 844		retryOptions.OnRetry = call.OnRetry
 845		retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
 846
 847		result, err := retry(ctx, func() (stepExecutionResult, error) {
 848			// Create the stream
 849			stream, err := stepModel.Stream(ctx, streamCall)
 850			if err != nil {
 851				return stepExecutionResult{}, err
 852			}
 853
 854			// Process the stream
 855			result, err := a.processStepStream(ctx, stream, opts, steps)
 856			if err != nil {
 857				return stepExecutionResult{}, err
 858			}
 859
 860			return result, nil
 861		})
 862		if err != nil {
 863			if opts.OnError != nil {
 864				opts.OnError(err)
 865			}
 866			return nil, err
 867		}
 868
 869		steps = append(steps, result.StepResult)
 870		totalUsage = addUsage(totalUsage, result.StepResult.Usage)
 871
 872		// Call step finished callback
 873		if opts.OnStepFinish != nil {
 874			_ = opts.OnStepFinish(result.StepResult)
 875		}
 876
 877		// Add step messages to response messages
 878		stepMessages := toResponseMessages(result.StepResult.Content)
 879		responseMessages = append(responseMessages, stepMessages...)
 880
 881		// Check stop conditions
 882		shouldStop := isStopConditionMet(call.StopWhen, steps)
 883		if shouldStop || !result.ShouldContinue {
 884			break
 885		}
 886	}
 887
 888	// Finish agent stream
 889	agentResult := &AgentResult{
 890		Steps:      steps,
 891		Response:   steps[len(steps)-1].Response,
 892		TotalUsage: totalUsage,
 893	}
 894
 895	if opts.OnFinish != nil {
 896		opts.OnFinish(agentResult)
 897	}
 898
 899	if opts.OnAgentFinish != nil {
 900		_ = opts.OnAgentFinish(agentResult)
 901	}
 902
 903	return agentResult, nil
 904}
 905
 906func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 907	preparedTools := make([]Tool, 0, len(tools))
 908
 909	// If explicitly disabling all tools, return no tools
 910	if disableAllTools {
 911		return preparedTools
 912	}
 913
 914	for _, tool := range tools {
 915		// If activeTools has items, only include tools in the list
 916		// If activeTools is empty, include all tools
 917		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 918			continue
 919		}
 920		info := tool.Info()
 921		preparedTools = append(preparedTools, FunctionTool{
 922			Name:        info.Name,
 923			Description: info.Description,
 924			InputSchema: map[string]any{
 925				"type":       "object",
 926				"properties": info.Parameters,
 927				"required":   info.Required,
 928			},
 929			ProviderOptions: tool.ProviderOptions(),
 930		})
 931	}
 932	return preparedTools
 933}
 934
 935// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 936func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 937	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 938		return toolCall
 939	} else { //nolint: revive
 940		if repairFunc != nil {
 941			repairOptions := ToolCallRepairOptions{
 942				OriginalToolCall: toolCall,
 943				ValidationError:  err,
 944				AvailableTools:   availableTools,
 945				SystemPrompt:     systemPrompt,
 946				Messages:         messages,
 947			}
 948
 949			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 950				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 951					return *repairedToolCall
 952				}
 953			}
 954		}
 955
 956		invalidToolCall := toolCall
 957		invalidToolCall.Invalid = true
 958		invalidToolCall.ValidationError = err
 959		return invalidToolCall
 960	}
 961}
 962
 963// validateToolCall validates a tool call against available tools and their schemas.
 964func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 965	var tool AgentTool
 966	for _, t := range availableTools {
 967		if t.Info().Name == toolCall.ToolName {
 968			tool = t
 969			break
 970		}
 971	}
 972
 973	if tool == nil {
 974		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 975	}
 976
 977	// Validate JSON parsing
 978	var input map[string]any
 979	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 980		return fmt.Errorf("invalid JSON input: %w", err)
 981	}
 982
 983	// Basic schema validation (check required fields)
 984	// TODO: more robust schema validation using JSON Schema or similar
 985	toolInfo := tool.Info()
 986	for _, required := range toolInfo.Required {
 987		if _, exists := input[required]; !exists {
 988			return fmt.Errorf("missing required parameter: %s", required)
 989		}
 990	}
 991	return nil
 992}
 993
 994func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 995	if prompt == "" {
 996		return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
 997	}
 998
 999	var preparedPrompt Prompt
1000
1001	if system != "" {
1002		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1003	}
1004	preparedPrompt = append(preparedPrompt, messages...)
1005	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1006	return preparedPrompt, nil
1007}
1008
1009// WithSystemPrompt sets the system prompt for the agent.
1010func WithSystemPrompt(prompt string) AgentOption {
1011	return func(s *agentSettings) {
1012		s.systemPrompt = prompt
1013	}
1014}
1015
1016// WithMaxOutputTokens sets the maximum output tokens for the agent.
1017func WithMaxOutputTokens(tokens int64) AgentOption {
1018	return func(s *agentSettings) {
1019		s.maxOutputTokens = &tokens
1020	}
1021}
1022
1023// WithTemperature sets the temperature for the agent.
1024func WithTemperature(temp float64) AgentOption {
1025	return func(s *agentSettings) {
1026		s.temperature = &temp
1027	}
1028}
1029
1030// WithTopP sets the top-p value for the agent.
1031func WithTopP(topP float64) AgentOption {
1032	return func(s *agentSettings) {
1033		s.topP = &topP
1034	}
1035}
1036
1037// WithTopK sets the top-k value for the agent.
1038func WithTopK(topK int64) AgentOption {
1039	return func(s *agentSettings) {
1040		s.topK = &topK
1041	}
1042}
1043
1044// WithPresencePenalty sets the presence penalty for the agent.
1045func WithPresencePenalty(penalty float64) AgentOption {
1046	return func(s *agentSettings) {
1047		s.presencePenalty = &penalty
1048	}
1049}
1050
1051// WithFrequencyPenalty sets the frequency penalty for the agent.
1052func WithFrequencyPenalty(penalty float64) AgentOption {
1053	return func(s *agentSettings) {
1054		s.frequencyPenalty = &penalty
1055	}
1056}
1057
1058// WithTools sets the tools for the agent.
1059func WithTools(tools ...AgentTool) AgentOption {
1060	return func(s *agentSettings) {
1061		s.tools = append(s.tools, tools...)
1062	}
1063}
1064
1065// WithStopConditions sets the stop conditions for the agent.
1066func WithStopConditions(conditions ...StopCondition) AgentOption {
1067	return func(s *agentSettings) {
1068		s.stopWhen = append(s.stopWhen, conditions...)
1069	}
1070}
1071
1072// WithPrepareStep sets the prepare step function for the agent.
1073func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1074	return func(s *agentSettings) {
1075		s.prepareStep = fn
1076	}
1077}
1078
1079// WithRepairToolCall sets the repair tool call function for the agent.
1080func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1081	return func(s *agentSettings) {
1082		s.repairToolCall = fn
1083	}
1084}
1085
1086// WithMaxRetries sets the maximum number of retries for the agent.
1087func WithMaxRetries(maxRetries int) AgentOption {
1088	return func(s *agentSettings) {
1089		s.maxRetries = &maxRetries
1090	}
1091}
1092
1093// WithOnRetry sets the retry callback for the agent.
1094func WithOnRetry(callback OnRetryCallback) AgentOption {
1095	return func(s *agentSettings) {
1096		s.onRetry = callback
1097	}
1098}
1099
1100// processStepStream processes a single step's stream and returns the step result.
1101func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) {
1102	var stepContent []Content
1103	var stepToolCalls []ToolCallContent
1104	var stepUsage Usage
1105	stepFinishReason := FinishReasonUnknown
1106	var stepWarnings []CallWarning
1107	var stepProviderMetadata ProviderMetadata
1108
1109	activeToolCalls := make(map[string]*ToolCallContent)
1110	activeTextContent := make(map[string]string)
1111	type reasoningContent struct {
1112		content string
1113		options ProviderMetadata
1114	}
1115	activeReasoningContent := make(map[string]reasoningContent)
1116
1117	// Process stream parts
1118	for part := range stream {
1119		// Forward all parts to chunk callback
1120		if opts.OnChunk != nil {
1121			err := opts.OnChunk(part)
1122			if err != nil {
1123				return stepExecutionResult{}, err
1124			}
1125		}
1126
1127		switch part.Type {
1128		case StreamPartTypeWarnings:
1129			stepWarnings = part.Warnings
1130			if opts.OnWarnings != nil {
1131				err := opts.OnWarnings(part.Warnings)
1132				if err != nil {
1133					return stepExecutionResult{}, err
1134				}
1135			}
1136
1137		case StreamPartTypeTextStart:
1138			activeTextContent[part.ID] = ""
1139			if opts.OnTextStart != nil {
1140				err := opts.OnTextStart(part.ID)
1141				if err != nil {
1142					return stepExecutionResult{}, err
1143				}
1144			}
1145
1146		case StreamPartTypeTextDelta:
1147			if _, exists := activeTextContent[part.ID]; exists {
1148				activeTextContent[part.ID] += part.Delta
1149			}
1150			if opts.OnTextDelta != nil {
1151				err := opts.OnTextDelta(part.ID, part.Delta)
1152				if err != nil {
1153					return stepExecutionResult{}, err
1154				}
1155			}
1156
1157		case StreamPartTypeTextEnd:
1158			if text, exists := activeTextContent[part.ID]; exists {
1159				stepContent = append(stepContent, TextContent{
1160					Text:             text,
1161					ProviderMetadata: part.ProviderMetadata,
1162				})
1163				delete(activeTextContent, part.ID)
1164			}
1165			if opts.OnTextEnd != nil {
1166				err := opts.OnTextEnd(part.ID)
1167				if err != nil {
1168					return stepExecutionResult{}, err
1169				}
1170			}
1171
1172		case StreamPartTypeReasoningStart:
1173			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1174			if opts.OnReasoningStart != nil {
1175				content := ReasoningContent{
1176					Text:             part.Delta,
1177					ProviderMetadata: part.ProviderMetadata,
1178				}
1179				err := opts.OnReasoningStart(part.ID, content)
1180				if err != nil {
1181					return stepExecutionResult{}, err
1182				}
1183			}
1184
1185		case StreamPartTypeReasoningDelta:
1186			if active, exists := activeReasoningContent[part.ID]; exists {
1187				active.content += part.Delta
1188				active.options = part.ProviderMetadata
1189				activeReasoningContent[part.ID] = active
1190			}
1191			if opts.OnReasoningDelta != nil {
1192				err := opts.OnReasoningDelta(part.ID, part.Delta)
1193				if err != nil {
1194					return stepExecutionResult{}, err
1195				}
1196			}
1197
1198		case StreamPartTypeReasoningEnd:
1199			if active, exists := activeReasoningContent[part.ID]; exists {
1200				if part.ProviderMetadata != nil {
1201					active.options = part.ProviderMetadata
1202				}
1203				content := ReasoningContent{
1204					Text:             active.content,
1205					ProviderMetadata: active.options,
1206				}
1207				stepContent = append(stepContent, content)
1208				if opts.OnReasoningEnd != nil {
1209					err := opts.OnReasoningEnd(part.ID, content)
1210					if err != nil {
1211						return stepExecutionResult{}, err
1212					}
1213				}
1214				delete(activeReasoningContent, part.ID)
1215			}
1216
1217		case StreamPartTypeToolInputStart:
1218			activeToolCalls[part.ID] = &ToolCallContent{
1219				ToolCallID:       part.ID,
1220				ToolName:         part.ToolCallName,
1221				Input:            "",
1222				ProviderExecuted: part.ProviderExecuted,
1223			}
1224			if opts.OnToolInputStart != nil {
1225				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1226				if err != nil {
1227					return stepExecutionResult{}, err
1228				}
1229			}
1230
1231		case StreamPartTypeToolInputDelta:
1232			if toolCall, exists := activeToolCalls[part.ID]; exists {
1233				toolCall.Input += part.Delta
1234			}
1235			if opts.OnToolInputDelta != nil {
1236				err := opts.OnToolInputDelta(part.ID, part.Delta)
1237				if err != nil {
1238					return stepExecutionResult{}, err
1239				}
1240			}
1241
1242		case StreamPartTypeToolInputEnd:
1243			if opts.OnToolInputEnd != nil {
1244				err := opts.OnToolInputEnd(part.ID)
1245				if err != nil {
1246					return stepExecutionResult{}, err
1247				}
1248			}
1249
1250		case StreamPartTypeToolCall:
1251			toolCall := ToolCallContent{
1252				ToolCallID:       part.ID,
1253				ToolName:         part.ToolCallName,
1254				Input:            part.ToolCallInput,
1255				ProviderExecuted: part.ProviderExecuted,
1256				ProviderMetadata: part.ProviderMetadata,
1257			}
1258
1259			// Validate and potentially repair the tool call
1260			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1261			stepToolCalls = append(stepToolCalls, validatedToolCall)
1262			stepContent = append(stepContent, validatedToolCall)
1263
1264			if opts.OnToolCall != nil {
1265				err := opts.OnToolCall(validatedToolCall)
1266				if err != nil {
1267					return stepExecutionResult{}, err
1268				}
1269			}
1270
1271			// Clean up active tool call
1272			delete(activeToolCalls, part.ID)
1273
1274		case StreamPartTypeSource:
1275			sourceContent := SourceContent{
1276				SourceType:       part.SourceType,
1277				ID:               part.ID,
1278				URL:              part.URL,
1279				Title:            part.Title,
1280				ProviderMetadata: part.ProviderMetadata,
1281			}
1282			stepContent = append(stepContent, sourceContent)
1283			if opts.OnSource != nil {
1284				err := opts.OnSource(sourceContent)
1285				if err != nil {
1286					return stepExecutionResult{}, err
1287				}
1288			}
1289
1290		case StreamPartTypeFinish:
1291			stepUsage = part.Usage
1292			stepFinishReason = part.FinishReason
1293			stepProviderMetadata = part.ProviderMetadata
1294			if opts.OnStreamFinish != nil {
1295				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1296				if err != nil {
1297					return stepExecutionResult{}, err
1298				}
1299			}
1300
1301		case StreamPartTypeError:
1302			return stepExecutionResult{}, part.Error
1303		}
1304	}
1305
1306	// Execute tools if any
1307	var toolResults []ToolResultContent
1308	if len(stepToolCalls) > 0 {
1309		var err error
1310		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1311		if err != nil {
1312			return stepExecutionResult{}, err
1313		}
1314		// Add tool results to content
1315		for _, result := range toolResults {
1316			stepContent = append(stepContent, result)
1317		}
1318	}
1319
1320	stepResult := StepResult{
1321		Response: Response{
1322			Content:          stepContent,
1323			FinishReason:     stepFinishReason,
1324			Usage:            stepUsage,
1325			Warnings:         stepWarnings,
1326			ProviderMetadata: stepProviderMetadata,
1327		},
1328		Messages: toResponseMessages(stepContent),
1329	}
1330
1331	// Determine if we should continue (has tool calls and not stopped)
1332	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1333
1334	return stepExecutionResult{
1335		StepResult:     stepResult,
1336		ShouldContinue: shouldContinue,
1337	}, nil
1338}
1339
1340func addUsage(a, b Usage) Usage {
1341	return Usage{
1342		InputTokens:         a.InputTokens + b.InputTokens,
1343		OutputTokens:        a.OutputTokens + b.OutputTokens,
1344		TotalTokens:         a.TotalTokens + b.TotalTokens,
1345		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1346		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1347		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1348	}
1349}
1350
1351// WithHeaders sets the headers for the agent.
1352func WithHeaders(headers map[string]string) AgentOption {
1353	return func(s *agentSettings) {
1354		s.headers = headers
1355	}
1356}
1357
1358// WithProviderOptions sets the provider options for the agent.
1359func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1360	return func(s *agentSettings) {
1361		s.providerOptions = providerOptions
1362	}
1363}