agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11)
  12
  13// StepResult represents the result of a single step in an agent execution.
  14type StepResult struct {
  15	Response
  16	Messages []Message
  17}
  18
  19// StopCondition defines a function that determines when an agent should stop executing.
  20type StopCondition = func(steps []StepResult) bool
  21
  22// StepCountIs returns a stop condition that stops after the specified number of steps.
  23func StepCountIs(stepCount int) StopCondition {
  24	return func(steps []StepResult) bool {
  25		return len(steps) >= stepCount
  26	}
  27}
  28
  29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  30func HasToolCall(toolName string) StopCondition {
  31	return func(steps []StepResult) bool {
  32		if len(steps) == 0 {
  33			return false
  34		}
  35		lastStep := steps[len(steps)-1]
  36		toolCalls := lastStep.Content.ToolCalls()
  37		for _, toolCall := range toolCalls {
  38			if toolCall.ToolName == toolName {
  39				return true
  40			}
  41		}
  42		return false
  43	}
  44}
  45
  46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  47func HasContent(contentType ContentType) StopCondition {
  48	return func(steps []StepResult) bool {
  49		if len(steps) == 0 {
  50			return false
  51		}
  52		lastStep := steps[len(steps)-1]
  53		for _, content := range lastStep.Content {
  54			if content.GetType() == contentType {
  55				return true
  56			}
  57		}
  58		return false
  59	}
  60}
  61
  62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  63func FinishReasonIs(reason FinishReason) StopCondition {
  64	return func(steps []StepResult) bool {
  65		if len(steps) == 0 {
  66			return false
  67		}
  68		lastStep := steps[len(steps)-1]
  69		return lastStep.FinishReason == reason
  70	}
  71}
  72
  73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  74func MaxTokensUsed(maxTokens int64) StopCondition {
  75	return func(steps []StepResult) bool {
  76		var totalTokens int64
  77		for _, step := range steps {
  78			totalTokens += step.Usage.TotalTokens
  79		}
  80		return totalTokens >= maxTokens
  81	}
  82}
  83
  84// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  85type PrepareStepFunctionOptions struct {
  86	Steps      []StepResult
  87	StepNumber int
  88	Model      LanguageModel
  89	Messages   []Message
  90}
  91
  92// PrepareStepResult contains the result of preparing a step in an agent execution.
  93type PrepareStepResult struct {
  94	Model           LanguageModel
  95	Messages        []Message
  96	System          *string
  97	ToolChoice      *ToolChoice
  98	ActiveTools     []string
  99	DisableAllTools bool
 100}
 101
 102// ToolCallRepairOptions contains the options for repairing a tool call.
 103type ToolCallRepairOptions struct {
 104	OriginalToolCall ToolCallContent
 105	ValidationError  error
 106	AvailableTools   []AgentTool
 107	SystemPrompt     string
 108	Messages         []Message
 109}
 110
 111type (
 112	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 113	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 114
 115	// OnStepFinishedFunction defines a function that is called when a step finishes.
 116	OnStepFinishedFunction = func(step StepResult)
 117
 118	// RepairToolCallFunction defines a function that repairs a tool call.
 119	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 120)
 121
 122type agentSettings struct {
 123	systemPrompt     string
 124	maxOutputTokens  *int64
 125	temperature      *float64
 126	topP             *float64
 127	topK             *int64
 128	presencePenalty  *float64
 129	frequencyPenalty *float64
 130	headers          map[string]string
 131	providerOptions  ProviderOptions
 132
 133	// TODO: add support for provider tools
 134	tools      []AgentTool
 135	maxRetries *int
 136
 137	model LanguageModel
 138
 139	stopWhen       []StopCondition
 140	prepareStep    PrepareStepFunction
 141	repairToolCall RepairToolCallFunction
 142	onRetry        OnRetryCallback
 143}
 144
 145// AgentCall represents a call to an agent.
 146type AgentCall struct {
 147	Prompt           string     `json:"prompt"`
 148	Files            []FilePart `json:"files"`
 149	Messages         []Message  `json:"messages"`
 150	MaxOutputTokens  *int64
 151	Temperature      *float64 `json:"temperature"`
 152	TopP             *float64 `json:"top_p"`
 153	TopK             *int64   `json:"top_k"`
 154	PresencePenalty  *float64 `json:"presence_penalty"`
 155	FrequencyPenalty *float64 `json:"frequency_penalty"`
 156	ActiveTools      []string `json:"active_tools"`
 157	ProviderOptions  ProviderOptions
 158	OnRetry          OnRetryCallback
 159	MaxRetries       *int
 160
 161	StopWhen       []StopCondition
 162	PrepareStep    PrepareStepFunction
 163	RepairToolCall RepairToolCallFunction
 164}
 165
 166// Agent-level callbacks.
 167type (
 168	// OnAgentStartFunc is called when agent starts.
 169	OnAgentStartFunc func()
 170
 171	// OnAgentFinishFunc is called when agent finishes.
 172	OnAgentFinishFunc func(result *AgentResult) error
 173
 174	// OnStepStartFunc is called when a step starts.
 175	OnStepStartFunc func(stepNumber int) error
 176
 177	// OnStepFinishFunc is called when a step finishes.
 178	OnStepFinishFunc func(stepResult StepResult) error
 179
 180	// OnFinishFunc is called when entire agent completes.
 181	OnFinishFunc func(result *AgentResult)
 182
 183	// OnErrorFunc is called when an error occurs.
 184	OnErrorFunc func(error)
 185)
 186
 187// Stream part callbacks - called for each corresponding stream part type.
 188type (
 189	// OnChunkFunc is called for each stream part (catch-all).
 190	OnChunkFunc func(StreamPart) error
 191
 192	// OnWarningsFunc is called for warnings.
 193	OnWarningsFunc func(warnings []CallWarning) error
 194
 195	// OnTextStartFunc is called when text starts.
 196	OnTextStartFunc func(id string) error
 197
 198	// OnTextDeltaFunc is called for text deltas.
 199	OnTextDeltaFunc func(id, text string) error
 200
 201	// OnTextEndFunc is called when text ends.
 202	OnTextEndFunc func(id string) error
 203
 204	// OnReasoningStartFunc is called when reasoning starts.
 205	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 206
 207	// OnReasoningDeltaFunc is called for reasoning deltas.
 208	OnReasoningDeltaFunc func(id, text string) error
 209
 210	// OnReasoningEndFunc is called when reasoning ends.
 211	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 212
 213	// OnToolInputStartFunc is called when tool input starts.
 214	OnToolInputStartFunc func(id, toolName string) error
 215
 216	// OnToolInputDeltaFunc is called for tool input deltas.
 217	OnToolInputDeltaFunc func(id, delta string) error
 218
 219	// OnToolInputEndFunc is called when tool input ends.
 220	OnToolInputEndFunc func(id string) error
 221
 222	// OnToolCallFunc is called when tool call is complete.
 223	OnToolCallFunc func(toolCall ToolCallContent) error
 224
 225	// OnToolResultFunc is called when tool execution completes.
 226	OnToolResultFunc func(result ToolResultContent) error
 227
 228	// OnSourceFunc is called for source references.
 229	OnSourceFunc func(source SourceContent) error
 230
 231	// OnStreamFinishFunc is called when stream finishes.
 232	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 233)
 234
 235// AgentStreamCall represents a streaming call to an agent.
 236type AgentStreamCall struct {
 237	Prompt           string     `json:"prompt"`
 238	Files            []FilePart `json:"files"`
 239	Messages         []Message  `json:"messages"`
 240	MaxOutputTokens  *int64
 241	Temperature      *float64 `json:"temperature"`
 242	TopP             *float64 `json:"top_p"`
 243	TopK             *int64   `json:"top_k"`
 244	PresencePenalty  *float64 `json:"presence_penalty"`
 245	FrequencyPenalty *float64 `json:"frequency_penalty"`
 246	ActiveTools      []string `json:"active_tools"`
 247	Headers          map[string]string
 248	ProviderOptions  ProviderOptions
 249	OnRetry          OnRetryCallback
 250	MaxRetries       *int
 251
 252	StopWhen       []StopCondition
 253	PrepareStep    PrepareStepFunction
 254	RepairToolCall RepairToolCallFunction
 255
 256	// Agent-level callbacks
 257	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 258	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 259	OnStepStart   OnStepStartFunc   // Called when a step starts
 260	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 261	OnFinish      OnFinishFunc      // Called when entire agent completes
 262	OnError       OnErrorFunc       // Called when an error occurs
 263
 264	// Stream part callbacks - called for each corresponding stream part type
 265	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 266	OnWarnings       OnWarningsFunc       // Called for warnings
 267	OnTextStart      OnTextStartFunc      // Called when text starts
 268	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 269	OnTextEnd        OnTextEndFunc        // Called when text ends
 270	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 271	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 272	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 273	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 274	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 275	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 276	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 277	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 278	OnSource         OnSourceFunc         // Called for source references
 279	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 280}
 281
 282// AgentResult represents the result of an agent execution.
 283type AgentResult struct {
 284	Steps []StepResult
 285	// Final response
 286	Response   Response
 287	TotalUsage Usage
 288}
 289
 290// Agent represents an AI agent that can generate responses and stream responses.
 291type Agent interface {
 292	Generate(context.Context, AgentCall) (*AgentResult, error)
 293	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 294}
 295
 296// AgentOption defines a function that configures agent settings.
 297type AgentOption = func(*agentSettings)
 298
 299type agent struct {
 300	settings agentSettings
 301}
 302
 303// NewAgent creates a new agent with the given language model and options.
 304func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 305	settings := agentSettings{
 306		model: model,
 307	}
 308	for _, o := range opts {
 309		o(&settings)
 310	}
 311	return &agent{
 312		settings: settings,
 313	}
 314}
 315
 316func (a *agent) prepareCall(call AgentCall) AgentCall {
 317	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 318	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 319	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 320	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 321	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 322	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 323	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 324
 325	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 326		call.StopWhen = a.settings.stopWhen
 327	}
 328	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 329		call.PrepareStep = a.settings.prepareStep
 330	}
 331	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 332		call.RepairToolCall = a.settings.repairToolCall
 333	}
 334	if call.OnRetry == nil && a.settings.onRetry != nil {
 335		call.OnRetry = a.settings.onRetry
 336	}
 337
 338	providerOptions := ProviderOptions{}
 339	if a.settings.providerOptions != nil {
 340		maps.Copy(providerOptions, a.settings.providerOptions)
 341	}
 342	if call.ProviderOptions != nil {
 343		maps.Copy(providerOptions, call.ProviderOptions)
 344	}
 345	call.ProviderOptions = providerOptions
 346
 347	headers := map[string]string{}
 348
 349	if a.settings.headers != nil {
 350		maps.Copy(headers, a.settings.headers)
 351	}
 352
 353	return call
 354}
 355
 356// Generate implements Agent.
 357func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 358	opts = a.prepareCall(opts)
 359	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 360	if err != nil {
 361		return nil, err
 362	}
 363	var responseMessages []Message
 364	var steps []StepResult
 365
 366	for {
 367		stepInputMessages := append(initialPrompt, responseMessages...)
 368		stepModel := a.settings.model
 369		stepSystemPrompt := a.settings.systemPrompt
 370		stepActiveTools := opts.ActiveTools
 371		stepToolChoice := ToolChoiceAuto
 372		disableAllTools := false
 373
 374		if opts.PrepareStep != nil {
 375			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 376				Model:      stepModel,
 377				Steps:      steps,
 378				StepNumber: len(steps),
 379				Messages:   stepInputMessages,
 380			})
 381			if err != nil {
 382				return nil, err
 383			}
 384
 385			ctx = updatedCtx
 386
 387			// Apply prepared step modifications
 388			if prepared.Messages != nil {
 389				stepInputMessages = prepared.Messages
 390			}
 391			if prepared.Model != nil {
 392				stepModel = prepared.Model
 393			}
 394			if prepared.System != nil {
 395				stepSystemPrompt = *prepared.System
 396			}
 397			if prepared.ToolChoice != nil {
 398				stepToolChoice = *prepared.ToolChoice
 399			}
 400			if len(prepared.ActiveTools) > 0 {
 401				stepActiveTools = prepared.ActiveTools
 402			}
 403			disableAllTools = prepared.DisableAllTools
 404		}
 405
 406		// Recreate prompt with potentially modified system prompt
 407		if stepSystemPrompt != a.settings.systemPrompt {
 408			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 409			if err != nil {
 410				return nil, err
 411			}
 412			// Replace system message part, keep the rest
 413			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 414				stepInputMessages[0] = stepPrompt[0] // Replace system message
 415			}
 416		}
 417
 418		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 419
 420		retryOptions := DefaultRetryOptions()
 421		if opts.MaxRetries != nil {
 422			retryOptions.MaxRetries = *opts.MaxRetries
 423		}
 424		retryOptions.OnRetry = opts.OnRetry
 425		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 426
 427		result, err := retry(ctx, func() (*Response, error) {
 428			return stepModel.Generate(ctx, Call{
 429				Prompt:           stepInputMessages,
 430				MaxOutputTokens:  opts.MaxOutputTokens,
 431				Temperature:      opts.Temperature,
 432				TopP:             opts.TopP,
 433				TopK:             opts.TopK,
 434				PresencePenalty:  opts.PresencePenalty,
 435				FrequencyPenalty: opts.FrequencyPenalty,
 436				Tools:            preparedTools,
 437				ToolChoice:       &stepToolChoice,
 438				ProviderOptions:  opts.ProviderOptions,
 439			})
 440		})
 441		if err != nil {
 442			return nil, err
 443		}
 444
 445		var stepToolCalls []ToolCallContent
 446		for _, content := range result.Content {
 447			if content.GetType() == ContentTypeToolCall {
 448				toolCall, ok := AsContentType[ToolCallContent](content)
 449				if !ok {
 450					continue
 451				}
 452
 453				// Validate and potentially repair the tool call
 454				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 455				stepToolCalls = append(stepToolCalls, validatedToolCall)
 456			}
 457		}
 458
 459		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 460
 461		// Build step content with validated tool calls and tool results
 462		stepContent := []Content{}
 463		toolCallIndex := 0
 464		for _, content := range result.Content {
 465			if content.GetType() == ContentTypeToolCall {
 466				// Replace with validated tool call
 467				if toolCallIndex < len(stepToolCalls) {
 468					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 469					toolCallIndex++
 470				}
 471			} else {
 472				// Keep other content as-is
 473				stepContent = append(stepContent, content)
 474			}
 475		}
 476		// Add tool results
 477		for _, result := range toolResults {
 478			stepContent = append(stepContent, result)
 479		}
 480		currentStepMessages := toResponseMessages(stepContent)
 481		responseMessages = append(responseMessages, currentStepMessages...)
 482
 483		stepResult := StepResult{
 484			Response: Response{
 485				Content:          stepContent,
 486				FinishReason:     result.FinishReason,
 487				Usage:            result.Usage,
 488				Warnings:         result.Warnings,
 489				ProviderMetadata: result.ProviderMetadata,
 490			},
 491			Messages: currentStepMessages,
 492		}
 493		steps = append(steps, stepResult)
 494		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 495
 496		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 497			break
 498		}
 499	}
 500
 501	totalUsage := Usage{}
 502
 503	for _, step := range steps {
 504		usage := step.Usage
 505		totalUsage.InputTokens += usage.InputTokens
 506		totalUsage.OutputTokens += usage.OutputTokens
 507		totalUsage.ReasoningTokens += usage.ReasoningTokens
 508		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 509		totalUsage.CacheReadTokens += usage.CacheReadTokens
 510		totalUsage.TotalTokens += usage.TotalTokens
 511	}
 512
 513	agentResult := &AgentResult{
 514		Steps:      steps,
 515		Response:   steps[len(steps)-1].Response,
 516		TotalUsage: totalUsage,
 517	}
 518	return agentResult, nil
 519}
 520
 521func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 522	if len(conditions) == 0 {
 523		return false
 524	}
 525
 526	for _, condition := range conditions {
 527		if condition(steps) {
 528			return true
 529		}
 530	}
 531	return false
 532}
 533
 534func toResponseMessages(content []Content) []Message {
 535	var assistantParts []MessagePart
 536	var toolParts []MessagePart
 537
 538	for _, c := range content {
 539		switch c.GetType() {
 540		case ContentTypeText:
 541			text, ok := AsContentType[TextContent](c)
 542			if !ok {
 543				continue
 544			}
 545			assistantParts = append(assistantParts, TextPart{
 546				Text:            text.Text,
 547				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 548			})
 549		case ContentTypeReasoning:
 550			reasoning, ok := AsContentType[ReasoningContent](c)
 551			if !ok {
 552				continue
 553			}
 554			assistantParts = append(assistantParts, ReasoningPart{
 555				Text:            reasoning.Text,
 556				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 557			})
 558		case ContentTypeToolCall:
 559			toolCall, ok := AsContentType[ToolCallContent](c)
 560			if !ok {
 561				continue
 562			}
 563			assistantParts = append(assistantParts, ToolCallPart{
 564				ToolCallID:       toolCall.ToolCallID,
 565				ToolName:         toolCall.ToolName,
 566				Input:            toolCall.Input,
 567				ProviderExecuted: toolCall.ProviderExecuted,
 568				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 569			})
 570		case ContentTypeFile:
 571			file, ok := AsContentType[FileContent](c)
 572			if !ok {
 573				continue
 574			}
 575			assistantParts = append(assistantParts, FilePart{
 576				Data:            file.Data,
 577				MediaType:       file.MediaType,
 578				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 579			})
 580		case ContentTypeSource:
 581			// Sources are metadata about references used to generate the response.
 582			// They don't need to be included in the conversation messages.
 583			continue
 584		case ContentTypeToolResult:
 585			result, ok := AsContentType[ToolResultContent](c)
 586			if !ok {
 587				continue
 588			}
 589			toolParts = append(toolParts, ToolResultPart{
 590				ToolCallID:      result.ToolCallID,
 591				Output:          result.Result,
 592				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 593			})
 594		}
 595	}
 596
 597	var messages []Message
 598	if len(assistantParts) > 0 {
 599		messages = append(messages, Message{
 600			Role:    MessageRoleAssistant,
 601			Content: assistantParts,
 602		})
 603	}
 604	if len(toolParts) > 0 {
 605		messages = append(messages, Message{
 606			Role:    MessageRoleTool,
 607			Content: toolParts,
 608		})
 609	}
 610	return messages
 611}
 612
 613func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 614	if len(toolCalls) == 0 {
 615		return nil, nil
 616	}
 617
 618	// Create a map for quick tool lookup
 619	toolMap := make(map[string]AgentTool)
 620	for _, tool := range allTools {
 621		toolMap[tool.Info().Name] = tool
 622	}
 623
 624	// Execute all tool calls sequentially in order
 625	results := make([]ToolResultContent, 0, len(toolCalls))
 626
 627	for _, toolCall := range toolCalls {
 628		// Skip invalid tool calls - create error result
 629		if toolCall.Invalid {
 630			result := ToolResultContent{
 631				ToolCallID: toolCall.ToolCallID,
 632				ToolName:   toolCall.ToolName,
 633				Result: ToolResultOutputContentError{
 634					Error: toolCall.ValidationError,
 635				},
 636				ProviderExecuted: false,
 637			}
 638			results = append(results, result)
 639			if toolResultCallback != nil {
 640				if err := toolResultCallback(result); err != nil {
 641					return nil, err
 642				}
 643			}
 644			continue
 645		}
 646
 647		tool, exists := toolMap[toolCall.ToolName]
 648		if !exists {
 649			result := ToolResultContent{
 650				ToolCallID: toolCall.ToolCallID,
 651				ToolName:   toolCall.ToolName,
 652				Result: ToolResultOutputContentError{
 653					Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
 654				},
 655				ProviderExecuted: false,
 656			}
 657			results = append(results, result)
 658			if toolResultCallback != nil {
 659				if err := toolResultCallback(result); err != nil {
 660					return nil, err
 661				}
 662			}
 663			continue
 664		}
 665
 666		// Execute the tool
 667		toolResult, err := tool.Run(ctx, ToolCall{
 668			ID:    toolCall.ToolCallID,
 669			Name:  toolCall.ToolName,
 670			Input: toolCall.Input,
 671		})
 672		if err != nil {
 673			result := ToolResultContent{
 674				ToolCallID: toolCall.ToolCallID,
 675				ToolName:   toolCall.ToolName,
 676				Result: ToolResultOutputContentError{
 677					Error: err,
 678				},
 679				ClientMetadata:   toolResult.Metadata,
 680				ProviderExecuted: false,
 681			}
 682			if toolResultCallback != nil {
 683				if cbErr := toolResultCallback(result); cbErr != nil {
 684					return nil, cbErr
 685				}
 686			}
 687			return nil, err
 688		}
 689
 690		var result ToolResultContent
 691		if toolResult.IsError {
 692			result = ToolResultContent{
 693				ToolCallID: toolCall.ToolCallID,
 694				ToolName:   toolCall.ToolName,
 695				Result: ToolResultOutputContentError{
 696					Error: errors.New(toolResult.Content),
 697				},
 698				ClientMetadata:   toolResult.Metadata,
 699				ProviderExecuted: false,
 700			}
 701		} else {
 702			result = ToolResultContent{
 703				ToolCallID: toolCall.ToolCallID,
 704				ToolName:   toolCall.ToolName,
 705				Result: ToolResultOutputContentText{
 706					Text: toolResult.Content,
 707				},
 708				ClientMetadata:   toolResult.Metadata,
 709				ProviderExecuted: false,
 710			}
 711		}
 712		results = append(results, result)
 713		if toolResultCallback != nil {
 714			if err := toolResultCallback(result); err != nil {
 715				return nil, err
 716			}
 717		}
 718	}
 719
 720	return results, nil
 721}
 722
 723// Stream implements Agent.
 724func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 725	// Convert AgentStreamCall to AgentCall for preparation
 726	call := AgentCall{
 727		Prompt:           opts.Prompt,
 728		Files:            opts.Files,
 729		Messages:         opts.Messages,
 730		MaxOutputTokens:  opts.MaxOutputTokens,
 731		Temperature:      opts.Temperature,
 732		TopP:             opts.TopP,
 733		TopK:             opts.TopK,
 734		PresencePenalty:  opts.PresencePenalty,
 735		FrequencyPenalty: opts.FrequencyPenalty,
 736		ActiveTools:      opts.ActiveTools,
 737		ProviderOptions:  opts.ProviderOptions,
 738		MaxRetries:       opts.MaxRetries,
 739		StopWhen:         opts.StopWhen,
 740		PrepareStep:      opts.PrepareStep,
 741		RepairToolCall:   opts.RepairToolCall,
 742	}
 743
 744	call = a.prepareCall(call)
 745
 746	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 747	if err != nil {
 748		return nil, err
 749	}
 750
 751	var responseMessages []Message
 752	var steps []StepResult
 753	var totalUsage Usage
 754
 755	// Start agent stream
 756	if opts.OnAgentStart != nil {
 757		opts.OnAgentStart()
 758	}
 759
 760	for stepNumber := 0; ; stepNumber++ {
 761		stepInputMessages := append(initialPrompt, responseMessages...)
 762		stepModel := a.settings.model
 763		stepSystemPrompt := a.settings.systemPrompt
 764		stepActiveTools := call.ActiveTools
 765		stepToolChoice := ToolChoiceAuto
 766		disableAllTools := false
 767
 768		// Apply step preparation if provided
 769		if call.PrepareStep != nil {
 770			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 771				Model:      stepModel,
 772				Steps:      steps,
 773				StepNumber: stepNumber,
 774				Messages:   stepInputMessages,
 775			})
 776			if err != nil {
 777				return nil, err
 778			}
 779
 780			ctx = updatedCtx
 781
 782			if prepared.Messages != nil {
 783				stepInputMessages = prepared.Messages
 784			}
 785			if prepared.Model != nil {
 786				stepModel = prepared.Model
 787			}
 788			if prepared.System != nil {
 789				stepSystemPrompt = *prepared.System
 790			}
 791			if prepared.ToolChoice != nil {
 792				stepToolChoice = *prepared.ToolChoice
 793			}
 794			if len(prepared.ActiveTools) > 0 {
 795				stepActiveTools = prepared.ActiveTools
 796			}
 797			disableAllTools = prepared.DisableAllTools
 798		}
 799
 800		// Recreate prompt with potentially modified system prompt
 801		if stepSystemPrompt != a.settings.systemPrompt {
 802			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 803			if err != nil {
 804				return nil, err
 805			}
 806			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 807				stepInputMessages[0] = stepPrompt[0]
 808			}
 809		}
 810
 811		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 812
 813		// Start step stream
 814		if opts.OnStepStart != nil {
 815			_ = opts.OnStepStart(stepNumber)
 816		}
 817
 818		// Create streaming call
 819		streamCall := Call{
 820			Prompt:           stepInputMessages,
 821			MaxOutputTokens:  call.MaxOutputTokens,
 822			Temperature:      call.Temperature,
 823			TopP:             call.TopP,
 824			TopK:             call.TopK,
 825			PresencePenalty:  call.PresencePenalty,
 826			FrequencyPenalty: call.FrequencyPenalty,
 827			Tools:            preparedTools,
 828			ToolChoice:       &stepToolChoice,
 829			ProviderOptions:  call.ProviderOptions,
 830		}
 831
 832		// Get streaming response with retry logic
 833		retryOptions := DefaultRetryOptions()
 834		if call.MaxRetries != nil {
 835			retryOptions.MaxRetries = *call.MaxRetries
 836		}
 837		retryOptions.OnRetry = call.OnRetry
 838		retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions)
 839
 840		stream, err := retry(ctx, func() (StreamResponse, error) {
 841			return stepModel.Stream(ctx, streamCall)
 842		})
 843		if err != nil {
 844			if opts.OnError != nil {
 845				opts.OnError(err)
 846			}
 847			return nil, err
 848		}
 849
 850		// Process stream with tool execution
 851		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 852		if err != nil {
 853			if opts.OnError != nil {
 854				opts.OnError(err)
 855			}
 856			return nil, err
 857		}
 858
 859		steps = append(steps, stepResult)
 860		totalUsage = addUsage(totalUsage, stepResult.Usage)
 861
 862		// Call step finished callback
 863		if opts.OnStepFinish != nil {
 864			_ = opts.OnStepFinish(stepResult)
 865		}
 866
 867		// Add step messages to response messages
 868		stepMessages := toResponseMessages(stepResult.Content)
 869		responseMessages = append(responseMessages, stepMessages...)
 870
 871		// Check stop conditions
 872		shouldStop := isStopConditionMet(call.StopWhen, steps)
 873		if shouldStop || !shouldContinue {
 874			break
 875		}
 876	}
 877
 878	// Finish agent stream
 879	agentResult := &AgentResult{
 880		Steps:      steps,
 881		Response:   steps[len(steps)-1].Response,
 882		TotalUsage: totalUsage,
 883	}
 884
 885	if opts.OnFinish != nil {
 886		opts.OnFinish(agentResult)
 887	}
 888
 889	if opts.OnAgentFinish != nil {
 890		_ = opts.OnAgentFinish(agentResult)
 891	}
 892
 893	return agentResult, nil
 894}
 895
 896func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 897	preparedTools := make([]Tool, 0, len(tools))
 898
 899	// If explicitly disabling all tools, return no tools
 900	if disableAllTools {
 901		return preparedTools
 902	}
 903
 904	for _, tool := range tools {
 905		// If activeTools has items, only include tools in the list
 906		// If activeTools is empty, include all tools
 907		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 908			continue
 909		}
 910		info := tool.Info()
 911		preparedTools = append(preparedTools, FunctionTool{
 912			Name:        info.Name,
 913			Description: info.Description,
 914			InputSchema: map[string]any{
 915				"type":       "object",
 916				"properties": info.Parameters,
 917				"required":   info.Required,
 918			},
 919			ProviderOptions: tool.ProviderOptions(),
 920		})
 921	}
 922	return preparedTools
 923}
 924
 925// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 926func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 927	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 928		return toolCall
 929	} else { //nolint: revive
 930		if repairFunc != nil {
 931			repairOptions := ToolCallRepairOptions{
 932				OriginalToolCall: toolCall,
 933				ValidationError:  err,
 934				AvailableTools:   availableTools,
 935				SystemPrompt:     systemPrompt,
 936				Messages:         messages,
 937			}
 938
 939			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 940				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 941					return *repairedToolCall
 942				}
 943			}
 944		}
 945
 946		invalidToolCall := toolCall
 947		invalidToolCall.Invalid = true
 948		invalidToolCall.ValidationError = err
 949		return invalidToolCall
 950	}
 951}
 952
 953// validateToolCall validates a tool call against available tools and their schemas.
 954func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 955	var tool AgentTool
 956	for _, t := range availableTools {
 957		if t.Info().Name == toolCall.ToolName {
 958			tool = t
 959			break
 960		}
 961	}
 962
 963	if tool == nil {
 964		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 965	}
 966
 967	// Validate JSON parsing
 968	var input map[string]any
 969	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 970		return fmt.Errorf("invalid JSON input: %w", err)
 971	}
 972
 973	// Basic schema validation (check required fields)
 974	// TODO: more robust schema validation using JSON Schema or similar
 975	toolInfo := tool.Info()
 976	for _, required := range toolInfo.Required {
 977		if _, exists := input[required]; !exists {
 978			return fmt.Errorf("missing required parameter: %s", required)
 979		}
 980	}
 981	return nil
 982}
 983
 984func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 985	if prompt == "" {
 986		return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
 987	}
 988
 989	var preparedPrompt Prompt
 990
 991	if system != "" {
 992		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 993	}
 994	preparedPrompt = append(preparedPrompt, messages...)
 995	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 996	return preparedPrompt, nil
 997}
 998
 999// WithSystemPrompt sets the system prompt for the agent.
