agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11	"sync"
  12)
  13
  14// StepResult represents the result of a single step in an agent execution.
  15type StepResult struct {
  16	Response
  17	Messages []Message
  18}
  19
  20// StopCondition defines a function that determines when an agent should stop executing.
  21type StopCondition = func(steps []StepResult) bool
  22
  23// StepCountIs returns a stop condition that stops after the specified number of steps.
  24func StepCountIs(stepCount int) StopCondition {
  25	return func(steps []StepResult) bool {
  26		return len(steps) >= stepCount
  27	}
  28}
  29
  30// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  31func HasToolCall(toolName string) StopCondition {
  32	return func(steps []StepResult) bool {
  33		if len(steps) == 0 {
  34			return false
  35		}
  36		lastStep := steps[len(steps)-1]
  37		toolCalls := lastStep.Content.ToolCalls()
  38		for _, toolCall := range toolCalls {
  39			if toolCall.ToolName == toolName {
  40				return true
  41			}
  42		}
  43		return false
  44	}
  45}
  46
  47// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  48func HasContent(contentType ContentType) StopCondition {
  49	return func(steps []StepResult) bool {
  50		if len(steps) == 0 {
  51			return false
  52		}
  53		lastStep := steps[len(steps)-1]
  54		for _, content := range lastStep.Content {
  55			if content.GetType() == contentType {
  56				return true
  57			}
  58		}
  59		return false
  60	}
  61}
  62
  63// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  64func FinishReasonIs(reason FinishReason) StopCondition {
  65	return func(steps []StepResult) bool {
  66		if len(steps) == 0 {
  67			return false
  68		}
  69		lastStep := steps[len(steps)-1]
  70		return lastStep.FinishReason == reason
  71	}
  72}
  73
  74// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  75func MaxTokensUsed(maxTokens int64) StopCondition {
  76	return func(steps []StepResult) bool {
  77		var totalTokens int64
  78		for _, step := range steps {
  79			totalTokens += step.Usage.TotalTokens
  80		}
  81		return totalTokens >= maxTokens
  82	}
  83}
  84
  85// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  86type PrepareStepFunctionOptions struct {
  87	Steps      []StepResult
  88	StepNumber int
  89	Model      LanguageModel
  90	Messages   []Message
  91}
  92
  93// PrepareStepResult contains the result of preparing a step in an agent execution.
  94type PrepareStepResult struct {
  95	Model           LanguageModel
  96	Messages        []Message
  97	System          *string
  98	ToolChoice      *ToolChoice
  99	ActiveTools     []string
 100	DisableAllTools bool
 101}
 102
 103// ToolCallRepairOptions contains the options for repairing a tool call.
 104type ToolCallRepairOptions struct {
 105	OriginalToolCall ToolCallContent
 106	ValidationError  error
 107	AvailableTools   []AgentTool
 108	SystemPrompt     string
 109	Messages         []Message
 110}
 111
 112type (
 113	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 114	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 115
 116	// OnStepFinishedFunction defines a function that is called when a step finishes.
 117	OnStepFinishedFunction = func(step StepResult)
 118
 119	// RepairToolCallFunction defines a function that repairs a tool call.
 120	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 121)
 122
 123type agentSettings struct {
 124	systemPrompt     string
 125	maxOutputTokens  *int64
 126	temperature      *float64
 127	topP             *float64
 128	topK             *int64
 129	presencePenalty  *float64
 130	frequencyPenalty *float64
 131	headers          map[string]string
 132	providerOptions  ProviderOptions
 133
 134	// TODO: add support for provider tools
 135	tools      []AgentTool
 136	maxRetries *int
 137
 138	model LanguageModel
 139
 140	stopWhen       []StopCondition
 141	prepareStep    PrepareStepFunction
 142	repairToolCall RepairToolCallFunction
 143	onRetry        OnRetryCallback
 144}
 145
 146// AgentCall represents a call to an agent.
 147type AgentCall struct {
 148	Prompt           string     `json:"prompt"`
 149	Files            []FilePart `json:"files"`
 150	Messages         []Message  `json:"messages"`
 151	MaxOutputTokens  *int64
 152	Temperature      *float64 `json:"temperature"`
 153	TopP             *float64 `json:"top_p"`
 154	TopK             *int64   `json:"top_k"`
 155	PresencePenalty  *float64 `json:"presence_penalty"`
 156	FrequencyPenalty *float64 `json:"frequency_penalty"`
 157	ActiveTools      []string `json:"active_tools"`
 158	ProviderOptions  ProviderOptions
 159	OnRetry          OnRetryCallback
 160	MaxRetries       *int
 161
 162	StopWhen       []StopCondition
 163	PrepareStep    PrepareStepFunction
 164	RepairToolCall RepairToolCallFunction
 165}
 166
 167// Agent-level callbacks.
 168type (
 169	// OnAgentStartFunc is called when agent starts.
 170	OnAgentStartFunc func()
 171
 172	// OnAgentFinishFunc is called when agent finishes.
 173	OnAgentFinishFunc func(result *AgentResult) error
 174
 175	// OnStepStartFunc is called when a step starts.
 176	OnStepStartFunc func(stepNumber int) error
 177
 178	// OnStepFinishFunc is called when a step finishes.
 179	OnStepFinishFunc func(stepResult StepResult) error
 180
 181	// OnFinishFunc is called when entire agent completes.
 182	OnFinishFunc func(result *AgentResult)
 183
 184	// OnErrorFunc is called when an error occurs.
 185	OnErrorFunc func(error)
 186)
 187
 188// Stream part callbacks - called for each corresponding stream part type.
 189type (
 190	// OnChunkFunc is called for each stream part (catch-all).
 191	OnChunkFunc func(StreamPart) error
 192
 193	// OnWarningsFunc is called for warnings.
 194	OnWarningsFunc func(warnings []CallWarning) error
 195
 196	// OnTextStartFunc is called when text starts.
 197	OnTextStartFunc func(id string) error
 198
 199	// OnTextDeltaFunc is called for text deltas.
 200	OnTextDeltaFunc func(id, text string) error
 201
 202	// OnTextEndFunc is called when text ends.
 203	OnTextEndFunc func(id string) error
 204
 205	// OnReasoningStartFunc is called when reasoning starts.
 206	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 207
 208	// OnReasoningDeltaFunc is called for reasoning deltas.
 209	OnReasoningDeltaFunc func(id, text string) error
 210
 211	// OnReasoningEndFunc is called when reasoning ends.
 212	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 213
 214	// OnToolInputStartFunc is called when tool input starts.
 215	OnToolInputStartFunc func(id, toolName string) error
 216
 217	// OnToolInputDeltaFunc is called for tool input deltas.
 218	OnToolInputDeltaFunc func(id, delta string) error
 219
 220	// OnToolInputEndFunc is called when tool input ends.
 221	OnToolInputEndFunc func(id string) error
 222
 223	// OnToolCallFunc is called when tool call is complete.
 224	OnToolCallFunc func(toolCall ToolCallContent) error
 225
 226	// OnToolResultFunc is called when tool execution completes.
 227	OnToolResultFunc func(result ToolResultContent) error
 228
 229	// OnSourceFunc is called for source references.
 230	OnSourceFunc func(source SourceContent) error
 231
 232	// OnStreamFinishFunc is called when stream finishes.
 233	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 234)
 235
 236// AgentStreamCall represents a streaming call to an agent.
 237type AgentStreamCall struct {
 238	Prompt           string     `json:"prompt"`
 239	Files            []FilePart `json:"files"`
 240	Messages         []Message  `json:"messages"`
 241	MaxOutputTokens  *int64
 242	Temperature      *float64 `json:"temperature"`
 243	TopP             *float64 `json:"top_p"`
 244	TopK             *int64   `json:"top_k"`
 245	PresencePenalty  *float64 `json:"presence_penalty"`
 246	FrequencyPenalty *float64 `json:"frequency_penalty"`
 247	ActiveTools      []string `json:"active_tools"`
 248	Headers          map[string]string
 249	ProviderOptions  ProviderOptions
 250	OnRetry          OnRetryCallback
 251	MaxRetries       *int
 252
 253	StopWhen       []StopCondition
 254	PrepareStep    PrepareStepFunction
 255	RepairToolCall RepairToolCallFunction
 256
 257	// Agent-level callbacks
 258	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 259	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 260	OnStepStart   OnStepStartFunc   // Called when a step starts
 261	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 262	OnFinish      OnFinishFunc      // Called when entire agent completes
 263	OnError       OnErrorFunc       // Called when an error occurs
 264
 265	// Stream part callbacks - called for each corresponding stream part type
 266	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 267	OnWarnings       OnWarningsFunc       // Called for warnings
 268	OnTextStart      OnTextStartFunc      // Called when text starts
 269	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 270	OnTextEnd        OnTextEndFunc        // Called when text ends
 271	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 272	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 273	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 274	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 275	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 276	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 277	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 278	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 279	OnSource         OnSourceFunc         // Called for source references
 280	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 281}
 282
 283// AgentResult represents the result of an agent execution.
 284type AgentResult struct {
 285	Steps []StepResult
 286	// Final response
 287	Response   Response
 288	TotalUsage Usage
 289}
 290
 291// Agent represents an AI agent that can generate responses and stream responses.
 292type Agent interface {
 293	Generate(context.Context, AgentCall) (*AgentResult, error)
 294	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 295}
 296
 297// AgentOption defines a function that configures agent settings.
 298type AgentOption = func(*agentSettings)
 299
 300type agent struct {
 301	settings agentSettings
 302}
 303
 304// NewAgent creates a new agent with the given language model and options.
 305func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 306	settings := agentSettings{
 307		model: model,
 308	}
 309	for _, o := range opts {
 310		o(&settings)
 311	}
 312	return &agent{
 313		settings: settings,
 314	}
 315}
 316
 317func (a *agent) prepareCall(call AgentCall) AgentCall {
 318	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 319	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 320	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 321	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 322	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 323	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 324	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 325
 326	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 327		call.StopWhen = a.settings.stopWhen
 328	}
 329	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 330		call.PrepareStep = a.settings.prepareStep
 331	}
 332	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 333		call.RepairToolCall = a.settings.repairToolCall
 334	}
 335	if call.OnRetry == nil && a.settings.onRetry != nil {
 336		call.OnRetry = a.settings.onRetry
 337	}
 338
 339	providerOptions := ProviderOptions{}
 340	if a.settings.providerOptions != nil {
 341		maps.Copy(providerOptions, a.settings.providerOptions)
 342	}
 343	if call.ProviderOptions != nil {
 344		maps.Copy(providerOptions, call.ProviderOptions)
 345	}
 346	call.ProviderOptions = providerOptions
 347
 348	headers := map[string]string{}
 349
 350	if a.settings.headers != nil {
 351		maps.Copy(headers, a.settings.headers)
 352	}
 353
 354	return call
 355}
 356
 357// Generate implements Agent.
 358func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 359	opts = a.prepareCall(opts)
 360	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 361	if err != nil {
 362		return nil, err
 363	}
 364	var responseMessages []Message
 365	var steps []StepResult
 366
 367	for {
 368		stepInputMessages := append(initialPrompt, responseMessages...)
 369		stepModel := a.settings.model
 370		stepSystemPrompt := a.settings.systemPrompt
 371		stepActiveTools := opts.ActiveTools
 372		stepToolChoice := ToolChoiceAuto
 373		disableAllTools := false
 374
 375		if opts.PrepareStep != nil {
 376			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 377				Model:      stepModel,
 378				Steps:      steps,
 379				StepNumber: len(steps),
 380				Messages:   stepInputMessages,
 381			})
 382			if err != nil {
 383				return nil, err
 384			}
 385
 386			ctx = updatedCtx
 387
 388			// Apply prepared step modifications
 389			if prepared.Messages != nil {
 390				stepInputMessages = prepared.Messages
 391			}
 392			if prepared.Model != nil {
 393				stepModel = prepared.Model
 394			}
 395			if prepared.System != nil {
 396				stepSystemPrompt = *prepared.System
 397			}
 398			if prepared.ToolChoice != nil {
 399				stepToolChoice = *prepared.ToolChoice
 400			}
 401			if len(prepared.ActiveTools) > 0 {
 402				stepActiveTools = prepared.ActiveTools
 403			}
 404			disableAllTools = prepared.DisableAllTools
 405		}
 406
 407		// Recreate prompt with potentially modified system prompt
 408		if stepSystemPrompt != a.settings.systemPrompt {
 409			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 410			if err != nil {
 411				return nil, err
 412			}
 413			// Replace system message part, keep the rest
 414			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 415				stepInputMessages[0] = stepPrompt[0] // Replace system message
 416			}
 417		}
 418
 419		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 420
 421		retryOptions := DefaultRetryOptions()
 422		retryOptions.OnRetry = opts.OnRetry
 423		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 424
 425		result, err := retry(ctx, func() (*Response, error) {
 426			return stepModel.Generate(ctx, Call{
 427				Prompt:           stepInputMessages,
 428				MaxOutputTokens:  opts.MaxOutputTokens,
 429				Temperature:      opts.Temperature,
 430				TopP:             opts.TopP,
 431				TopK:             opts.TopK,
 432				PresencePenalty:  opts.PresencePenalty,
 433				FrequencyPenalty: opts.FrequencyPenalty,
 434				Tools:            preparedTools,
 435				ToolChoice:       &stepToolChoice,
 436				ProviderOptions:  opts.ProviderOptions,
 437			})
 438		})
 439		if err != nil {
 440			return nil, err
 441		}
 442
 443		var stepToolCalls []ToolCallContent
 444		for _, content := range result.Content {
 445			if content.GetType() == ContentTypeToolCall {
 446				toolCall, ok := AsContentType[ToolCallContent](content)
 447				if !ok {
 448					continue
 449				}
 450
 451				// Validate and potentially repair the tool call
 452				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 453				stepToolCalls = append(stepToolCalls, validatedToolCall)
 454			}
 455		}
 456
 457		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 458
 459		// Build step content with validated tool calls and tool results
 460		stepContent := []Content{}
 461		toolCallIndex := 0
 462		for _, content := range result.Content {
 463			if content.GetType() == ContentTypeToolCall {
 464				// Replace with validated tool call
 465				if toolCallIndex < len(stepToolCalls) {
 466					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 467					toolCallIndex++
 468				}
 469			} else {
 470				// Keep other content as-is
 471				stepContent = append(stepContent, content)
 472			}
 473		}
 474		// Add tool results
 475		for _, result := range toolResults {
 476			stepContent = append(stepContent, result)
 477		}
 478		currentStepMessages := toResponseMessages(stepContent)
 479		responseMessages = append(responseMessages, currentStepMessages...)
 480
 481		stepResult := StepResult{
 482			Response: Response{
 483				Content:          stepContent,
 484				FinishReason:     result.FinishReason,
 485				Usage:            result.Usage,
 486				Warnings:         result.Warnings,
 487				ProviderMetadata: result.ProviderMetadata,
 488			},
 489			Messages: currentStepMessages,
 490		}
 491		steps = append(steps, stepResult)
 492		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 493
 494		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 495			break
 496		}
 497	}
 498
 499	totalUsage := Usage{}
 500
 501	for _, step := range steps {
 502		usage := step.Usage
 503		totalUsage.InputTokens += usage.InputTokens
 504		totalUsage.OutputTokens += usage.OutputTokens
 505		totalUsage.ReasoningTokens += usage.ReasoningTokens
 506		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 507		totalUsage.CacheReadTokens += usage.CacheReadTokens
 508		totalUsage.TotalTokens += usage.TotalTokens
 509	}
 510
 511	agentResult := &AgentResult{
 512		Steps:      steps,
 513		Response:   steps[len(steps)-1].Response,
 514		TotalUsage: totalUsage,
 515	}
 516	return agentResult, nil
 517}
 518
 519func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 520	if len(conditions) == 0 {
 521		return false
 522	}
 523
 524	for _, condition := range conditions {
 525		if condition(steps) {
 526			return true
 527		}
 528	}
 529	return false
 530}
 531
 532func toResponseMessages(content []Content) []Message {
 533	var assistantParts []MessagePart
 534	var toolParts []MessagePart
 535
 536	for _, c := range content {
 537		switch c.GetType() {
 538		case ContentTypeText:
 539			text, ok := AsContentType[TextContent](c)
 540			if !ok {
 541				continue
 542			}
 543			assistantParts = append(assistantParts, TextPart{
 544				Text:            text.Text,
 545				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 546			})
 547		case ContentTypeReasoning:
 548			reasoning, ok := AsContentType[ReasoningContent](c)
 549			if !ok {
 550				continue
 551			}
 552			assistantParts = append(assistantParts, ReasoningPart{
 553				Text:            reasoning.Text,
 554				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 555			})
 556		case ContentTypeToolCall:
 557			toolCall, ok := AsContentType[ToolCallContent](c)
 558			if !ok {
 559				continue
 560			}
 561			assistantParts = append(assistantParts, ToolCallPart{
 562				ToolCallID:       toolCall.ToolCallID,
 563				ToolName:         toolCall.ToolName,
 564				Input:            toolCall.Input,
 565				ProviderExecuted: toolCall.ProviderExecuted,
 566				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 567			})
 568		case ContentTypeFile:
 569			file, ok := AsContentType[FileContent](c)
 570			if !ok {
 571				continue
 572			}
 573			assistantParts = append(assistantParts, FilePart{
 574				Data:            file.Data,
 575				MediaType:       file.MediaType,
 576				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 577			})
 578		case ContentTypeSource:
 579			// Sources are metadata about references used to generate the response.
 580			// They don't need to be included in the conversation messages.
 581			continue
 582		case ContentTypeToolResult:
 583			result, ok := AsContentType[ToolResultContent](c)
 584			if !ok {
 585				continue
 586			}
 587			toolParts = append(toolParts, ToolResultPart{
 588				ToolCallID:      result.ToolCallID,
 589				Output:          result.Result,
 590				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 591			})
 592		}
 593	}
 594
 595	var messages []Message
 596	if len(assistantParts) > 0 {
 597		messages = append(messages, Message{
 598			Role:    MessageRoleAssistant,
 599			Content: assistantParts,
 600		})
 601	}
 602	if len(toolParts) > 0 {
 603		messages = append(messages, Message{
 604			Role:    MessageRoleTool,
 605			Content: toolParts,
 606		})
 607	}
 608	return messages
 609}
 610
 611func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 612	if len(toolCalls) == 0 {
 613		return nil, nil
 614	}
 615
 616	// Create a map for quick tool lookup
 617	toolMap := make(map[string]AgentTool)
 618	for _, tool := range allTools {
 619		toolMap[tool.Info().Name] = tool
 620	}
 621
 622	// Execute all tool calls in parallel
 623	results := make([]ToolResultContent, len(toolCalls))
 624	executeErrors := make([]error, len(toolCalls))
 625
 626	var wg sync.WaitGroup
 627
 628	for i, toolCall := range toolCalls {
 629		wg.Add(1)
 630		go func(index int, call ToolCallContent) {
 631			defer wg.Done()
 632
 633			// Skip invalid tool calls - create error result
 634			if call.Invalid {
 635				results[index] = ToolResultContent{
 636					ToolCallID: call.ToolCallID,
 637					ToolName:   call.ToolName,
 638					Result: ToolResultOutputContentError{
 639						Error: call.ValidationError,
 640					},
 641					ProviderExecuted: false,
 642				}
 643				if toolResultCallback != nil {
 644					err := toolResultCallback(results[index])
 645					if err != nil {
 646						executeErrors[index] = err
 647					}
 648				}
 649
 650				return
 651			}
 652
 653			tool, exists := toolMap[call.ToolName]
 654			if !exists {
 655				results[index] = ToolResultContent{
 656					ToolCallID: call.ToolCallID,
 657					ToolName:   call.ToolName,
 658					Result: ToolResultOutputContentError{
 659						Error: errors.New("Error: Tool not found: " + call.ToolName),
 660					},
 661					ProviderExecuted: false,
 662				}
 663
 664				if toolResultCallback != nil {
 665					err := toolResultCallback(results[index])
 666					if err != nil {
 667						executeErrors[index] = err
 668					}
 669				}
 670				return
 671			}
 672
 673			// Execute the tool
 674			result, err := tool.Run(ctx, ToolCall{
 675				ID:    call.ToolCallID,
 676				Name:  call.ToolName,
 677				Input: call.Input,
 678			})
 679			if err != nil {
 680				results[index] = ToolResultContent{
 681					ToolCallID: call.ToolCallID,
 682					ToolName:   call.ToolName,
 683					Result: ToolResultOutputContentError{
 684						Error: err,
 685					},
 686					ClientMetadata:   result.Metadata,
 687					ProviderExecuted: false,
 688				}
 689				if toolResultCallback != nil {
 690					cbErr := toolResultCallback(results[index])
 691					if cbErr != nil {
 692						executeErrors[index] = cbErr
 693					}
 694				}
 695				executeErrors[index] = err
 696				return
 697			}
 698
 699			if result.IsError {
 700				results[index] = ToolResultContent{
 701					ToolCallID: call.ToolCallID,
 702					ToolName:   call.ToolName,
 703					Result: ToolResultOutputContentError{
 704						Error: errors.New(result.Content),
 705					},
 706					ClientMetadata:   result.Metadata,
 707					ProviderExecuted: false,
 708				}
 709
 710				if toolResultCallback != nil {
 711					err := toolResultCallback(results[index])
 712					if err != nil {
 713						executeErrors[index] = err
 714					}
 715				}
 716			} else {
 717				results[index] = ToolResultContent{
 718					ToolCallID: call.ToolCallID,
 719					ToolName:   toolCall.ToolName,
 720					Result: ToolResultOutputContentText{
 721						Text: result.Content,
 722					},
 723					ClientMetadata:   result.Metadata,
 724					ProviderExecuted: false,
 725				}
 726				if toolResultCallback != nil {
 727					err := toolResultCallback(results[index])
 728					if err != nil {
 729						executeErrors[index] = err
 730					}
 731				}
 732			}
 733		}(i, toolCall)
 734	}
 735
 736	// Wait for all tool executions to complete
 737	wg.Wait()
 738
 739	for _, err := range executeErrors {
 740		if err != nil {
 741			return nil, err
 742		}
 743	}
 744
 745	return results, nil
 746}
 747
 748// Stream implements Agent.
 749func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 750	// Convert AgentStreamCall to AgentCall for preparation
 751	call := AgentCall{
 752		Prompt:           opts.Prompt,
 753		Files:            opts.Files,
 754		Messages:         opts.Messages,
 755		MaxOutputTokens:  opts.MaxOutputTokens,
 756		Temperature:      opts.Temperature,
 757		TopP:             opts.TopP,
 758		TopK:             opts.TopK,
 759		PresencePenalty:  opts.PresencePenalty,
 760		FrequencyPenalty: opts.FrequencyPenalty,
 761		ActiveTools:      opts.ActiveTools,
 762		ProviderOptions:  opts.ProviderOptions,
 763		MaxRetries:       opts.MaxRetries,
 764		StopWhen:         opts.StopWhen,
 765		PrepareStep:      opts.PrepareStep,
 766		RepairToolCall:   opts.RepairToolCall,
 767	}
 768
 769	call = a.prepareCall(call)
 770
 771	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 772	if err != nil {
 773		return nil, err
 774	}
 775
 776	var responseMessages []Message
 777	var steps []StepResult
 778	var totalUsage Usage
 779
 780	// Start agent stream
 781	if opts.OnAgentStart != nil {
 782		opts.OnAgentStart()
 783	}
 784
 785	for stepNumber := 0; ; stepNumber++ {
 786		stepInputMessages := append(initialPrompt, responseMessages...)
 787		stepModel := a.settings.model
 788		stepSystemPrompt := a.settings.systemPrompt
 789		stepActiveTools := call.ActiveTools
 790		stepToolChoice := ToolChoiceAuto
 791		disableAllTools := false
 792
 793		// Apply step preparation if provided
 794		if call.PrepareStep != nil {
 795			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 796				Model:      stepModel,
 797				Steps:      steps,
 798				StepNumber: stepNumber,
 799				Messages:   stepInputMessages,
 800			})
 801			if err != nil {
 802				return nil, err
 803			}
 804
 805			ctx = updatedCtx
 806
 807			if prepared.Messages != nil {
 808				stepInputMessages = prepared.Messages
 809			}
 810			if prepared.Model != nil {
 811				stepModel = prepared.Model
 812			}
 813			if prepared.System != nil {
 814				stepSystemPrompt = *prepared.System
 815			}
 816			if prepared.ToolChoice != nil {
 817				stepToolChoice = *prepared.ToolChoice
 818			}
 819			if len(prepared.ActiveTools) > 0 {
 820				stepActiveTools = prepared.ActiveTools
 821			}
 822			disableAllTools = prepared.DisableAllTools
 823		}
 824
 825		// Recreate prompt with potentially modified system prompt
 826		if stepSystemPrompt != a.settings.systemPrompt {
 827			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 828			if err != nil {
 829				return nil, err
 830			}
 831			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 832				stepInputMessages[0] = stepPrompt[0]
 833			}
 834		}
 835
 836		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 837
 838		// Start step stream
 839		if opts.OnStepStart != nil {
 840			_ = opts.OnStepStart(stepNumber)
 841		}
 842
 843		// Create streaming call
 844		streamCall := Call{
 845			Prompt:           stepInputMessages,
 846			MaxOutputTokens:  call.MaxOutputTokens,
 847			Temperature:      call.Temperature,
 848			TopP:             call.TopP,
 849			TopK:             call.TopK,
 850			PresencePenalty:  call.PresencePenalty,
 851			FrequencyPenalty: call.FrequencyPenalty,
 852			Tools:            preparedTools,
 853			ToolChoice:       &stepToolChoice,
 854			ProviderOptions:  call.ProviderOptions,
 855		}
 856
 857		// Get streaming response
 858		stream, err := stepModel.Stream(ctx, streamCall)
 859		if err != nil {
 860			if opts.OnError != nil {
 861				opts.OnError(err)
 862			}
 863			return nil, err
 864		}
 865
 866		// Process stream with tool execution
 867		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 868		if err != nil {
 869			if opts.OnError != nil {
 870				opts.OnError(err)
 871			}
 872			return nil, err
 873		}
 874
 875		steps = append(steps, stepResult)
 876		totalUsage = addUsage(totalUsage, stepResult.Usage)
 877
 878		// Call step finished callback
 879		if opts.OnStepFinish != nil {
 880			_ = opts.OnStepFinish(stepResult)
 881		}
 882
 883		// Add step messages to response messages
 884		stepMessages := toResponseMessages(stepResult.Content)
 885		responseMessages = append(responseMessages, stepMessages...)
 886
 887		// Check stop conditions
 888		shouldStop := isStopConditionMet(call.StopWhen, steps)
 889		if shouldStop || !shouldContinue {
 890			break
 891		}
 892	}
 893
 894	// Finish agent stream
 895	agentResult := &AgentResult{
 896		Steps:      steps,
 897		Response:   steps[len(steps)-1].Response,
 898		TotalUsage: totalUsage,
 899	}
 900
 901	if opts.OnFinish != nil {
 902		opts.OnFinish(agentResult)
 903	}
 904
 905	if opts.OnAgentFinish != nil {
 906		_ = opts.OnAgentFinish(agentResult)
 907	}
 908
 909	return agentResult, nil
 910}
 911
 912func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 913	preparedTools := make([]Tool, 0, len(tools))
 914
 915	// If explicitly disabling all tools, return no tools
 916	if disableAllTools {
 917		return preparedTools
 918	}
 919
 920	for _, tool := range tools {
 921		// If activeTools has items, only include tools in the list
 922		// If activeTools is empty, include all tools
 923		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 924			continue
 925		}
 926		info := tool.Info()
 927		preparedTools = append(preparedTools, FunctionTool{
 928			Name:        info.Name,
 929			Description: info.Description,
 930			InputSchema: map[string]any{
 931				"type":       "object",
 932				"properties": info.Parameters,
 933				"required":   info.Required,
 934			},
 935			ProviderOptions: tool.ProviderOptions(),
 936		})
 937	}
 938	return preparedTools
 939}
 940
 941// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 942func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 943	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 944		return toolCall
 945	} else { //nolint: revive
 946		if repairFunc != nil {
 947			repairOptions := ToolCallRepairOptions{
 948				OriginalToolCall: toolCall,
 949				ValidationError:  err,
 950				AvailableTools:   availableTools,
 951				SystemPrompt:     systemPrompt,
 952				Messages:         messages,
 953			}
 954
 955			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 956				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 957					return *repairedToolCall
 958				}
 959			}
 960		}
 961
 962		invalidToolCall := toolCall
 963		invalidToolCall.Invalid = true
 964		invalidToolCall.ValidationError = err
 965		return invalidToolCall
 966	}
 967}
 968
 969// validateToolCall validates a tool call against available tools and their schemas.
 970func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 971	var tool AgentTool
 972	for _, t := range availableTools {
 973		if t.Info().Name == toolCall.ToolName {
 974			tool = t
 975			break
 976		}
 977	}
 978
 979	if tool == nil {
 980		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 981	}
 982
 983	// Validate JSON parsing
 984	var input map[string]any
 985	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 986		return fmt.Errorf("invalid JSON input: %w", err)
 987	}
 988
 989	// Basic schema validation (check required fields)
 990	// TODO: more robust schema validation using JSON Schema or similar
 991	toolInfo := tool.Info()
 992	for _, required := range toolInfo.Required {
 993		if _, exists := input[required]; !exists {
 994			return fmt.Errorf("missing required parameter: %s", required)
 995		}
 996	}
 997	return nil
 998}
 999
