agent.go

   1package fantasy
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11)
  12
  13// StepResult represents the result of a single step in an agent execution.
  14type StepResult struct {
  15	Response
  16	Messages []Message
  17}
  18
  19// StopCondition defines a function that determines when an agent should stop executing.
  20type StopCondition = func(steps []StepResult) bool
  21
  22// StepCountIs returns a stop condition that stops after the specified number of steps.
  23func StepCountIs(stepCount int) StopCondition {
  24	return func(steps []StepResult) bool {
  25		return len(steps) >= stepCount
  26	}
  27}
  28
  29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  30func HasToolCall(toolName string) StopCondition {
  31	return func(steps []StepResult) bool {
  32		if len(steps) == 0 {
  33			return false
  34		}
  35		lastStep := steps[len(steps)-1]
  36		toolCalls := lastStep.Content.ToolCalls()
  37		for _, toolCall := range toolCalls {
  38			if toolCall.ToolName == toolName {
  39				return true
  40			}
  41		}
  42		return false
  43	}
  44}
  45
  46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  47func HasContent(contentType ContentType) StopCondition {
  48	return func(steps []StepResult) bool {
  49		if len(steps) == 0 {
  50			return false
  51		}
  52		lastStep := steps[len(steps)-1]
  53		for _, content := range lastStep.Content {
  54			if content.GetType() == contentType {
  55				return true
  56			}
  57		}
  58		return false
  59	}
  60}
  61
  62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  63func FinishReasonIs(reason FinishReason) StopCondition {
  64	return func(steps []StepResult) bool {
  65		if len(steps) == 0 {
  66			return false
  67		}
  68		lastStep := steps[len(steps)-1]
  69		return lastStep.FinishReason == reason
  70	}
  71}
  72
  73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  74func MaxTokensUsed(maxTokens int64) StopCondition {
  75	return func(steps []StepResult) bool {
  76		var totalTokens int64
  77		for _, step := range steps {
  78			totalTokens += step.Usage.TotalTokens
  79		}
  80		return totalTokens >= maxTokens
  81	}
  82}
  83
  84// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
  85type PrepareStepFunctionOptions struct {
  86	Steps      []StepResult
  87	StepNumber int
  88	Model      LanguageModel
  89	Messages   []Message
  90}
  91
  92// PrepareStepResult contains the result of preparing a step in an agent execution.
  93type PrepareStepResult struct {
  94	Model           LanguageModel
  95	Messages        []Message
  96	System          *string
  97	ToolChoice      *ToolChoice
  98	ActiveTools     []string
  99	DisableAllTools bool
 100}
 101
 102// ToolCallRepairOptions contains the options for repairing a tool call.
 103type ToolCallRepairOptions struct {
 104	OriginalToolCall ToolCallContent
 105	ValidationError  error
 106	AvailableTools   []AgentTool
 107	SystemPrompt     string
 108	Messages         []Message
 109}
 110
 111type (
 112	// PrepareStepFunction defines a function that prepares a step in an agent execution.
 113	PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 114
 115	// OnStepFinishedFunction defines a function that is called when a step finishes.
 116	OnStepFinishedFunction = func(step StepResult)
 117
 118	// RepairToolCallFunction defines a function that repairs a tool call.
 119	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 120)
 121
 122type agentSettings struct {
 123	systemPrompt     string
 124	maxOutputTokens  *int64
 125	temperature      *float64
 126	topP             *float64
 127	topK             *int64
 128	presencePenalty  *float64
 129	frequencyPenalty *float64
 130	headers          map[string]string
 131	providerOptions  ProviderOptions
 132
 133	// TODO: add support for provider tools
 134	tools      []AgentTool
 135	maxRetries *int
 136
 137	model LanguageModel
 138
 139	stopWhen       []StopCondition
 140	prepareStep    PrepareStepFunction
 141	repairToolCall RepairToolCallFunction
 142	onRetry        OnRetryCallback
 143}
 144
 145// AgentCall represents a call to an agent.
 146type AgentCall struct {
 147	Prompt           string     `json:"prompt"`
 148	Files            []FilePart `json:"files"`
 149	Messages         []Message  `json:"messages"`
 150	MaxOutputTokens  *int64
 151	Temperature      *float64 `json:"temperature"`
 152	TopP             *float64 `json:"top_p"`
 153	TopK             *int64   `json:"top_k"`
 154	PresencePenalty  *float64 `json:"presence_penalty"`
 155	FrequencyPenalty *float64 `json:"frequency_penalty"`
 156	ActiveTools      []string `json:"active_tools"`
 157	ProviderOptions  ProviderOptions
 158	OnRetry          OnRetryCallback
 159	MaxRetries       *int
 160
 161	StopWhen       []StopCondition
 162	PrepareStep    PrepareStepFunction
 163	RepairToolCall RepairToolCallFunction
 164}
 165
 166// Agent-level callbacks.
 167type (
 168	// OnAgentStartFunc is called when agent starts.
 169	OnAgentStartFunc func()
 170
 171	// OnAgentFinishFunc is called when agent finishes.
 172	OnAgentFinishFunc func(result *AgentResult) error
 173
 174	// OnStepStartFunc is called when a step starts.
 175	OnStepStartFunc func(stepNumber int) error
 176
 177	// OnStepFinishFunc is called when a step finishes.
 178	OnStepFinishFunc func(stepResult StepResult) error
 179
 180	// OnFinishFunc is called when entire agent completes.
 181	OnFinishFunc func(result *AgentResult)
 182
 183	// OnErrorFunc is called when an error occurs.
 184	OnErrorFunc func(error)
 185)
 186
 187// Stream part callbacks - called for each corresponding stream part type.
 188type (
 189	// OnChunkFunc is called for each stream part (catch-all).
 190	OnChunkFunc func(StreamPart) error
 191
 192	// OnWarningsFunc is called for warnings.
 193	OnWarningsFunc func(warnings []CallWarning) error
 194
 195	// OnTextStartFunc is called when text starts.
 196	OnTextStartFunc func(id string) error
 197
 198	// OnTextDeltaFunc is called for text deltas.
 199	OnTextDeltaFunc func(id, text string) error
 200
 201	// OnTextEndFunc is called when text ends.
 202	OnTextEndFunc func(id string) error
 203
 204	// OnReasoningStartFunc is called when reasoning starts.
 205	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 206
 207	// OnReasoningDeltaFunc is called for reasoning deltas.
 208	OnReasoningDeltaFunc func(id, text string) error
 209
 210	// OnReasoningEndFunc is called when reasoning ends.
 211	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 212
 213	// OnToolInputStartFunc is called when tool input starts.
 214	OnToolInputStartFunc func(id, toolName string) error
 215
 216	// OnToolInputDeltaFunc is called for tool input deltas.
 217	OnToolInputDeltaFunc func(id, delta string) error
 218
 219	// OnToolInputEndFunc is called when tool input ends.
 220	OnToolInputEndFunc func(id string) error
 221
 222	// OnToolCallFunc is called when tool call is complete.
 223	OnToolCallFunc func(toolCall ToolCallContent) error
 224
 225	// OnToolResultFunc is called when tool execution completes.
 226	OnToolResultFunc func(result ToolResultContent) error
 227
 228	// OnSourceFunc is called for source references.
 229	OnSourceFunc func(source SourceContent) error
 230
 231	// OnStreamFinishFunc is called when stream finishes.
 232	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 233)
 234
 235// AgentStreamCall represents a streaming call to an agent.
 236type AgentStreamCall struct {
 237	Prompt           string     `json:"prompt"`
 238	Files            []FilePart `json:"files"`
 239	Messages         []Message  `json:"messages"`
 240	MaxOutputTokens  *int64
 241	Temperature      *float64 `json:"temperature"`
 242	TopP             *float64 `json:"top_p"`
 243	TopK             *int64   `json:"top_k"`
 244	PresencePenalty  *float64 `json:"presence_penalty"`
 245	FrequencyPenalty *float64 `json:"frequency_penalty"`
 246	ActiveTools      []string `json:"active_tools"`
 247	Headers          map[string]string
 248	ProviderOptions  ProviderOptions
 249	OnRetry          OnRetryCallback
 250	MaxRetries       *int
 251
 252	StopWhen       []StopCondition
 253	PrepareStep    PrepareStepFunction
 254	RepairToolCall RepairToolCallFunction
 255
 256	// Agent-level callbacks
 257	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 258	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 259	OnStepStart   OnStepStartFunc   // Called when a step starts
 260	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 261	OnFinish      OnFinishFunc      // Called when entire agent completes
 262	OnError       OnErrorFunc       // Called when an error occurs
 263
 264	// Stream part callbacks - called for each corresponding stream part type
 265	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 266	OnWarnings       OnWarningsFunc       // Called for warnings
 267	OnTextStart      OnTextStartFunc      // Called when text starts
 268	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 269	OnTextEnd        OnTextEndFunc        // Called when text ends
 270	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 271	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 272	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 273	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 274	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 275	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 276	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 277	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 278	OnSource         OnSourceFunc         // Called for source references
 279	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 280}
 281
 282// AgentResult represents the result of an agent execution.
 283type AgentResult struct {
 284	Steps []StepResult
 285	// Final response
 286	Response   Response
 287	TotalUsage Usage
 288}
 289
 290// Agent represents an AI agent that can generate responses and stream responses.
 291type Agent interface {
 292	Generate(context.Context, AgentCall) (*AgentResult, error)
 293	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 294}
 295
 296// AgentOption defines a function that configures agent settings.
 297type AgentOption = func(*agentSettings)
 298
 299type agent struct {
 300	settings agentSettings
 301}
 302
 303// NewAgent creates a new agent with the given language model and options.
 304func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 305	settings := agentSettings{
 306		model: model,
 307	}
 308	for _, o := range opts {
 309		o(&settings)
 310	}
 311	return &agent{
 312		settings: settings,
 313	}
 314}
 315
 316func (a *agent) prepareCall(call AgentCall) AgentCall {
 317	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 318	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 319	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 320	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 321	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 322	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 323	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 324
 325	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 326		call.StopWhen = a.settings.stopWhen
 327	}
 328	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 329		call.PrepareStep = a.settings.prepareStep
 330	}
 331	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 332		call.RepairToolCall = a.settings.repairToolCall
 333	}
 334	if call.OnRetry == nil && a.settings.onRetry != nil {
 335		call.OnRetry = a.settings.onRetry
 336	}
 337
 338	providerOptions := ProviderOptions{}
 339	if a.settings.providerOptions != nil {
 340		maps.Copy(providerOptions, a.settings.providerOptions)
 341	}
 342	if call.ProviderOptions != nil {
 343		maps.Copy(providerOptions, call.ProviderOptions)
 344	}
 345	call.ProviderOptions = providerOptions
 346
 347	headers := map[string]string{}
 348
 349	if a.settings.headers != nil {
 350		maps.Copy(headers, a.settings.headers)
 351	}
 352
 353	return call
 354}
 355
 356// Generate implements Agent.
 357func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 358	opts = a.prepareCall(opts)
 359	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 360	if err != nil {
 361		return nil, err
 362	}
 363	var responseMessages []Message
 364	var steps []StepResult
 365
 366	for {
 367		stepInputMessages := append(initialPrompt, responseMessages...)
 368		stepModel := a.settings.model
 369		stepSystemPrompt := a.settings.systemPrompt
 370		stepActiveTools := opts.ActiveTools
 371		stepToolChoice := ToolChoiceAuto
 372		disableAllTools := false
 373
 374		if opts.PrepareStep != nil {
 375			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 376				Model:      stepModel,
 377				Steps:      steps,
 378				StepNumber: len(steps),
 379				Messages:   stepInputMessages,
 380			})
 381			if err != nil {
 382				return nil, err
 383			}
 384
 385			ctx = updatedCtx
 386
 387			// Apply prepared step modifications
 388			if prepared.Messages != nil {
 389				stepInputMessages = prepared.Messages
 390			}
 391			if prepared.Model != nil {
 392				stepModel = prepared.Model
 393			}
 394			if prepared.System != nil {
 395				stepSystemPrompt = *prepared.System
 396			}
 397			if prepared.ToolChoice != nil {
 398				stepToolChoice = *prepared.ToolChoice
 399			}
 400			if len(prepared.ActiveTools) > 0 {
 401				stepActiveTools = prepared.ActiveTools
 402			}
 403			disableAllTools = prepared.DisableAllTools
 404		}
 405
 406		// Recreate prompt with potentially modified system prompt
 407		if stepSystemPrompt != a.settings.systemPrompt {
 408			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 409			if err != nil {
 410				return nil, err
 411			}
 412			// Replace system message part, keep the rest
 413			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 414				stepInputMessages[0] = stepPrompt[0] // Replace system message
 415			}
 416		}
 417
 418		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 419
 420		retryOptions := DefaultRetryOptions()
 421		retryOptions.OnRetry = opts.OnRetry
 422		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 423
 424		result, err := retry(ctx, func() (*Response, error) {
 425			return stepModel.Generate(ctx, Call{
 426				Prompt:           stepInputMessages,
 427				MaxOutputTokens:  opts.MaxOutputTokens,
 428				Temperature:      opts.Temperature,
 429				TopP:             opts.TopP,
 430				TopK:             opts.TopK,
 431				PresencePenalty:  opts.PresencePenalty,
 432				FrequencyPenalty: opts.FrequencyPenalty,
 433				Tools:            preparedTools,
 434				ToolChoice:       &stepToolChoice,
 435				ProviderOptions:  opts.ProviderOptions,
 436			})
 437		})
 438		if err != nil {
 439			return nil, err
 440		}
 441
 442		var stepToolCalls []ToolCallContent
 443		for _, content := range result.Content {
 444			if content.GetType() == ContentTypeToolCall {
 445				toolCall, ok := AsContentType[ToolCallContent](content)
 446				if !ok {
 447					continue
 448				}
 449
 450				// Validate and potentially repair the tool call
 451				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 452				stepToolCalls = append(stepToolCalls, validatedToolCall)
 453			}
 454		}
 455
 456		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 457
 458		// Build step content with validated tool calls and tool results
 459		stepContent := []Content{}
 460		toolCallIndex := 0
 461		for _, content := range result.Content {
 462			if content.GetType() == ContentTypeToolCall {
 463				// Replace with validated tool call
 464				if toolCallIndex < len(stepToolCalls) {
 465					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 466					toolCallIndex++
 467				}
 468			} else {
 469				// Keep other content as-is
 470				stepContent = append(stepContent, content)
 471			}
 472		}
 473		// Add tool results
 474		for _, result := range toolResults {
 475			stepContent = append(stepContent, result)
 476		}
 477		currentStepMessages := toResponseMessages(stepContent)
 478		responseMessages = append(responseMessages, currentStepMessages...)
 479
 480		stepResult := StepResult{
 481			Response: Response{
 482				Content:          stepContent,
 483				FinishReason:     result.FinishReason,
 484				Usage:            result.Usage,
 485				Warnings:         result.Warnings,
 486				ProviderMetadata: result.ProviderMetadata,
 487			},
 488			Messages: currentStepMessages,
 489		}
 490		steps = append(steps, stepResult)
 491		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 492
 493		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 494			break
 495		}
 496	}
 497
 498	totalUsage := Usage{}
 499
 500	for _, step := range steps {
 501		usage := step.Usage
 502		totalUsage.InputTokens += usage.InputTokens
 503		totalUsage.OutputTokens += usage.OutputTokens
 504		totalUsage.ReasoningTokens += usage.ReasoningTokens
 505		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 506		totalUsage.CacheReadTokens += usage.CacheReadTokens
 507		totalUsage.TotalTokens += usage.TotalTokens
 508	}
 509
 510	agentResult := &AgentResult{
 511		Steps:      steps,
 512		Response:   steps[len(steps)-1].Response,
 513		TotalUsage: totalUsage,
 514	}
 515	return agentResult, nil
 516}
 517
 518func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 519	if len(conditions) == 0 {
 520		return false
 521	}
 522
 523	for _, condition := range conditions {
 524		if condition(steps) {
 525			return true
 526		}
 527	}
 528	return false
 529}
 530
 531func toResponseMessages(content []Content) []Message {
 532	var assistantParts []MessagePart
 533	var toolParts []MessagePart
 534
 535	for _, c := range content {
 536		switch c.GetType() {
 537		case ContentTypeText:
 538			text, ok := AsContentType[TextContent](c)
 539			if !ok {
 540				continue
 541			}
 542			assistantParts = append(assistantParts, TextPart{
 543				Text:            text.Text,
 544				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 545			})
 546		case ContentTypeReasoning:
 547			reasoning, ok := AsContentType[ReasoningContent](c)
 548			if !ok {
 549				continue
 550			}
 551			assistantParts = append(assistantParts, ReasoningPart{
 552				Text:            reasoning.Text,
 553				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 554			})
 555		case ContentTypeToolCall:
 556			toolCall, ok := AsContentType[ToolCallContent](c)
 557			if !ok {
 558				continue
 559			}
 560			assistantParts = append(assistantParts, ToolCallPart{
 561				ToolCallID:       toolCall.ToolCallID,
 562				ToolName:         toolCall.ToolName,
 563				Input:            toolCall.Input,
 564				ProviderExecuted: toolCall.ProviderExecuted,
 565				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 566			})
 567		case ContentTypeFile:
 568			file, ok := AsContentType[FileContent](c)
 569			if !ok {
 570				continue
 571			}
 572			assistantParts = append(assistantParts, FilePart{
 573				Data:            file.Data,
 574				MediaType:       file.MediaType,
 575				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 576			})
 577		case ContentTypeSource:
 578			// Sources are metadata about references used to generate the response.
 579			// They don't need to be included in the conversation messages.
 580			continue
 581		case ContentTypeToolResult:
 582			result, ok := AsContentType[ToolResultContent](c)
 583			if !ok {
 584				continue
 585			}
 586			toolParts = append(toolParts, ToolResultPart{
 587				ToolCallID:      result.ToolCallID,
 588				Output:          result.Result,
 589				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 590			})
 591		}
 592	}
 593
 594	var messages []Message
 595	if len(assistantParts) > 0 {
 596		messages = append(messages, Message{
 597			Role:    MessageRoleAssistant,
 598			Content: assistantParts,
 599		})
 600	}
 601	if len(toolParts) > 0 {
 602		messages = append(messages, Message{
 603			Role:    MessageRoleTool,
 604			Content: toolParts,
 605		})
 606	}
 607	return messages
 608}
 609
 610func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 611	if len(toolCalls) == 0 {
 612		return nil, nil
 613	}
 614
 615	// Create a map for quick tool lookup
 616	toolMap := make(map[string]AgentTool)
 617	for _, tool := range allTools {
 618		toolMap[tool.Info().Name] = tool
 619	}
 620
 621	// Execute all tool calls sequentially in order
 622	results := make([]ToolResultContent, 0, len(toolCalls))
 623
 624	for _, toolCall := range toolCalls {
 625		// Skip invalid tool calls - create error result
 626		if toolCall.Invalid {
 627			result := ToolResultContent{
 628				ToolCallID: toolCall.ToolCallID,
 629				ToolName:   toolCall.ToolName,
 630				Result: ToolResultOutputContentError{
 631					Error: toolCall.ValidationError,
 632				},
 633				ProviderExecuted: false,
 634			}
 635			results = append(results, result)
 636			if toolResultCallback != nil {
 637				if err := toolResultCallback(result); err != nil {
 638					return nil, err
 639				}
 640			}
 641			continue
 642		}
 643
 644		tool, exists := toolMap[toolCall.ToolName]
 645		if !exists {
 646			result := ToolResultContent{
 647				ToolCallID: toolCall.ToolCallID,
 648				ToolName:   toolCall.ToolName,
 649				Result: ToolResultOutputContentError{
 650					Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
 651				},
 652				ProviderExecuted: false,
 653			}
 654			results = append(results, result)
 655			if toolResultCallback != nil {
 656				if err := toolResultCallback(result); err != nil {
 657					return nil, err
 658				}
 659			}
 660			continue
 661		}
 662
 663		// Execute the tool
 664		toolResult, err := tool.Run(ctx, ToolCall{
 665			ID:    toolCall.ToolCallID,
 666			Name:  toolCall.ToolName,
 667			Input: toolCall.Input,
 668		})
 669		if err != nil {
 670			result := ToolResultContent{
 671				ToolCallID: toolCall.ToolCallID,
 672				ToolName:   toolCall.ToolName,
 673				Result: ToolResultOutputContentError{
 674					Error: err,
 675				},
 676				ClientMetadata:   toolResult.Metadata,
 677				ProviderExecuted: false,
 678			}
 679			if toolResultCallback != nil {
 680				if cbErr := toolResultCallback(result); cbErr != nil {
 681					return nil, cbErr
 682				}
 683			}
 684			return nil, err
 685		}
 686
 687		var result ToolResultContent
 688		if toolResult.IsError {
 689			result = ToolResultContent{
 690				ToolCallID: toolCall.ToolCallID,
 691				ToolName:   toolCall.ToolName,
 692				Result: ToolResultOutputContentError{
 693					Error: errors.New(toolResult.Content),
 694				},
 695				ClientMetadata:   toolResult.Metadata,
 696				ProviderExecuted: false,
 697			}
 698		} else {
 699			result = ToolResultContent{
 700				ToolCallID: toolCall.ToolCallID,
 701				ToolName:   toolCall.ToolName,
 702				Result: ToolResultOutputContentText{
 703					Text: toolResult.Content,
 704				},
 705				ClientMetadata:   toolResult.Metadata,
 706				ProviderExecuted: false,
 707			}
 708		}
 709		results = append(results, result)
 710		if toolResultCallback != nil {
 711			if err := toolResultCallback(result); err != nil {
 712				return nil, err
 713			}
 714		}
 715	}
 716
 717	return results, nil
 718}
 719
 720// Stream implements Agent.
 721func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 722	// Convert AgentStreamCall to AgentCall for preparation
 723	call := AgentCall{
 724		Prompt:           opts.Prompt,
 725		Files:            opts.Files,
 726		Messages:         opts.Messages,
 727		MaxOutputTokens:  opts.MaxOutputTokens,
 728		Temperature:      opts.Temperature,
 729		TopP:             opts.TopP,
 730		TopK:             opts.TopK,
 731		PresencePenalty:  opts.PresencePenalty,
 732		FrequencyPenalty: opts.FrequencyPenalty,
 733		ActiveTools:      opts.ActiveTools,
 734		ProviderOptions:  opts.ProviderOptions,
 735		MaxRetries:       opts.MaxRetries,
 736		StopWhen:         opts.StopWhen,
 737		PrepareStep:      opts.PrepareStep,
 738		RepairToolCall:   opts.RepairToolCall,
 739	}
 740
 741	call = a.prepareCall(call)
 742
 743	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 744	if err != nil {
 745		return nil, err
 746	}
 747
 748	var responseMessages []Message
 749	var steps []StepResult
 750	var totalUsage Usage
 751
 752	// Start agent stream
 753	if opts.OnAgentStart != nil {
 754		opts.OnAgentStart()
 755	}
 756
 757	for stepNumber := 0; ; stepNumber++ {
 758		stepInputMessages := append(initialPrompt, responseMessages...)
 759		stepModel := a.settings.model
 760		stepSystemPrompt := a.settings.systemPrompt
 761		stepActiveTools := call.ActiveTools
 762		stepToolChoice := ToolChoiceAuto
 763		disableAllTools := false
 764
 765		// Apply step preparation if provided
 766		if call.PrepareStep != nil {
 767			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 768				Model:      stepModel,
 769				Steps:      steps,
 770				StepNumber: stepNumber,
 771				Messages:   stepInputMessages,
 772			})
 773			if err != nil {
 774				return nil, err
 775			}
 776
 777			ctx = updatedCtx
 778
 779			if prepared.Messages != nil {
 780				stepInputMessages = prepared.Messages
 781			}
 782			if prepared.Model != nil {
 783				stepModel = prepared.Model
 784			}
 785			if prepared.System != nil {
 786				stepSystemPrompt = *prepared.System
 787			}
 788			if prepared.ToolChoice != nil {
 789				stepToolChoice = *prepared.ToolChoice
 790			}
 791			if len(prepared.ActiveTools) > 0 {
 792				stepActiveTools = prepared.ActiveTools
 793			}
 794			disableAllTools = prepared.DisableAllTools
 795		}
 796
 797		// Recreate prompt with potentially modified system prompt
 798		if stepSystemPrompt != a.settings.systemPrompt {
 799			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 800			if err != nil {
 801				return nil, err
 802			}
 803			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 804				stepInputMessages[0] = stepPrompt[0]
 805			}
 806		}
 807
 808		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 809
 810		// Start step stream
 811		if opts.OnStepStart != nil {
 812			_ = opts.OnStepStart(stepNumber)
 813		}
 814
 815		// Create streaming call
 816		streamCall := Call{
 817			Prompt:           stepInputMessages,
 818			MaxOutputTokens:  call.MaxOutputTokens,
 819			Temperature:      call.Temperature,
 820			TopP:             call.TopP,
 821			TopK:             call.TopK,
 822			PresencePenalty:  call.PresencePenalty,
 823			FrequencyPenalty: call.FrequencyPenalty,
 824			Tools:            preparedTools,
 825			ToolChoice:       &stepToolChoice,
 826			ProviderOptions:  call.ProviderOptions,
 827		}
 828
 829		// Get streaming response
 830		stream, err := stepModel.Stream(ctx, streamCall)
 831		if err != nil {
 832			if opts.OnError != nil {
 833				opts.OnError(err)
 834			}
 835			return nil, err
 836		}
 837
 838		// Process stream with tool execution
 839		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 840		if err != nil {
 841			if opts.OnError != nil {
 842				opts.OnError(err)
 843			}
 844			return nil, err
 845		}
 846
 847		steps = append(steps, stepResult)
 848		totalUsage = addUsage(totalUsage, stepResult.Usage)
 849
 850		// Call step finished callback
 851		if opts.OnStepFinish != nil {
 852			_ = opts.OnStepFinish(stepResult)
 853		}
 854
 855		// Add step messages to response messages
 856		stepMessages := toResponseMessages(stepResult.Content)
 857		responseMessages = append(responseMessages, stepMessages...)
 858
 859		// Check stop conditions
 860		shouldStop := isStopConditionMet(call.StopWhen, steps)
 861		if shouldStop || !shouldContinue {
 862			break
 863		}
 864	}
 865
 866	// Finish agent stream
 867	agentResult := &AgentResult{
 868		Steps:      steps,
 869		Response:   steps[len(steps)-1].Response,
 870		TotalUsage: totalUsage,
 871	}
 872
 873	if opts.OnFinish != nil {
 874		opts.OnFinish(agentResult)
 875	}
 876
 877	if opts.OnAgentFinish != nil {
 878		_ = opts.OnAgentFinish(agentResult)
 879	}
 880
 881	return agentResult, nil
 882}
 883
 884func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 885	preparedTools := make([]Tool, 0, len(tools))
 886
 887	// If explicitly disabling all tools, return no tools
 888	if disableAllTools {
 889		return preparedTools
 890	}
 891
 892	for _, tool := range tools {
 893		// If activeTools has items, only include tools in the list
 894		// If activeTools is empty, include all tools
 895		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 896			continue
 897		}
 898		info := tool.Info()
 899		preparedTools = append(preparedTools, FunctionTool{
 900			Name:        info.Name,
 901			Description: info.Description,
 902			InputSchema: map[string]any{
 903				"type":       "object",
 904				"properties": info.Parameters,
 905				"required":   info.Required,
 906			},
 907			ProviderOptions: tool.ProviderOptions(),
 908		})
 909	}
 910	return preparedTools
 911}
 912
 913// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 914func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 915	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 916		return toolCall
 917	} else { //nolint: revive
 918		if repairFunc != nil {
 919			repairOptions := ToolCallRepairOptions{
 920				OriginalToolCall: toolCall,
 921				ValidationError:  err,
 922				AvailableTools:   availableTools,
 923				SystemPrompt:     systemPrompt,
 924				Messages:         messages,
 925			}
 926
 927			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 928				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 929					return *repairedToolCall
 930				}
 931			}
 932		}
 933
 934		invalidToolCall := toolCall
 935		invalidToolCall.Invalid = true
 936		invalidToolCall.ValidationError = err
 937		return invalidToolCall
 938	}
 939}
 940
 941// validateToolCall validates a tool call against available tools and their schemas.
 942func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 943	var tool AgentTool
 944	for _, t := range availableTools {
 945		if t.Info().Name == toolCall.ToolName {
 946			tool = t
 947			break
 948		}
 949	}
 950
 951	if tool == nil {
 952		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 953	}
 954
 955	// Validate JSON parsing
 956	var input map[string]any
 957	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 958		return fmt.Errorf("invalid JSON input: %w", err)
 959	}
 960
 961	// Basic schema validation (check required fields)
 962	// TODO: more robust schema validation using JSON Schema or similar
 963	toolInfo := tool.Info()
 964	for _, required := range toolInfo.Required {
 965		if _, exists := input[required]; !exists {
 966			return fmt.Errorf("missing required parameter: %s", required)
 967		}
 968	}
 969	return nil
 970}
 971
 972func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 973	if prompt == "" {
 974		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 975	}
 976
 977	var preparedPrompt Prompt
 978
 979	if system != "" {
 980		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 981	}
 982	preparedPrompt = append(preparedPrompt, messages...)
 983	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 984	return preparedPrompt, nil
 985}
 986
 987// WithSystemPrompt sets the system prompt for the agent.
 988func WithSystemPrompt(prompt string) AgentOption {
 989	return func(s *agentSettings) {
 990		s.systemPrompt = prompt
 991	}
 992}
 993
 994// WithMaxOutputTokens sets the maximum output tokens for the agent.
 995func WithMaxOutputTokens(tokens int64) AgentOption {
 996	return func(s *agentSettings) {
 997		s.maxOutputTokens = &tokens
 998	}
 999}