1000func WithSystemPrompt(prompt string) AgentOption {
1001	return func(s *agentSettings) {
1002		s.systemPrompt = prompt
1003	}
1004}
1005
1006// WithMaxOutputTokens sets the maximum output tokens for the agent.
1007func WithMaxOutputTokens(tokens int64) AgentOption {
1008	return func(s *agentSettings) {
1009		s.maxOutputTokens = &tokens
1010	}
1011}
1012
1013// WithTemperature sets the temperature for the agent.
1014func WithTemperature(temp float64) AgentOption {
1015	return func(s *agentSettings) {
1016		s.temperature = &temp
1017	}
1018}
1019
1020// WithTopP sets the top-p value for the agent.
1021func WithTopP(topP float64) AgentOption {
1022	return func(s *agentSettings) {
1023		s.topP = &topP
1024	}
1025}
1026
1027// WithTopK sets the top-k value for the agent.
1028func WithTopK(topK int64) AgentOption {
1029	return func(s *agentSettings) {
1030		s.topK = &topK
1031	}
1032}
1033
1034// WithPresencePenalty sets the presence penalty for the agent.
1035func WithPresencePenalty(penalty float64) AgentOption {
1036	return func(s *agentSettings) {
1037		s.presencePenalty = &penalty
1038	}
1039}
1040
1041// WithFrequencyPenalty sets the frequency penalty for the agent.
1042func WithFrequencyPenalty(penalty float64) AgentOption {
1043	return func(s *agentSettings) {
1044		s.frequencyPenalty = &penalty
1045	}
1046}
1047
1048// WithTools sets the tools for the agent.
1049func WithTools(tools ...AgentTool) AgentOption {
1050	return func(s *agentSettings) {
1051		s.tools = append(s.tools, tools...)
1052	}
1053}
1054
1055// WithStopConditions sets the stop conditions for the agent.
1056func WithStopConditions(conditions ...StopCondition) AgentOption {
1057	return func(s *agentSettings) {
1058		s.stopWhen = append(s.stopWhen, conditions...)
1059	}
1060}
1061
1062// WithPrepareStep sets the prepare step function for the agent.
1063func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1064	return func(s *agentSettings) {
1065		s.prepareStep = fn
1066	}
1067}
1068
1069// WithRepairToolCall sets the repair tool call function for the agent.
1070func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1071	return func(s *agentSettings) {
1072		s.repairToolCall = fn
1073	}
1074}
1075
1076// WithMaxRetries sets the maximum number of retries for the agent.
1077func WithMaxRetries(maxRetries int) AgentOption {
1078	return func(s *agentSettings) {
1079		s.maxRetries = &maxRetries
1080	}
1081}
1082
1083// WithOnRetry sets the retry callback for the agent.
1084func WithOnRetry(callback OnRetryCallback) AgentOption {
1085	return func(s *agentSettings) {
1086		s.onRetry = callback
1087	}
1088}
1089
1090// processStepStream processes a single step's stream and returns the step result.
1091func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1092	var stepContent []Content
1093	var stepToolCalls []ToolCallContent
1094	var stepUsage Usage
1095	stepFinishReason := FinishReasonUnknown
1096	var stepWarnings []CallWarning
1097	var stepProviderMetadata ProviderMetadata
1098
1099	activeToolCalls := make(map[string]*ToolCallContent)
1100	activeTextContent := make(map[string]string)
1101	type reasoningContent struct {
1102		content string
1103		options ProviderMetadata
1104	}
1105	activeReasoningContent := make(map[string]reasoningContent)
1106
1107	// Process stream parts
1108	for part := range stream {
1109		// Forward all parts to chunk callback
1110		if opts.OnChunk != nil {
1111			err := opts.OnChunk(part)
1112			if err != nil {
1113				return StepResult{}, false, err
1114			}
1115		}
1116
1117		switch part.Type {
1118		case StreamPartTypeWarnings:
1119			stepWarnings = part.Warnings
1120			if opts.OnWarnings != nil {
1121				err := opts.OnWarnings(part.Warnings)
1122				if err != nil {
1123					return StepResult{}, false, err
1124				}
1125			}
1126
1127		case StreamPartTypeTextStart:
1128			activeTextContent[part.ID] = ""
1129			if opts.OnTextStart != nil {
1130				err := opts.OnTextStart(part.ID)
1131				if err != nil {
1132					return StepResult{}, false, err
1133				}
1134			}
1135
1136		case StreamPartTypeTextDelta:
1137			if _, exists := activeTextContent[part.ID]; exists {
1138				activeTextContent[part.ID] += part.Delta
1139			}
1140			if opts.OnTextDelta != nil {
1141				err := opts.OnTextDelta(part.ID, part.Delta)
1142				if err != nil {
1143					return StepResult{}, false, err
1144				}
1145			}
1146
1147		case StreamPartTypeTextEnd:
1148			if text, exists := activeTextContent[part.ID]; exists {
1149				stepContent = append(stepContent, TextContent{
1150					Text:             text,
1151					ProviderMetadata: part.ProviderMetadata,
1152				})
1153				delete(activeTextContent, part.ID)
1154			}
1155			if opts.OnTextEnd != nil {
1156				err := opts.OnTextEnd(part.ID)
1157				if err != nil {
1158					return StepResult{}, false, err
1159				}
1160			}
1161
1162		case StreamPartTypeReasoningStart:
1163			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1164			if opts.OnReasoningStart != nil {
1165				content := ReasoningContent{
1166					Text:             part.Delta,
1167					ProviderMetadata: part.ProviderMetadata,
1168				}
1169				err := opts.OnReasoningStart(part.ID, content)
1170				if err != nil {
1171					return StepResult{}, false, err
1172				}
1173			}
1174
1175		case StreamPartTypeReasoningDelta:
1176			if active, exists := activeReasoningContent[part.ID]; exists {
1177				active.content += part.Delta
1178				active.options = part.ProviderMetadata
1179				activeReasoningContent[part.ID] = active
1180			}
1181			if opts.OnReasoningDelta != nil {
1182				err := opts.OnReasoningDelta(part.ID, part.Delta)
1183				if err != nil {
1184					return StepResult{}, false, err
1185				}
1186			}
1187
1188		case StreamPartTypeReasoningEnd:
1189			if active, exists := activeReasoningContent[part.ID]; exists {
1190				if part.ProviderMetadata != nil {
1191					active.options = part.ProviderMetadata
1192				}
1193				content := ReasoningContent{
1194					Text:             active.content,
1195					ProviderMetadata: active.options,
1196				}
1197				stepContent = append(stepContent, content)
1198				if opts.OnReasoningEnd != nil {
1199					err := opts.OnReasoningEnd(part.ID, content)
1200					if err != nil {
1201						return StepResult{}, false, err
1202					}
1203				}
1204				delete(activeReasoningContent, part.ID)
1205			}
1206
1207		case StreamPartTypeToolInputStart:
1208			activeToolCalls[part.ID] = &ToolCallContent{
1209				ToolCallID:       part.ID,
1210				ToolName:         part.ToolCallName,
1211				Input:            "",
1212				ProviderExecuted: part.ProviderExecuted,
1213			}
1214			if opts.OnToolInputStart != nil {
1215				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1216				if err != nil {
1217					return StepResult{}, false, err
1218				}
1219			}
1220
1221		case StreamPartTypeToolInputDelta:
1222			if toolCall, exists := activeToolCalls[part.ID]; exists {
1223				toolCall.Input += part.Delta
1224			}
1225			if opts.OnToolInputDelta != nil {
1226				err := opts.OnToolInputDelta(part.ID, part.Delta)
1227				if err != nil {
1228					return StepResult{}, false, err
1229				}
1230			}
1231
1232		case StreamPartTypeToolInputEnd:
1233			if opts.OnToolInputEnd != nil {
1234				err := opts.OnToolInputEnd(part.ID)
1235				if err != nil {
1236					return StepResult{}, false, err
1237				}
1238			}
1239
1240		case StreamPartTypeToolCall:
1241			toolCall := ToolCallContent{
1242				ToolCallID:       part.ID,
1243				ToolName:         part.ToolCallName,
1244				Input:            part.ToolCallInput,
1245				ProviderExecuted: part.ProviderExecuted,
1246				ProviderMetadata: part.ProviderMetadata,
1247			}
1248
1249			// Validate and potentially repair the tool call
1250			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1251			stepToolCalls = append(stepToolCalls, validatedToolCall)
1252			stepContent = append(stepContent, validatedToolCall)
1253
1254			if opts.OnToolCall != nil {
1255				err := opts.OnToolCall(validatedToolCall)
1256				if err != nil {
1257					return StepResult{}, false, err
1258				}
1259			}
1260
1261			// Clean up active tool call
1262			delete(activeToolCalls, part.ID)
1263
1264		case StreamPartTypeSource:
1265			sourceContent := SourceContent{
1266				SourceType:       part.SourceType,
1267				ID:               part.ID,
1268				URL:              part.URL,
1269				Title:            part.Title,
1270				ProviderMetadata: part.ProviderMetadata,
1271			}
1272			stepContent = append(stepContent, sourceContent)
1273			if opts.OnSource != nil {
1274				err := opts.OnSource(sourceContent)
1275				if err != nil {
1276					return StepResult{}, false, err
1277				}
1278			}
1279
1280		case StreamPartTypeFinish:
1281			stepUsage = part.Usage
1282			stepFinishReason = part.FinishReason
1283			stepProviderMetadata = part.ProviderMetadata
1284			if opts.OnStreamFinish != nil {
1285				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1286				if err != nil {
1287					return StepResult{}, false, err
1288				}
1289			}
1290
1291		case StreamPartTypeError:
1292			if opts.OnError != nil {
1293				opts.OnError(part.Error)
1294			}
1295			return StepResult{}, false, part.Error
1296		}
1297	}
1298
1299	// Execute tools if any
1300	var toolResults []ToolResultContent
1301	if len(stepToolCalls) > 0 {
1302		var err error
1303		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1304		if err != nil {
1305			return StepResult{}, false, err
1306		}
1307		// Add tool results to content
1308		for _, result := range toolResults {
1309			stepContent = append(stepContent, result)
1310		}
1311	}
1312
1313	stepResult := StepResult{
1314		Response: Response{
1315			Content:          stepContent,
1316			FinishReason:     stepFinishReason,
1317			Usage:            stepUsage,
1318			Warnings:         stepWarnings,
1319			ProviderMetadata: stepProviderMetadata,
1320		},
1321		Messages: toResponseMessages(stepContent),
1322	}
1323
1324	// Determine if we should continue (has tool calls and not stopped)
1325	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1326
1327	return stepResult, shouldContinue, nil
1328}
1329
1330func addUsage(a, b Usage) Usage {
1331	return Usage{
1332		InputTokens:         a.InputTokens + b.InputTokens,
1333		OutputTokens:        a.OutputTokens + b.OutputTokens,
1334		TotalTokens:         a.TotalTokens + b.TotalTokens,
1335		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1336		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1337		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1338	}
1339}
1340
1341// WithHeaders sets the headers for the agent.
1342func WithHeaders(headers map[string]string) AgentOption {
1343	return func(s *agentSettings) {
1344		s.headers = headers
1345	}
1346}
1347
1348// WithProviderOptions sets the provider options for the agent.
1349func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1350	return func(s *agentSettings) {
1351		s.providerOptions = providerOptions
1352	}
1353}