1000func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1001	if prompt == "" {
1002		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
1003	}
1004
1005	var preparedPrompt Prompt
1006
1007	if system != "" {
1008		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1009	}
1010	preparedPrompt = append(preparedPrompt, messages...)
1011	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1012	return preparedPrompt, nil
1013}
1014
1015// WithSystemPrompt sets the system prompt for the agent.
1016func WithSystemPrompt(prompt string) AgentOption {
1017	return func(s *agentSettings) {
1018		s.systemPrompt = prompt
1019	}
1020}
1021
1022// WithMaxOutputTokens sets the maximum output tokens for the agent.
1023func WithMaxOutputTokens(tokens int64) AgentOption {
1024	return func(s *agentSettings) {
1025		s.maxOutputTokens = &tokens
1026	}
1027}
1028
1029// WithTemperature sets the temperature for the agent.
1030func WithTemperature(temp float64) AgentOption {
1031	return func(s *agentSettings) {
1032		s.temperature = &temp
1033	}
1034}
1035
1036// WithTopP sets the top-p value for the agent.
1037func WithTopP(topP float64) AgentOption {
1038	return func(s *agentSettings) {
1039		s.topP = &topP
1040	}
1041}
1042
1043// WithTopK sets the top-k value for the agent.
1044func WithTopK(topK int64) AgentOption {
1045	return func(s *agentSettings) {
1046		s.topK = &topK
1047	}
1048}
1049
1050// WithPresencePenalty sets the presence penalty for the agent.
1051func WithPresencePenalty(penalty float64) AgentOption {
1052	return func(s *agentSettings) {
1053		s.presencePenalty = &penalty
1054	}
1055}
1056
1057// WithFrequencyPenalty sets the frequency penalty for the agent.
1058func WithFrequencyPenalty(penalty float64) AgentOption {
1059	return func(s *agentSettings) {
1060		s.frequencyPenalty = &penalty
1061	}
1062}
1063
1064// WithTools sets the tools for the agent.
1065func WithTools(tools ...AgentTool) AgentOption {
1066	return func(s *agentSettings) {
1067		s.tools = append(s.tools, tools...)
1068	}
1069}
1070
1071// WithStopConditions sets the stop conditions for the agent.
1072func WithStopConditions(conditions ...StopCondition) AgentOption {
1073	return func(s *agentSettings) {
1074		s.stopWhen = append(s.stopWhen, conditions...)
1075	}
1076}
1077
1078// WithPrepareStep sets the prepare step function for the agent.
1079func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1080	return func(s *agentSettings) {
1081		s.prepareStep = fn
1082	}
1083}
1084
1085// WithRepairToolCall sets the repair tool call function for the agent.
1086func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1087	return func(s *agentSettings) {
1088		s.repairToolCall = fn
1089	}
1090}
1091
1092// processStepStream processes a single step's stream and returns the step result.
1093func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1094	var stepContent []Content
1095	var stepToolCalls []ToolCallContent
1096	var stepUsage Usage
1097	stepFinishReason := FinishReasonUnknown
1098	var stepWarnings []CallWarning
1099	var stepProviderMetadata ProviderMetadata
1100
1101	activeToolCalls := make(map[string]*ToolCallContent)
1102	activeTextContent := make(map[string]string)
1103	type reasoningContent struct {
1104		content string
1105		options ProviderMetadata
1106	}
1107	activeReasoningContent := make(map[string]reasoningContent)
1108
1109	// Process stream parts
1110	for part := range stream {
1111		// Forward all parts to chunk callback
1112		if opts.OnChunk != nil {
1113			err := opts.OnChunk(part)
1114			if err != nil {
1115				return StepResult{}, false, err
1116			}
1117		}
1118
1119		switch part.Type {
1120		case StreamPartTypeWarnings:
1121			stepWarnings = part.Warnings
1122			if opts.OnWarnings != nil {
1123				err := opts.OnWarnings(part.Warnings)
1124				if err != nil {
1125					return StepResult{}, false, err
1126				}
1127			}
1128
1129		case StreamPartTypeTextStart:
1130			activeTextContent[part.ID] = ""
1131			if opts.OnTextStart != nil {
1132				err := opts.OnTextStart(part.ID)
1133				if err != nil {
1134					return StepResult{}, false, err
1135				}
1136			}
1137
1138		case StreamPartTypeTextDelta:
1139			if _, exists := activeTextContent[part.ID]; exists {
1140				activeTextContent[part.ID] += part.Delta
1141			}
1142			if opts.OnTextDelta != nil {
1143				err := opts.OnTextDelta(part.ID, part.Delta)
1144				if err != nil {
1145					return StepResult{}, false, err
1146				}
1147			}
1148
1149		case StreamPartTypeTextEnd:
1150			if text, exists := activeTextContent[part.ID]; exists {
1151				stepContent = append(stepContent, TextContent{
1152					Text:             text,
1153					ProviderMetadata: part.ProviderMetadata,
1154				})
1155				delete(activeTextContent, part.ID)
1156			}
1157			if opts.OnTextEnd != nil {
1158				err := opts.OnTextEnd(part.ID)
1159				if err != nil {
1160					return StepResult{}, false, err
1161				}
1162			}
1163
1164		case StreamPartTypeReasoningStart:
1165			activeReasoningContent[part.ID] = reasoningContent{content: "", options: part.ProviderMetadata}
1166			if opts.OnReasoningStart != nil {
1167				content := ReasoningContent{
1168					Text:             "",
1169					ProviderMetadata: part.ProviderMetadata,
1170				}
1171				err := opts.OnReasoningStart(part.ID, content)
1172				if err != nil {
1173					return StepResult{}, false, err
1174				}
1175			}
1176
1177		case StreamPartTypeReasoningDelta:
1178			if active, exists := activeReasoningContent[part.ID]; exists {
1179				active.content += part.Delta
1180				active.options = part.ProviderMetadata
1181				activeReasoningContent[part.ID] = active
1182			}
1183			if opts.OnReasoningDelta != nil {
1184				err := opts.OnReasoningDelta(part.ID, part.Delta)
1185				if err != nil {
1186					return StepResult{}, false, err
1187				}
1188			}
1189
1190		case StreamPartTypeReasoningEnd:
1191			if active, exists := activeReasoningContent[part.ID]; exists {
1192				if part.ProviderMetadata != nil {
1193					active.options = part.ProviderMetadata
1194				}
1195				content := ReasoningContent{
1196					Text:             active.content,
1197					ProviderMetadata: active.options,
1198				}
1199				stepContent = append(stepContent, content)
1200				if opts.OnReasoningEnd != nil {
1201					err := opts.OnReasoningEnd(part.ID, content)
1202					if err != nil {
1203						return StepResult{}, false, err
1204					}
1205				}
1206				delete(activeReasoningContent, part.ID)
1207			}
1208
1209		case StreamPartTypeToolInputStart:
1210			activeToolCalls[part.ID] = &ToolCallContent{
1211				ToolCallID:       part.ID,
1212				ToolName:         part.ToolCallName,
1213				Input:            "",
1214				ProviderExecuted: part.ProviderExecuted,
1215			}
1216			if opts.OnToolInputStart != nil {
1217				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1218				if err != nil {
1219					return StepResult{}, false, err
1220				}
1221			}
1222
1223		case StreamPartTypeToolInputDelta:
1224			if toolCall, exists := activeToolCalls[part.ID]; exists {
1225				toolCall.Input += part.Delta
1226			}
1227			if opts.OnToolInputDelta != nil {
1228				err := opts.OnToolInputDelta(part.ID, part.Delta)
1229				if err != nil {
1230					return StepResult{}, false, err
1231				}
1232			}
1233
1234		case StreamPartTypeToolInputEnd:
1235			if opts.OnToolInputEnd != nil {
1236				err := opts.OnToolInputEnd(part.ID)
1237				if err != nil {
1238					return StepResult{}, false, err
1239				}
1240			}
1241
1242		case StreamPartTypeToolCall:
1243			toolCall := ToolCallContent{
1244				ToolCallID:       part.ID,
1245				ToolName:         part.ToolCallName,
1246				Input:            part.ToolCallInput,
1247				ProviderExecuted: part.ProviderExecuted,
1248				ProviderMetadata: part.ProviderMetadata,
1249			}
1250
1251			// Validate and potentially repair the tool call
1252			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1253			stepToolCalls = append(stepToolCalls, validatedToolCall)
1254			stepContent = append(stepContent, validatedToolCall)
1255
1256			if opts.OnToolCall != nil {
1257				err := opts.OnToolCall(validatedToolCall)
1258				if err != nil {
1259					return StepResult{}, false, err
1260				}
1261			}
1262
1263			// Clean up active tool call
1264			delete(activeToolCalls, part.ID)
1265
1266		case StreamPartTypeSource:
1267			sourceContent := SourceContent{
1268				SourceType:       part.SourceType,
1269				ID:               part.ID,
1270				URL:              part.URL,
1271				Title:            part.Title,
1272				ProviderMetadata: part.ProviderMetadata,
1273			}
1274			stepContent = append(stepContent, sourceContent)
1275			if opts.OnSource != nil {
1276				err := opts.OnSource(sourceContent)
1277				if err != nil {
1278					return StepResult{}, false, err
1279				}
1280			}
1281
1282		case StreamPartTypeFinish:
1283			stepUsage = part.Usage
1284			stepFinishReason = part.FinishReason
1285			stepProviderMetadata = part.ProviderMetadata
1286			if opts.OnStreamFinish != nil {
1287				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1288				if err != nil {
1289					return StepResult{}, false, err
1290				}
1291			}
1292
1293		case StreamPartTypeError:
1294			if opts.OnError != nil {
1295				opts.OnError(part.Error)
1296			}
1297			return StepResult{}, false, part.Error
1298		}
1299	}
1300
1301	// Execute tools if any
1302	var toolResults []ToolResultContent
1303	if len(stepToolCalls) > 0 {
1304		var err error
1305		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1306		if err != nil {
1307			return StepResult{}, false, err
1308		}
1309		// Add tool results to content
1310		for _, result := range toolResults {
1311			stepContent = append(stepContent, result)
1312		}
1313	}
1314
1315	stepResult := StepResult{
1316		Response: Response{
1317			Content:          stepContent,
1318			FinishReason:     stepFinishReason,
1319			Usage:            stepUsage,
1320			Warnings:         stepWarnings,
1321			ProviderMetadata: stepProviderMetadata,
1322		},
1323		Messages: toResponseMessages(stepContent),
1324	}
1325
1326	// Determine if we should continue (has tool calls and not stopped)
1327	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1328
1329	return stepResult, shouldContinue, nil
1330}
1331
1332func addUsage(a, b Usage) Usage {
1333	return Usage{
1334		InputTokens:         a.InputTokens + b.InputTokens,
1335		OutputTokens:        a.OutputTokens + b.OutputTokens,
1336		TotalTokens:         a.TotalTokens + b.TotalTokens,
1337		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1338		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1339		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1340	}
1341}
1342
1343// WithHeaders sets the headers for the agent.
1344func WithHeaders(headers map[string]string) AgentOption {
1345	return func(s *agentSettings) {
1346		s.headers = headers
1347	}
1348}
1349
1350// WithProviderOptions sets the provider options for the agent.
1351func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1352	return func(s *agentSettings) {
1353		s.providerOptions = providerOptions
1354	}
1355}