agent.go

   1package ai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"maps"
  10	"slices"
  11	"sync"
  12)
  13
  14type StepResult struct {
  15	Response
  16	Messages []Message
  17}
  18
  19type StopCondition = func(steps []StepResult) bool
  20
  21// StepCountIs returns a stop condition that stops after the specified number of steps.
  22func StepCountIs(stepCount int) StopCondition {
  23	return func(steps []StepResult) bool {
  24		return len(steps) >= stepCount
  25	}
  26}
  27
  28// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  29func HasToolCall(toolName string) StopCondition {
  30	return func(steps []StepResult) bool {
  31		if len(steps) == 0 {
  32			return false
  33		}
  34		lastStep := steps[len(steps)-1]
  35		toolCalls := lastStep.Content.ToolCalls()
  36		for _, toolCall := range toolCalls {
  37			if toolCall.ToolName == toolName {
  38				return true
  39			}
  40		}
  41		return false
  42	}
  43}
  44
  45// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  46func HasContent(contentType ContentType) StopCondition {
  47	return func(steps []StepResult) bool {
  48		if len(steps) == 0 {
  49			return false
  50		}
  51		lastStep := steps[len(steps)-1]
  52		for _, content := range lastStep.Content {
  53			if content.GetType() == contentType {
  54				return true
  55			}
  56		}
  57		return false
  58	}
  59}
  60
  61// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  62func FinishReasonIs(reason FinishReason) StopCondition {
  63	return func(steps []StepResult) bool {
  64		if len(steps) == 0 {
  65			return false
  66		}
  67		lastStep := steps[len(steps)-1]
  68		return lastStep.FinishReason == reason
  69	}
  70}
  71
  72// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  73func MaxTokensUsed(maxTokens int64) StopCondition {
  74	return func(steps []StepResult) bool {
  75		var totalTokens int64
  76		for _, step := range steps {
  77			totalTokens += step.Usage.TotalTokens
  78		}
  79		return totalTokens >= maxTokens
  80	}
  81}
  82
  83type PrepareStepFunctionOptions struct {
  84	Steps      []StepResult
  85	StepNumber int
  86	Model      LanguageModel
  87	Messages   []Message
  88}
  89
  90type PrepareStepResult struct {
  91	Model           LanguageModel
  92	Messages        []Message
  93	System          *string
  94	ToolChoice      *ToolChoice
  95	ActiveTools     []string
  96	DisableAllTools bool
  97}
  98
  99type ToolCallRepairOptions struct {
 100	OriginalToolCall ToolCallContent
 101	ValidationError  error
 102	AvailableTools   []AgentTool
 103	SystemPrompt     string
 104	Messages         []Message
 105}
 106
 107type (
 108	PrepareStepFunction    = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
 109	OnStepFinishedFunction = func(step StepResult)
 110	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 111)
 112
 113type agentSettings struct {
 114	systemPrompt     string
 115	maxOutputTokens  *int64
 116	temperature      *float64
 117	topP             *float64
 118	topK             *int64
 119	presencePenalty  *float64
 120	frequencyPenalty *float64
 121	headers          map[string]string
 122	providerOptions  ProviderOptions
 123
 124	// TODO: add support for provider tools
 125	tools      []AgentTool
 126	maxRetries *int
 127
 128	model LanguageModel
 129
 130	stopWhen       []StopCondition
 131	prepareStep    PrepareStepFunction
 132	repairToolCall RepairToolCallFunction
 133	onRetry        OnRetryCallback
 134}
 135
 136type AgentCall struct {
 137	Prompt           string     `json:"prompt"`
 138	Files            []FilePart `json:"files"`
 139	Messages         []Message  `json:"messages"`
 140	MaxOutputTokens  *int64
 141	Temperature      *float64 `json:"temperature"`
 142	TopP             *float64 `json:"top_p"`
 143	TopK             *int64   `json:"top_k"`
 144	PresencePenalty  *float64 `json:"presence_penalty"`
 145	FrequencyPenalty *float64 `json:"frequency_penalty"`
 146	ActiveTools      []string `json:"active_tools"`
 147	ProviderOptions  ProviderOptions
 148	OnRetry          OnRetryCallback
 149	MaxRetries       *int
 150
 151	StopWhen       []StopCondition
 152	PrepareStep    PrepareStepFunction
 153	RepairToolCall RepairToolCallFunction
 154}
 155
 156// Agent-level callbacks.
 157type (
 158	// OnAgentStartFunc is called when agent starts.
 159	OnAgentStartFunc func()
 160
 161	// OnAgentFinishFunc is called when agent finishes.
 162	OnAgentFinishFunc func(result *AgentResult) error
 163
 164	// OnStepStartFunc is called when a step starts.
 165	OnStepStartFunc func(stepNumber int) error
 166
 167	// OnStepFinishFunc is called when a step finishes.
 168	OnStepFinishFunc func(stepResult StepResult) error
 169
 170	// OnFinishFunc is called when entire agent completes.
 171	OnFinishFunc func(result *AgentResult)
 172
 173	// OnErrorFunc is called when an error occurs.
 174	OnErrorFunc func(error)
 175)
 176
 177// Stream part callbacks - called for each corresponding stream part type.
 178type (
 179	// OnChunkFunc is called for each stream part (catch-all).
 180	OnChunkFunc func(StreamPart) error
 181
 182	// OnWarningsFunc is called for warnings.
 183	OnWarningsFunc func(warnings []CallWarning) error
 184
 185	// OnTextStartFunc is called when text starts.
 186	OnTextStartFunc func(id string) error
 187
 188	// OnTextDeltaFunc is called for text deltas.
 189	OnTextDeltaFunc func(id, text string) error
 190
 191	// OnTextEndFunc is called when text ends.
 192	OnTextEndFunc func(id string) error
 193
 194	// OnReasoningStartFunc is called when reasoning starts.
 195	OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
 196
 197	// OnReasoningDeltaFunc is called for reasoning deltas.
 198	OnReasoningDeltaFunc func(id, text string) error
 199
 200	// OnReasoningEndFunc is called when reasoning ends.
 201	OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
 202
 203	// OnToolInputStartFunc is called when tool input starts.
 204	OnToolInputStartFunc func(id, toolName string) error
 205
 206	// OnToolInputDeltaFunc is called for tool input deltas.
 207	OnToolInputDeltaFunc func(id, delta string) error
 208
 209	// OnToolInputEndFunc is called when tool input ends.
 210	OnToolInputEndFunc func(id string) error
 211
 212	// OnToolCallFunc is called when tool call is complete.
 213	OnToolCallFunc func(toolCall ToolCallContent) error
 214
 215	// OnToolResultFunc is called when tool execution completes.
 216	OnToolResultFunc func(result ToolResultContent) error
 217
 218	// OnSourceFunc is called for source references.
 219	OnSourceFunc func(source SourceContent) error
 220
 221	// OnStreamFinishFunc is called when stream finishes.
 222	OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
 223)
 224
 225type AgentStreamCall struct {
 226	Prompt           string     `json:"prompt"`
 227	Files            []FilePart `json:"files"`
 228	Messages         []Message  `json:"messages"`
 229	MaxOutputTokens  *int64
 230	Temperature      *float64 `json:"temperature"`
 231	TopP             *float64 `json:"top_p"`
 232	TopK             *int64   `json:"top_k"`
 233	PresencePenalty  *float64 `json:"presence_penalty"`
 234	FrequencyPenalty *float64 `json:"frequency_penalty"`
 235	ActiveTools      []string `json:"active_tools"`
 236	Headers          map[string]string
 237	ProviderOptions  ProviderOptions
 238	OnRetry          OnRetryCallback
 239	MaxRetries       *int
 240
 241	StopWhen       []StopCondition
 242	PrepareStep    PrepareStepFunction
 243	RepairToolCall RepairToolCallFunction
 244
 245	// Agent-level callbacks
 246	OnAgentStart  OnAgentStartFunc  // Called when agent starts
 247	OnAgentFinish OnAgentFinishFunc // Called when agent finishes
 248	OnStepStart   OnStepStartFunc   // Called when a step starts
 249	OnStepFinish  OnStepFinishFunc  // Called when a step finishes
 250	OnFinish      OnFinishFunc      // Called when entire agent completes
 251	OnError       OnErrorFunc       // Called when an error occurs
 252
 253	// Stream part callbacks - called for each corresponding stream part type
 254	OnChunk          OnChunkFunc          // Called for each stream part (catch-all)
 255	OnWarnings       OnWarningsFunc       // Called for warnings
 256	OnTextStart      OnTextStartFunc      // Called when text starts
 257	OnTextDelta      OnTextDeltaFunc      // Called for text deltas
 258	OnTextEnd        OnTextEndFunc        // Called when text ends
 259	OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
 260	OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
 261	OnReasoningEnd   OnReasoningEndFunc   // Called when reasoning ends
 262	OnToolInputStart OnToolInputStartFunc // Called when tool input starts
 263	OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
 264	OnToolInputEnd   OnToolInputEndFunc   // Called when tool input ends
 265	OnToolCall       OnToolCallFunc       // Called when tool call is complete
 266	OnToolResult     OnToolResultFunc     // Called when tool execution completes
 267	OnSource         OnSourceFunc         // Called for source references
 268	OnStreamFinish   OnStreamFinishFunc   // Called when stream finishes
 269}
 270
 271type AgentResult struct {
 272	Steps []StepResult
 273	// Final response
 274	Response   Response
 275	TotalUsage Usage
 276}
 277
 278type Agent interface {
 279	Generate(context.Context, AgentCall) (*AgentResult, error)
 280	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 281}
 282
 283type AgentOption = func(*agentSettings)
 284
 285type agent struct {
 286	settings agentSettings
 287}
 288
 289func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
 290	settings := agentSettings{
 291		model: model,
 292	}
 293	for _, o := range opts {
 294		o(&settings)
 295	}
 296	return &agent{
 297		settings: settings,
 298	}
 299}
 300
 301func (a *agent) prepareCall(call AgentCall) AgentCall {
 302	call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
 303	call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
 304	call.TopP = cmp.Or(call.TopP, a.settings.topP)
 305	call.TopK = cmp.Or(call.TopK, a.settings.topK)
 306	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 307	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 308	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
 309
 310	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 311		call.StopWhen = a.settings.stopWhen
 312	}
 313	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 314		call.PrepareStep = a.settings.prepareStep
 315	}
 316	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 317		call.RepairToolCall = a.settings.repairToolCall
 318	}
 319	if call.OnRetry == nil && a.settings.onRetry != nil {
 320		call.OnRetry = a.settings.onRetry
 321	}
 322
 323	providerOptions := ProviderOptions{}
 324	if a.settings.providerOptions != nil {
 325		maps.Copy(providerOptions, a.settings.providerOptions)
 326	}
 327	if call.ProviderOptions != nil {
 328		maps.Copy(providerOptions, call.ProviderOptions)
 329	}
 330	call.ProviderOptions = providerOptions
 331
 332	headers := map[string]string{}
 333
 334	if a.settings.headers != nil {
 335		maps.Copy(headers, a.settings.headers)
 336	}
 337
 338	return call
 339}
 340
 341// Generate implements Agent.
 342func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 343	opts = a.prepareCall(opts)
 344	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 345	if err != nil {
 346		return nil, err
 347	}
 348	var responseMessages []Message
 349	var steps []StepResult
 350
 351	for {
 352		stepInputMessages := append(initialPrompt, responseMessages...)
 353		stepModel := a.settings.model
 354		stepSystemPrompt := a.settings.systemPrompt
 355		stepActiveTools := opts.ActiveTools
 356		stepToolChoice := ToolChoiceAuto
 357		disableAllTools := false
 358
 359		if opts.PrepareStep != nil {
 360			updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
 361				Model:      stepModel,
 362				Steps:      steps,
 363				StepNumber: len(steps),
 364				Messages:   stepInputMessages,
 365			})
 366			if err != nil {
 367				return nil, err
 368			}
 369
 370			ctx = updatedCtx
 371
 372			// Apply prepared step modifications
 373			if prepared.Messages != nil {
 374				stepInputMessages = prepared.Messages
 375			}
 376			if prepared.Model != nil {
 377				stepModel = prepared.Model
 378			}
 379			if prepared.System != nil {
 380				stepSystemPrompt = *prepared.System
 381			}
 382			if prepared.ToolChoice != nil {
 383				stepToolChoice = *prepared.ToolChoice
 384			}
 385			if len(prepared.ActiveTools) > 0 {
 386				stepActiveTools = prepared.ActiveTools
 387			}
 388			disableAllTools = prepared.DisableAllTools
 389		}
 390
 391		// Recreate prompt with potentially modified system prompt
 392		if stepSystemPrompt != a.settings.systemPrompt {
 393			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 394			if err != nil {
 395				return nil, err
 396			}
 397			// Replace system message part, keep the rest
 398			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 399				stepInputMessages[0] = stepPrompt[0] // Replace system message
 400			}
 401		}
 402
 403		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 404
 405		retryOptions := DefaultRetryOptions()
 406		retryOptions.OnRetry = opts.OnRetry
 407		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 408
 409		result, err := retry(ctx, func() (*Response, error) {
 410			return stepModel.Generate(ctx, Call{
 411				Prompt:           stepInputMessages,
 412				MaxOutputTokens:  opts.MaxOutputTokens,
 413				Temperature:      opts.Temperature,
 414				TopP:             opts.TopP,
 415				TopK:             opts.TopK,
 416				PresencePenalty:  opts.PresencePenalty,
 417				FrequencyPenalty: opts.FrequencyPenalty,
 418				Tools:            preparedTools,
 419				ToolChoice:       &stepToolChoice,
 420				ProviderOptions:  opts.ProviderOptions,
 421			})
 422		})
 423		if err != nil {
 424			return nil, err
 425		}
 426
 427		var stepToolCalls []ToolCallContent
 428		for _, content := range result.Content {
 429			if content.GetType() == ContentTypeToolCall {
 430				toolCall, ok := AsContentType[ToolCallContent](content)
 431				if !ok {
 432					continue
 433				}
 434
 435				// Validate and potentially repair the tool call
 436				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 437				stepToolCalls = append(stepToolCalls, validatedToolCall)
 438			}
 439		}
 440
 441		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 442
 443		// Build step content with validated tool calls and tool results
 444		stepContent := []Content{}
 445		toolCallIndex := 0
 446		for _, content := range result.Content {
 447			if content.GetType() == ContentTypeToolCall {
 448				// Replace with validated tool call
 449				if toolCallIndex < len(stepToolCalls) {
 450					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 451					toolCallIndex++
 452				}
 453			} else {
 454				// Keep other content as-is
 455				stepContent = append(stepContent, content)
 456			}
 457		}
 458		// Add tool results
 459		for _, result := range toolResults {
 460			stepContent = append(stepContent, result)
 461		}
 462		currentStepMessages := toResponseMessages(stepContent)
 463		responseMessages = append(responseMessages, currentStepMessages...)
 464
 465		stepResult := StepResult{
 466			Response: Response{
 467				Content:          stepContent,
 468				FinishReason:     result.FinishReason,
 469				Usage:            result.Usage,
 470				Warnings:         result.Warnings,
 471				ProviderMetadata: result.ProviderMetadata,
 472			},
 473			Messages: currentStepMessages,
 474		}
 475		steps = append(steps, stepResult)
 476		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 477
 478		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 479			break
 480		}
 481	}
 482
 483	totalUsage := Usage{}
 484
 485	for _, step := range steps {
 486		usage := step.Usage
 487		totalUsage.InputTokens += usage.InputTokens
 488		totalUsage.OutputTokens += usage.OutputTokens
 489		totalUsage.ReasoningTokens += usage.ReasoningTokens
 490		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 491		totalUsage.CacheReadTokens += usage.CacheReadTokens
 492		totalUsage.TotalTokens += usage.TotalTokens
 493	}
 494
 495	agentResult := &AgentResult{
 496		Steps:      steps,
 497		Response:   steps[len(steps)-1].Response,
 498		TotalUsage: totalUsage,
 499	}
 500	return agentResult, nil
 501}
 502
 503func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 504	if len(conditions) == 0 {
 505		return false
 506	}
 507
 508	for _, condition := range conditions {
 509		if condition(steps) {
 510			return true
 511		}
 512	}
 513	return false
 514}
 515
 516func toResponseMessages(content []Content) []Message {
 517	var assistantParts []MessagePart
 518	var toolParts []MessagePart
 519
 520	for _, c := range content {
 521		switch c.GetType() {
 522		case ContentTypeText:
 523			text, ok := AsContentType[TextContent](c)
 524			if !ok {
 525				continue
 526			}
 527			assistantParts = append(assistantParts, TextPart{
 528				Text:            text.Text,
 529				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 530			})
 531		case ContentTypeReasoning:
 532			reasoning, ok := AsContentType[ReasoningContent](c)
 533			if !ok {
 534				continue
 535			}
 536			assistantParts = append(assistantParts, ReasoningPart{
 537				Text:            reasoning.Text,
 538				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 539			})
 540		case ContentTypeToolCall:
 541			toolCall, ok := AsContentType[ToolCallContent](c)
 542			if !ok {
 543				continue
 544			}
 545			assistantParts = append(assistantParts, ToolCallPart{
 546				ToolCallID:       toolCall.ToolCallID,
 547				ToolName:         toolCall.ToolName,
 548				Input:            toolCall.Input,
 549				ProviderExecuted: toolCall.ProviderExecuted,
 550				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 551			})
 552		case ContentTypeFile:
 553			file, ok := AsContentType[FileContent](c)
 554			if !ok {
 555				continue
 556			}
 557			assistantParts = append(assistantParts, FilePart{
 558				Data:            file.Data,
 559				MediaType:       file.MediaType,
 560				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 561			})
 562		case ContentTypeSource:
 563			// Sources are metadata about references used to generate the response.
 564			// They don't need to be included in the conversation messages.
 565			continue
 566		case ContentTypeToolResult:
 567			result, ok := AsContentType[ToolResultContent](c)
 568			if !ok {
 569				continue
 570			}
 571			toolParts = append(toolParts, ToolResultPart{
 572				ToolCallID:      result.ToolCallID,
 573				Output:          result.Result,
 574				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 575			})
 576		}
 577	}
 578
 579	var messages []Message
 580	if len(assistantParts) > 0 {
 581		messages = append(messages, Message{
 582			Role:    MessageRoleAssistant,
 583			Content: assistantParts,
 584		})
 585	}
 586	if len(toolParts) > 0 {
 587		messages = append(messages, Message{
 588			Role:    MessageRoleTool,
 589			Content: toolParts,
 590		})
 591	}
 592	return messages
 593}
 594
 595func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 596	if len(toolCalls) == 0 {
 597		return nil, nil
 598	}
 599
 600	// Create a map for quick tool lookup
 601	toolMap := make(map[string]AgentTool)
 602	for _, tool := range allTools {
 603		toolMap[tool.Info().Name] = tool
 604	}
 605
 606	// Execute all tool calls in parallel
 607	results := make([]ToolResultContent, len(toolCalls))
 608	executeErrors := make([]error, len(toolCalls))
 609
 610	var wg sync.WaitGroup
 611
 612	for i, toolCall := range toolCalls {
 613		wg.Add(1)
 614		go func(index int, call ToolCallContent) {
 615			defer wg.Done()
 616
 617			// Skip invalid tool calls - create error result
 618			if call.Invalid {
 619				results[index] = ToolResultContent{
 620					ToolCallID: call.ToolCallID,
 621					ToolName:   call.ToolName,
 622					Result: ToolResultOutputContentError{
 623						Error: call.ValidationError,
 624					},
 625					ProviderExecuted: false,
 626				}
 627				if toolResultCallback != nil {
 628					err := toolResultCallback(results[index])
 629					if err != nil {
 630						executeErrors[index] = err
 631					}
 632				}
 633
 634				return
 635			}
 636
 637			tool, exists := toolMap[call.ToolName]
 638			if !exists {
 639				results[index] = ToolResultContent{
 640					ToolCallID: call.ToolCallID,
 641					ToolName:   call.ToolName,
 642					Result: ToolResultOutputContentError{
 643						Error: errors.New("Error: Tool not found: " + call.ToolName),
 644					},
 645					ProviderExecuted: false,
 646				}
 647
 648				if toolResultCallback != nil {
 649					err := toolResultCallback(results[index])
 650					if err != nil {
 651						executeErrors[index] = err
 652					}
 653				}
 654				return
 655			}
 656
 657			// Execute the tool
 658			result, err := tool.Run(ctx, ToolCall{
 659				ID:    call.ToolCallID,
 660				Name:  call.ToolName,
 661				Input: call.Input,
 662			})
 663			if err != nil {
 664				results[index] = ToolResultContent{
 665					ToolCallID: call.ToolCallID,
 666					ToolName:   call.ToolName,
 667					Result: ToolResultOutputContentError{
 668						Error: err,
 669					},
 670					ClientMetadata:   result.Metadata,
 671					ProviderExecuted: false,
 672				}
 673				if toolResultCallback != nil {
 674					cbErr := toolResultCallback(results[index])
 675					if cbErr != nil {
 676						executeErrors[index] = cbErr
 677					}
 678				}
 679				executeErrors[index] = err
 680				return
 681			}
 682
 683			if result.IsError {
 684				results[index] = ToolResultContent{
 685					ToolCallID: call.ToolCallID,
 686					ToolName:   call.ToolName,
 687					Result: ToolResultOutputContentError{
 688						Error: errors.New(result.Content),
 689					},
 690					ClientMetadata:   result.Metadata,
 691					ProviderExecuted: false,
 692				}
 693
 694				if toolResultCallback != nil {
 695					err := toolResultCallback(results[index])
 696					if err != nil {
 697						executeErrors[index] = err
 698					}
 699				}
 700			} else {
 701				results[index] = ToolResultContent{
 702					ToolCallID: call.ToolCallID,
 703					ToolName:   toolCall.ToolName,
 704					Result: ToolResultOutputContentText{
 705						Text: result.Content,
 706					},
 707					ClientMetadata:   result.Metadata,
 708					ProviderExecuted: false,
 709				}
 710				if toolResultCallback != nil {
 711					err := toolResultCallback(results[index])
 712					if err != nil {
 713						executeErrors[index] = err
 714					}
 715				}
 716			}
 717		}(i, toolCall)
 718	}
 719
 720	// Wait for all tool executions to complete
 721	wg.Wait()
 722
 723	for _, err := range executeErrors {
 724		if err != nil {
 725			return nil, err
 726		}
 727	}
 728
 729	return results, nil
 730}
 731
 732// Stream implements Agent.
 733func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 734	// Convert AgentStreamCall to AgentCall for preparation
 735	call := AgentCall{
 736		Prompt:           opts.Prompt,
 737		Files:            opts.Files,
 738		Messages:         opts.Messages,
 739		MaxOutputTokens:  opts.MaxOutputTokens,
 740		Temperature:      opts.Temperature,
 741		TopP:             opts.TopP,
 742		TopK:             opts.TopK,
 743		PresencePenalty:  opts.PresencePenalty,
 744		FrequencyPenalty: opts.FrequencyPenalty,
 745		ActiveTools:      opts.ActiveTools,
 746		ProviderOptions:  opts.ProviderOptions,
 747		MaxRetries:       opts.MaxRetries,
 748		StopWhen:         opts.StopWhen,
 749		PrepareStep:      opts.PrepareStep,
 750		RepairToolCall:   opts.RepairToolCall,
 751	}
 752
 753	call = a.prepareCall(call)
 754
 755	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 756	if err != nil {
 757		return nil, err
 758	}
 759
 760	var responseMessages []Message
 761	var steps []StepResult
 762	var totalUsage Usage
 763
 764	// Start agent stream
 765	if opts.OnAgentStart != nil {
 766		opts.OnAgentStart()
 767	}
 768
 769	for stepNumber := 0; ; stepNumber++ {
 770		stepInputMessages := append(initialPrompt, responseMessages...)
 771		stepModel := a.settings.model
 772		stepSystemPrompt := a.settings.systemPrompt
 773		stepActiveTools := call.ActiveTools
 774		stepToolChoice := ToolChoiceAuto
 775		disableAllTools := false
 776
 777		// Apply step preparation if provided
 778		if call.PrepareStep != nil {
 779			updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
 780				Model:      stepModel,
 781				Steps:      steps,
 782				StepNumber: stepNumber,
 783				Messages:   stepInputMessages,
 784			})
 785			if err != nil {
 786				return nil, err
 787			}
 788
 789			ctx = updatedCtx
 790
 791			if prepared.Messages != nil {
 792				stepInputMessages = prepared.Messages
 793			}
 794			if prepared.Model != nil {
 795				stepModel = prepared.Model
 796			}
 797			if prepared.System != nil {
 798				stepSystemPrompt = *prepared.System
 799			}
 800			if prepared.ToolChoice != nil {
 801				stepToolChoice = *prepared.ToolChoice
 802			}
 803			if len(prepared.ActiveTools) > 0 {
 804				stepActiveTools = prepared.ActiveTools
 805			}
 806			disableAllTools = prepared.DisableAllTools
 807		}
 808
 809		// Recreate prompt with potentially modified system prompt
 810		if stepSystemPrompt != a.settings.systemPrompt {
 811			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 812			if err != nil {
 813				return nil, err
 814			}
 815			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 816				stepInputMessages[0] = stepPrompt[0]
 817			}
 818		}
 819
 820		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 821
 822		// Start step stream
 823		if opts.OnStepStart != nil {
 824			_ = opts.OnStepStart(stepNumber)
 825		}
 826
 827		// Create streaming call
 828		streamCall := Call{
 829			Prompt:           stepInputMessages,
 830			MaxOutputTokens:  call.MaxOutputTokens,
 831			Temperature:      call.Temperature,
 832			TopP:             call.TopP,
 833			TopK:             call.TopK,
 834			PresencePenalty:  call.PresencePenalty,
 835			FrequencyPenalty: call.FrequencyPenalty,
 836			Tools:            preparedTools,
 837			ToolChoice:       &stepToolChoice,
 838			ProviderOptions:  call.ProviderOptions,
 839		}
 840
 841		// Get streaming response
 842		stream, err := stepModel.Stream(ctx, streamCall)
 843		if err != nil {
 844			if opts.OnError != nil {
 845				opts.OnError(err)
 846			}
 847			return nil, err
 848		}
 849
 850		// Process stream with tool execution
 851		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 852		if err != nil {
 853			if opts.OnError != nil {
 854				opts.OnError(err)
 855			}
 856			return nil, err
 857		}
 858
 859		steps = append(steps, stepResult)
 860		totalUsage = addUsage(totalUsage, stepResult.Usage)
 861
 862		// Call step finished callback
 863		if opts.OnStepFinish != nil {
 864			_ = opts.OnStepFinish(stepResult)
 865		}
 866
 867		// Add step messages to response messages
 868		stepMessages := toResponseMessages(stepResult.Content)
 869		responseMessages = append(responseMessages, stepMessages...)
 870
 871		// Check stop conditions
 872		shouldStop := isStopConditionMet(call.StopWhen, steps)
 873		if shouldStop || !shouldContinue {
 874			break
 875		}
 876	}
 877
 878	// Finish agent stream
 879	agentResult := &AgentResult{
 880		Steps:      steps,
 881		Response:   steps[len(steps)-1].Response,
 882		TotalUsage: totalUsage,
 883	}
 884
 885	if opts.OnFinish != nil {
 886		opts.OnFinish(agentResult)
 887	}
 888
 889	if opts.OnAgentFinish != nil {
 890		_ = opts.OnAgentFinish(agentResult)
 891	}
 892
 893	return agentResult, nil
 894}
 895
 896func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
 897	preparedTools := make([]Tool, 0, len(tools))
 898
 899	// If explicitly disabling all tools, return no tools
 900	if disableAllTools {
 901		return preparedTools
 902	}
 903
 904	for _, tool := range tools {
 905		// If activeTools has items, only include tools in the list
 906		// If activeTools is empty, include all tools
 907		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 908			continue
 909		}
 910		info := tool.Info()
 911		preparedTools = append(preparedTools, FunctionTool{
 912			Name:        info.Name,
 913			Description: info.Description,
 914			InputSchema: map[string]any{
 915				"type":       "object",
 916				"properties": info.Parameters,
 917				"required":   info.Required,
 918			},
 919			ProviderOptions: tool.ProviderOptions(),
 920		})
 921	}
 922	return preparedTools
 923}
 924
 925// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
 926func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 927	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 928		return toolCall
 929	} else {
 930		if repairFunc != nil {
 931			repairOptions := ToolCallRepairOptions{
 932				OriginalToolCall: toolCall,
 933				ValidationError:  err,
 934				AvailableTools:   availableTools,
 935				SystemPrompt:     systemPrompt,
 936				Messages:         messages,
 937			}
 938
 939			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 940				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 941					return *repairedToolCall
 942				}
 943			}
 944		}
 945
 946		invalidToolCall := toolCall
 947		invalidToolCall.Invalid = true
 948		invalidToolCall.ValidationError = err
 949		return invalidToolCall
 950	}
 951}
 952
 953// validateToolCall validates a tool call against available tools and their schemas.
 954func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
 955	var tool AgentTool
 956	for _, t := range availableTools {
 957		if t.Info().Name == toolCall.ToolName {
 958			tool = t
 959			break
 960		}
 961	}
 962
 963	if tool == nil {
 964		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 965	}
 966
 967	// Validate JSON parsing
 968	var input map[string]any
 969	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 970		return fmt.Errorf("invalid JSON input: %w", err)
 971	}
 972
 973	// Basic schema validation (check required fields)
 974	// TODO: more robust schema validation using JSON Schema or similar
 975	toolInfo := tool.Info()
 976	for _, required := range toolInfo.Required {
 977		if _, exists := input[required]; !exists {
 978			return fmt.Errorf("missing required parameter: %s", required)
 979		}
 980	}
 981	return nil
 982}
 983
 984func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 985	if prompt == "" {
 986		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 987	}
 988
 989	var preparedPrompt Prompt
 990
 991	if system != "" {
 992		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 993	}
 994	preparedPrompt = append(preparedPrompt, messages...)
 995	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 996	return preparedPrompt, nil
 997}
 998
 999func WithSystemPrompt(prompt string) AgentOption {
1000	return func(s *agentSettings) {
1001		s.systemPrompt = prompt
1002	}
1003}
1004
1005func WithMaxOutputTokens(tokens int64) AgentOption {
1006	return func(s *agentSettings) {
1007		s.maxOutputTokens = &tokens
1008	}
1009}
1010
1011func WithTemperature(temp float64) AgentOption {
1012	return func(s *agentSettings) {
1013		s.temperature = &temp
1014	}
1015}
1016
1017func WithTopP(topP float64) AgentOption {
1018	return func(s *agentSettings) {
1019		s.topP = &topP
1020	}
1021}
1022
1023func WithTopK(topK int64) AgentOption {
1024	return func(s *agentSettings) {
1025		s.topK = &topK
1026	}
1027}
1028
1029func WithPresencePenalty(penalty float64) AgentOption {
1030	return func(s *agentSettings) {
1031		s.presencePenalty = &penalty
1032	}
1033}
1034
1035func WithFrequencyPenalty(penalty float64) AgentOption {
1036	return func(s *agentSettings) {
1037		s.frequencyPenalty = &penalty
1038	}
1039}
1040
1041func WithTools(tools ...AgentTool) AgentOption {
1042	return func(s *agentSettings) {
1043		s.tools = append(s.tools, tools...)
1044	}
1045}
1046
1047func WithStopConditions(conditions ...StopCondition) AgentOption {
1048	return func(s *agentSettings) {
1049		s.stopWhen = append(s.stopWhen, conditions...)
1050	}
1051}
1052
1053func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1054	return func(s *agentSettings) {
1055		s.prepareStep = fn
1056	}
1057}
1058
1059func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1060	return func(s *agentSettings) {
1061		s.repairToolCall = fn
1062	}
1063}
1064
1065// processStepStream processes a single step's stream and returns the step result.
1066func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1067	var stepContent []Content
1068	var stepToolCalls []ToolCallContent
1069	var stepUsage Usage
1070	stepFinishReason := FinishReasonUnknown
1071	var stepWarnings []CallWarning
1072	var stepProviderMetadata ProviderMetadata
1073
1074	activeToolCalls := make(map[string]*ToolCallContent)
1075	activeTextContent := make(map[string]string)
1076	type reasoningContent struct {
1077		content string
1078		options ProviderMetadata
1079	}
1080	activeReasoningContent := make(map[string]reasoningContent)
1081
1082	// Process stream parts
1083	for part := range stream {
1084		// Forward all parts to chunk callback
1085		if opts.OnChunk != nil {
1086			err := opts.OnChunk(part)
1087			if err != nil {
1088				return StepResult{}, false, err
1089			}
1090		}
1091
1092		switch part.Type {
1093		case StreamPartTypeWarnings:
1094			stepWarnings = part.Warnings
1095			if opts.OnWarnings != nil {
1096				err := opts.OnWarnings(part.Warnings)
1097				if err != nil {
1098					return StepResult{}, false, err
1099				}
1100			}
1101
1102		case StreamPartTypeTextStart:
1103			activeTextContent[part.ID] = ""
1104			if opts.OnTextStart != nil {
1105				err := opts.OnTextStart(part.ID)
1106				if err != nil {
1107					return StepResult{}, false, err
1108				}
1109			}
1110
1111		case StreamPartTypeTextDelta:
1112			if _, exists := activeTextContent[part.ID]; exists {
1113				activeTextContent[part.ID] += part.Delta
1114			}
1115			if opts.OnTextDelta != nil {
1116				err := opts.OnTextDelta(part.ID, part.Delta)
1117				if err != nil {
1118					return StepResult{}, false, err
1119				}
1120			}
1121
1122		case StreamPartTypeTextEnd:
1123			if text, exists := activeTextContent[part.ID]; exists {
1124				stepContent = append(stepContent, TextContent{
1125					Text:             text,
1126					ProviderMetadata: part.ProviderMetadata,
1127				})
1128				delete(activeTextContent, part.ID)
1129			}
1130			if opts.OnTextEnd != nil {
1131				err := opts.OnTextEnd(part.ID)
1132				if err != nil {
1133					return StepResult{}, false, err
1134				}
1135			}
1136
1137		case StreamPartTypeReasoningStart:
1138			activeReasoningContent[part.ID] = reasoningContent{content: "", options: part.ProviderMetadata}
1139			if opts.OnReasoningStart != nil {
1140				content := ReasoningContent{
1141					Text:             "",
1142					ProviderMetadata: part.ProviderMetadata,
1143				}
1144				err := opts.OnReasoningStart(part.ID, content)
1145				if err != nil {
1146					return StepResult{}, false, err
1147				}
1148			}
1149
1150		case StreamPartTypeReasoningDelta:
1151			if active, exists := activeReasoningContent[part.ID]; exists {
1152				active.content += part.Delta
1153				active.options = part.ProviderMetadata
1154				activeReasoningContent[part.ID] = active
1155			}
1156			if opts.OnReasoningDelta != nil {
1157				err := opts.OnReasoningDelta(part.ID, part.Delta)
1158				if err != nil {
1159					return StepResult{}, false, err
1160				}
1161			}
1162
1163		case StreamPartTypeReasoningEnd:
1164			if active, exists := activeReasoningContent[part.ID]; exists {
1165				if part.ProviderMetadata != nil {
1166					active.options = part.ProviderMetadata
1167				}
1168				content := ReasoningContent{
1169					Text:             active.content,
1170					ProviderMetadata: active.options,
1171				}
1172				stepContent = append(stepContent, content)
1173				if opts.OnReasoningEnd != nil {
1174					err := opts.OnReasoningEnd(part.ID, content)
1175					if err != nil {
1176						return StepResult{}, false, err
1177					}
1178				}
1179				delete(activeReasoningContent, part.ID)
1180			}
1181
1182		case StreamPartTypeToolInputStart:
1183			activeToolCalls[part.ID] = &ToolCallContent{
1184				ToolCallID:       part.ID,
1185				ToolName:         part.ToolCallName,
1186				Input:            "",
1187				ProviderExecuted: part.ProviderExecuted,
1188			}
1189			if opts.OnToolInputStart != nil {
1190				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1191				if err != nil {
1192					return StepResult{}, false, err
1193				}
1194			}
1195
1196		case StreamPartTypeToolInputDelta:
1197			if toolCall, exists := activeToolCalls[part.ID]; exists {
1198				toolCall.Input += part.Delta
1199			}
1200			if opts.OnToolInputDelta != nil {
1201				err := opts.OnToolInputDelta(part.ID, part.Delta)
1202				if err != nil {
1203					return StepResult{}, false, err
1204				}
1205			}
1206
1207		case StreamPartTypeToolInputEnd:
1208			if opts.OnToolInputEnd != nil {
1209				err := opts.OnToolInputEnd(part.ID)
1210				if err != nil {
1211					return StepResult{}, false, err
1212				}
1213			}
1214
1215		case StreamPartTypeToolCall:
1216			toolCall := ToolCallContent{
1217				ToolCallID:       part.ID,
1218				ToolName:         part.ToolCallName,
1219				Input:            part.ToolCallInput,
1220				ProviderExecuted: part.ProviderExecuted,
1221				ProviderMetadata: part.ProviderMetadata,
1222			}
1223
1224			// Validate and potentially repair the tool call
1225			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1226			stepToolCalls = append(stepToolCalls, validatedToolCall)
1227			stepContent = append(stepContent, validatedToolCall)
1228
1229			if opts.OnToolCall != nil {
1230				err := opts.OnToolCall(validatedToolCall)
1231				if err != nil {
1232					return StepResult{}, false, err
1233				}
1234			}
1235
1236			// Clean up active tool call
1237			delete(activeToolCalls, part.ID)
1238
1239		case StreamPartTypeSource:
1240			sourceContent := SourceContent{
1241				SourceType:       part.SourceType,
1242				ID:               part.ID,
1243				URL:              part.URL,
1244				Title:            part.Title,
1245				ProviderMetadata: part.ProviderMetadata,
1246			}
1247			stepContent = append(stepContent, sourceContent)
1248			if opts.OnSource != nil {
1249				err := opts.OnSource(sourceContent)
1250				if err != nil {
1251					return StepResult{}, false, err
1252				}
1253			}
1254
1255		case StreamPartTypeFinish:
1256			stepUsage = part.Usage
1257			stepFinishReason = part.FinishReason
1258			stepProviderMetadata = part.ProviderMetadata
1259			if opts.OnStreamFinish != nil {
1260				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1261				if err != nil {
1262					return StepResult{}, false, err
1263				}
1264			}
1265
1266		case StreamPartTypeError:
1267			if opts.OnError != nil {
1268				opts.OnError(part.Error)
1269			}
1270			return StepResult{}, false, part.Error
1271		}
1272	}
1273
1274	// Execute tools if any
1275	var toolResults []ToolResultContent
1276	if len(stepToolCalls) > 0 {
1277		var err error
1278		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1279		if err != nil {
1280			return StepResult{}, false, err
1281		}
1282		// Add tool results to content
1283		for _, result := range toolResults {
1284			stepContent = append(stepContent, result)
1285		}
1286	}
1287
1288	stepResult := StepResult{
1289		Response: Response{
1290			Content:          stepContent,
1291			FinishReason:     stepFinishReason,
1292			Usage:            stepUsage,
1293			Warnings:         stepWarnings,
1294			ProviderMetadata: stepProviderMetadata,
1295		},
1296		Messages: toResponseMessages(stepContent),
1297	}
1298
1299	// Determine if we should continue (has tool calls and not stopped)
1300	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1301
1302	return stepResult, shouldContinue, nil
1303}
1304
1305func addUsage(a, b Usage) Usage {
1306	return Usage{
1307		InputTokens:         a.InputTokens + b.InputTokens,
1308		OutputTokens:        a.OutputTokens + b.OutputTokens,
1309		TotalTokens:         a.TotalTokens + b.TotalTokens,
1310		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1311		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1312		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1313	}
1314}
1315
1316func WithHeaders(headers map[string]string) AgentOption {
1317	return func(s *agentSettings) {
1318		s.headers = headers
1319	}
1320}
1321
1322func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1323	return func(s *agentSettings) {
1324		s.providerOptions = providerOptions
1325	}
1326}