1000
1001// WithTemperature sets the temperature for the agent.
1002func WithTemperature(temp float64) AgentOption {
1003	return func(s *agentSettings) {
1004		s.temperature = &temp
1005	}
1006}
1007
1008// WithTopP sets the top-p value for the agent.
1009func WithTopP(topP float64) AgentOption {
1010	return func(s *agentSettings) {
1011		s.topP = &topP
1012	}
1013}
1014
1015// WithTopK sets the top-k value for the agent.
1016func WithTopK(topK int64) AgentOption {
1017	return func(s *agentSettings) {
1018		s.topK = &topK
1019	}
1020}
1021
1022// WithPresencePenalty sets the presence penalty for the agent.
1023func WithPresencePenalty(penalty float64) AgentOption {
1024	return func(s *agentSettings) {
1025		s.presencePenalty = &penalty
1026	}
1027}
1028
1029// WithFrequencyPenalty sets the frequency penalty for the agent.
1030func WithFrequencyPenalty(penalty float64) AgentOption {
1031	return func(s *agentSettings) {
1032		s.frequencyPenalty = &penalty
1033	}
1034}
1035
1036// WithTools sets the tools for the agent.
1037func WithTools(tools ...AgentTool) AgentOption {
1038	return func(s *agentSettings) {
1039		s.tools = append(s.tools, tools...)
1040	}
1041}
1042
1043// WithStopConditions sets the stop conditions for the agent.
1044func WithStopConditions(conditions ...StopCondition) AgentOption {
1045	return func(s *agentSettings) {
1046		s.stopWhen = append(s.stopWhen, conditions...)
1047	}
1048}
1049
1050// WithPrepareStep sets the prepare step function for the agent.
1051func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1052	return func(s *agentSettings) {
1053		s.prepareStep = fn
1054	}
1055}
1056
1057// WithRepairToolCall sets the repair tool call function for the agent.
1058func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1059	return func(s *agentSettings) {
1060		s.repairToolCall = fn
1061	}
1062}
1063
1064// processStepStream processes a single step's stream and returns the step result.
1065func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1066	var stepContent []Content
1067	var stepToolCalls []ToolCallContent
1068	var stepUsage Usage
1069	stepFinishReason := FinishReasonUnknown
1070	var stepWarnings []CallWarning
1071	var stepProviderMetadata ProviderMetadata
1072
1073	activeToolCalls := make(map[string]*ToolCallContent)
1074	activeTextContent := make(map[string]string)
1075	type reasoningContent struct {
1076		content string
1077		options ProviderMetadata
1078	}
1079	activeReasoningContent := make(map[string]reasoningContent)
1080
1081	// Process stream parts
1082	for part := range stream {
1083		// Forward all parts to chunk callback
1084		if opts.OnChunk != nil {
1085			err := opts.OnChunk(part)
1086			if err != nil {
1087				return StepResult{}, false, err
1088			}
1089		}
1090
1091		switch part.Type {
1092		case StreamPartTypeWarnings:
1093			stepWarnings = part.Warnings
1094			if opts.OnWarnings != nil {
1095				err := opts.OnWarnings(part.Warnings)
1096				if err != nil {
1097					return StepResult{}, false, err
1098				}
1099			}
1100
1101		case StreamPartTypeTextStart:
1102			activeTextContent[part.ID] = ""
1103			if opts.OnTextStart != nil {
1104				err := opts.OnTextStart(part.ID)
1105				if err != nil {
1106					return StepResult{}, false, err
1107				}
1108			}
1109
1110		case StreamPartTypeTextDelta:
1111			if _, exists := activeTextContent[part.ID]; exists {
1112				activeTextContent[part.ID] += part.Delta
1113			}
1114			if opts.OnTextDelta != nil {
1115				err := opts.OnTextDelta(part.ID, part.Delta)
1116				if err != nil {
1117					return StepResult{}, false, err
1118				}
1119			}
1120
1121		case StreamPartTypeTextEnd:
1122			if text, exists := activeTextContent[part.ID]; exists {
1123				stepContent = append(stepContent, TextContent{
1124					Text:             text,
1125					ProviderMetadata: part.ProviderMetadata,
1126				})
1127				delete(activeTextContent, part.ID)
1128			}
1129			if opts.OnTextEnd != nil {
1130				err := opts.OnTextEnd(part.ID)
1131				if err != nil {
1132					return StepResult{}, false, err
1133				}
1134			}
1135
1136		case StreamPartTypeReasoningStart:
1137			activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1138			if opts.OnReasoningStart != nil {
1139				content := ReasoningContent{
1140					Text:             part.Delta,
1141					ProviderMetadata: part.ProviderMetadata,
1142				}
1143				err := opts.OnReasoningStart(part.ID, content)
1144				if err != nil {
1145					return StepResult{}, false, err
1146				}
1147			}
1148
1149		case StreamPartTypeReasoningDelta:
1150			if active, exists := activeReasoningContent[part.ID]; exists {
1151				active.content += part.Delta
1152				active.options = part.ProviderMetadata
1153				activeReasoningContent[part.ID] = active
1154			}
1155			if opts.OnReasoningDelta != nil {
1156				err := opts.OnReasoningDelta(part.ID, part.Delta)
1157				if err != nil {
1158					return StepResult{}, false, err
1159				}
1160			}
1161
1162		case StreamPartTypeReasoningEnd:
1163			if active, exists := activeReasoningContent[part.ID]; exists {
1164				if part.ProviderMetadata != nil {
1165					active.options = part.ProviderMetadata
1166				}
1167				content := ReasoningContent{
1168					Text:             active.content,
1169					ProviderMetadata: active.options,
1170				}
1171				stepContent = append(stepContent, content)
1172				if opts.OnReasoningEnd != nil {
1173					err := opts.OnReasoningEnd(part.ID, content)
1174					if err != nil {
1175						return StepResult{}, false, err
1176					}
1177				}
1178				delete(activeReasoningContent, part.ID)
1179			}
1180
1181		case StreamPartTypeToolInputStart:
1182			activeToolCalls[part.ID] = &ToolCallContent{
1183				ToolCallID:       part.ID,
1184				ToolName:         part.ToolCallName,
1185				Input:            "",
1186				ProviderExecuted: part.ProviderExecuted,
1187			}
1188			if opts.OnToolInputStart != nil {
1189				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1190				if err != nil {
1191					return StepResult{}, false, err
1192				}
1193			}
1194
1195		case StreamPartTypeToolInputDelta:
1196			if toolCall, exists := activeToolCalls[part.ID]; exists {
1197				toolCall.Input += part.Delta
1198			}
1199			if opts.OnToolInputDelta != nil {
1200				err := opts.OnToolInputDelta(part.ID, part.Delta)
1201				if err != nil {
1202					return StepResult{}, false, err
1203				}
1204			}
1205
1206		case StreamPartTypeToolInputEnd:
1207			if opts.OnToolInputEnd != nil {
1208				err := opts.OnToolInputEnd(part.ID)
1209				if err != nil {
1210					return StepResult{}, false, err
1211				}
1212			}
1213
1214		case StreamPartTypeToolCall:
1215			toolCall := ToolCallContent{
1216				ToolCallID:       part.ID,
1217				ToolName:         part.ToolCallName,
1218				Input:            part.ToolCallInput,
1219				ProviderExecuted: part.ProviderExecuted,
1220				ProviderMetadata: part.ProviderMetadata,
1221			}
1222
1223			// Validate and potentially repair the tool call
1224			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1225			stepToolCalls = append(stepToolCalls, validatedToolCall)
1226			stepContent = append(stepContent, validatedToolCall)
1227
1228			if opts.OnToolCall != nil {
1229				err := opts.OnToolCall(validatedToolCall)
1230				if err != nil {
1231					return StepResult{}, false, err
1232				}
1233			}
1234
1235			// Clean up active tool call
1236			delete(activeToolCalls, part.ID)
1237
1238		case StreamPartTypeSource:
1239			sourceContent := SourceContent{
1240				SourceType:       part.SourceType,
1241				ID:               part.ID,
1242				URL:              part.URL,
1243				Title:            part.Title,
1244				ProviderMetadata: part.ProviderMetadata,
1245			}
1246			stepContent = append(stepContent, sourceContent)
1247			if opts.OnSource != nil {
1248				err := opts.OnSource(sourceContent)
1249				if err != nil {
1250					return StepResult{}, false, err
1251				}
1252			}
1253
1254		case StreamPartTypeFinish:
1255			stepUsage = part.Usage
1256			stepFinishReason = part.FinishReason
1257			stepProviderMetadata = part.ProviderMetadata
1258			if opts.OnStreamFinish != nil {
1259				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1260				if err != nil {
1261					return StepResult{}, false, err
1262				}
1263			}
1264
1265		case StreamPartTypeError:
1266			if opts.OnError != nil {
1267				opts.OnError(part.Error)
1268			}
1269			return StepResult{}, false, part.Error
1270		}
1271	}
1272
1273	// Execute tools if any
1274	var toolResults []ToolResultContent
1275	if len(stepToolCalls) > 0 {
1276		var err error
1277		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1278		if err != nil {
1279			return StepResult{}, false, err
1280		}
1281		// Add tool results to content
1282		for _, result := range toolResults {
1283			stepContent = append(stepContent, result)
1284		}
1285	}
1286
1287	stepResult := StepResult{
1288		Response: Response{
1289			Content:          stepContent,
1290			FinishReason:     stepFinishReason,
1291			Usage:            stepUsage,
1292			Warnings:         stepWarnings,
1293			ProviderMetadata: stepProviderMetadata,
1294		},
1295		Messages: toResponseMessages(stepContent),
1296	}
1297
1298	// Determine if we should continue (has tool calls and not stopped)
1299	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1300
1301	return stepResult, shouldContinue, nil
1302}
1303
1304func addUsage(a, b Usage) Usage {
1305	return Usage{
1306		InputTokens:         a.InputTokens + b.InputTokens,
1307		OutputTokens:        a.OutputTokens + b.OutputTokens,
1308		TotalTokens:         a.TotalTokens + b.TotalTokens,
1309		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1310		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1311		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1312	}
1313}
1314
1315// WithHeaders sets the headers for the agent.
1316func WithHeaders(headers map[string]string) AgentOption {
1317	return func(s *agentSettings) {
1318		s.headers = headers
1319	}
1320}
1321
1322// WithProviderOptions sets the provider options for the agent.
1323func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1324	return func(s *agentSettings) {
1325		s.providerOptions = providerOptions
1326	}
1327}