agent.go

   1package ai
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"maps"
   9	"slices"
  10	"sync"
  11
  12	"github.com/charmbracelet/crush/internal/llm/tools"
  13)
  14
  15type StepResult struct {
  16	Response
  17	Messages []Message
  18}
  19
  20type StopCondition = func(steps []StepResult) bool
  21
  22// StepCountIs returns a stop condition that stops after the specified number of steps.
  23func StepCountIs(stepCount int) StopCondition {
  24	return func(steps []StepResult) bool {
  25		return len(steps) >= stepCount
  26	}
  27}
  28
  29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
  30func HasToolCall(toolName string) StopCondition {
  31	return func(steps []StepResult) bool {
  32		if len(steps) == 0 {
  33			return false
  34		}
  35		lastStep := steps[len(steps)-1]
  36		toolCalls := lastStep.Content.ToolCalls()
  37		for _, toolCall := range toolCalls {
  38			if toolCall.ToolName == toolName {
  39				return true
  40			}
  41		}
  42		return false
  43	}
  44}
  45
  46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
  47func HasContent(contentType ContentType) StopCondition {
  48	return func(steps []StepResult) bool {
  49		if len(steps) == 0 {
  50			return false
  51		}
  52		lastStep := steps[len(steps)-1]
  53		for _, content := range lastStep.Content {
  54			if content.GetType() == contentType {
  55				return true
  56			}
  57		}
  58		return false
  59	}
  60}
  61
  62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
  63func FinishReasonIs(reason FinishReason) StopCondition {
  64	return func(steps []StepResult) bool {
  65		if len(steps) == 0 {
  66			return false
  67		}
  68		lastStep := steps[len(steps)-1]
  69		return lastStep.FinishReason == reason
  70	}
  71}
  72
  73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
  74func MaxTokensUsed(maxTokens int64) StopCondition {
  75	return func(steps []StepResult) bool {
  76		var totalTokens int64
  77		for _, step := range steps {
  78			totalTokens += step.Usage.TotalTokens
  79		}
  80		return totalTokens >= maxTokens
  81	}
  82}
  83
  84type PrepareStepFunctionOptions struct {
  85	Steps      []StepResult
  86	StepNumber int
  87	Model      LanguageModel
  88	Messages   []Message
  89}
  90
  91type PrepareStepResult struct {
  92	Model           LanguageModel
  93	Messages        []Message
  94	System          *string
  95	ToolChoice      *ToolChoice
  96	ActiveTools     []string
  97	DisableAllTools bool
  98}
  99
 100type ToolCallRepairOptions struct {
 101	OriginalToolCall ToolCallContent
 102	ValidationError  error
 103	AvailableTools   []tools.BaseTool
 104	SystemPrompt     string
 105	Messages         []Message
 106}
 107
 108type (
 109	PrepareStepFunction    = func(options PrepareStepFunctionOptions) PrepareStepResult
 110	OnStepFinishedFunction = func(step StepResult)
 111	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 112)
 113
 114type AgentSettings struct {
 115	systemPrompt     string
 116	maxOutputTokens  *int64
 117	temperature      *float64
 118	topP             *float64
 119	topK             *int64
 120	presencePenalty  *float64
 121	frequencyPenalty *float64
 122	headers          map[string]string
 123	providerOptions  ProviderOptions
 124
 125	// TODO: add support for provider tools
 126	tools      []tools.BaseTool
 127	maxRetries *int
 128
 129	model LanguageModel
 130
 131	stopWhen       []StopCondition
 132	prepareStep    PrepareStepFunction
 133	repairToolCall RepairToolCallFunction
 134	onStepFinished OnStepFinishedFunction
 135	onRetry        OnRetryCallback
 136}
 137
 138type AgentCall struct {
 139	Prompt           string     `json:"prompt"`
 140	Files            []FilePart `json:"files"`
 141	Messages         []Message  `json:"messages"`
 142	MaxOutputTokens  *int64
 143	Temperature      *float64 `json:"temperature"`
 144	TopP             *float64 `json:"top_p"`
 145	TopK             *int64   `json:"top_k"`
 146	PresencePenalty  *float64 `json:"presence_penalty"`
 147	FrequencyPenalty *float64 `json:"frequency_penalty"`
 148	ActiveTools      []string `json:"active_tools"`
 149	Headers          map[string]string
 150	ProviderOptions  ProviderOptions
 151	OnRetry          OnRetryCallback
 152	MaxRetries       *int
 153
 154	StopWhen       []StopCondition
 155	PrepareStep    PrepareStepFunction
 156	RepairToolCall RepairToolCallFunction
 157	OnStepFinished OnStepFinishedFunction
 158}
 159
 160type AgentStreamCall struct {
 161	Prompt           string     `json:"prompt"`
 162	Files            []FilePart `json:"files"`
 163	Messages         []Message  `json:"messages"`
 164	MaxOutputTokens  *int64
 165	Temperature      *float64 `json:"temperature"`
 166	TopP             *float64 `json:"top_p"`
 167	TopK             *int64   `json:"top_k"`
 168	PresencePenalty  *float64 `json:"presence_penalty"`
 169	FrequencyPenalty *float64 `json:"frequency_penalty"`
 170	ActiveTools      []string `json:"active_tools"`
 171	Headers          map[string]string
 172	ProviderOptions  ProviderOptions
 173	OnRetry          OnRetryCallback
 174	MaxRetries       *int
 175
 176	StopWhen       []StopCondition
 177	PrepareStep    PrepareStepFunction
 178	RepairToolCall RepairToolCallFunction
 179
 180	// Agent-level callbacks
 181	OnAgentStart  func()                      // Called when agent starts
 182	OnAgentFinish func(result *AgentResult)   // Called when agent finishes
 183	OnStepStart   func(stepNumber int)        // Called when a step starts
 184	OnStepFinish  func(stepResult StepResult) // Called when a step finishes
 185	OnFinish      func(result *AgentResult)   // Called when entire agent completes
 186	OnError       func(error)                 // Called when an error occurs
 187
 188	// Stream part callbacks - called for each corresponding stream part type
 189	OnChunk          func(StreamPart)                                                               // Called for each stream part (catch-all)
 190	OnWarnings       func(warnings []CallWarning)                                                   // Called for warnings
 191	OnTextStart      func(id string)                                                                // Called when text starts
 192	OnTextDelta      func(id, text string)                                                          // Called for text deltas
 193	OnTextEnd        func(id string)                                                                // Called when text ends
 194	OnReasoningStart func(id string)                                                                // Called when reasoning starts
 195	OnReasoningDelta func(id, text string)                                                          // Called for reasoning deltas
 196	OnReasoningEnd   func(id string)                                                                // Called when reasoning ends
 197	OnToolInputStart func(id, toolName string)                                                      // Called when tool input starts
 198	OnToolInputDelta func(id, delta string)                                                         // Called for tool input deltas
 199	OnToolInputEnd   func(id string)                                                                // Called when tool input ends
 200	OnToolCall       func(toolCall ToolCallContent)                                                 // Called when tool call is complete
 201	OnToolResult     func(result ToolResultContent)                                                 // Called when tool execution completes
 202	OnSource         func(source SourceContent)                                                     // Called for source references
 203	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes
 204	OnStreamError    func(error)                                                                    // Called when stream error occurs
 205}
 206
 207type AgentResult struct {
 208	Steps []StepResult
 209	// Final response
 210	Response   Response
 211	TotalUsage Usage
 212}
 213
 214type Agent interface {
 215	Generate(context.Context, AgentCall) (*AgentResult, error)
 216	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 217}
 218
 219type agentOption = func(*AgentSettings)
 220
 221type agent struct {
 222	settings AgentSettings
 223}
 224
 225func NewAgent(model LanguageModel, opts ...agentOption) Agent {
 226	settings := AgentSettings{
 227		model: model,
 228	}
 229	for _, o := range opts {
 230		o(&settings)
 231	}
 232	return &agent{
 233		settings: settings,
 234	}
 235}
 236
 237func (a *agent) prepareCall(call AgentCall) AgentCall {
 238	if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
 239		call.MaxOutputTokens = a.settings.maxOutputTokens
 240	}
 241	if call.Temperature == nil && a.settings.temperature != nil {
 242		call.Temperature = a.settings.temperature
 243	}
 244	if call.TopP == nil && a.settings.topP != nil {
 245		call.TopP = a.settings.topP
 246	}
 247	if call.TopK == nil && a.settings.topK != nil {
 248		call.TopK = a.settings.topK
 249	}
 250	if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
 251		call.PresencePenalty = a.settings.presencePenalty
 252	}
 253	if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
 254		call.FrequencyPenalty = a.settings.frequencyPenalty
 255	}
 256	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 257		call.StopWhen = a.settings.stopWhen
 258	}
 259	if call.PrepareStep == nil && a.settings.prepareStep != nil {
 260		call.PrepareStep = a.settings.prepareStep
 261	}
 262	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 263		call.RepairToolCall = a.settings.repairToolCall
 264	}
 265	if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
 266		call.OnStepFinished = a.settings.onStepFinished
 267	}
 268	if call.OnRetry == nil && a.settings.onRetry != nil {
 269		call.OnRetry = a.settings.onRetry
 270	}
 271	if call.MaxRetries == nil && a.settings.maxRetries != nil {
 272		call.MaxRetries = a.settings.maxRetries
 273	}
 274
 275	providerOptions := ProviderOptions{}
 276	if a.settings.providerOptions != nil {
 277		maps.Copy(providerOptions, a.settings.providerOptions)
 278	}
 279	if call.ProviderOptions != nil {
 280		maps.Copy(providerOptions, call.ProviderOptions)
 281	}
 282	call.ProviderOptions = providerOptions
 283
 284	headers := map[string]string{}
 285
 286	if a.settings.headers != nil {
 287		maps.Copy(headers, a.settings.headers)
 288	}
 289
 290	if call.Headers != nil {
 291		maps.Copy(headers, call.Headers)
 292	}
 293	call.Headers = headers
 294	return call
 295}
 296
 297// Generate implements Agent.
 298func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
 299	opts = a.prepareCall(opts)
 300	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 301	if err != nil {
 302		return nil, err
 303	}
 304	var responseMessages []Message
 305	var steps []StepResult
 306
 307	for {
 308		stepInputMessages := append(initialPrompt, responseMessages...)
 309		stepModel := a.settings.model
 310		stepSystemPrompt := a.settings.systemPrompt
 311		stepActiveTools := opts.ActiveTools
 312		stepToolChoice := ToolChoiceAuto
 313		disableAllTools := false
 314
 315		if opts.PrepareStep != nil {
 316			prepared := opts.PrepareStep(PrepareStepFunctionOptions{
 317				Model:      stepModel,
 318				Steps:      steps,
 319				StepNumber: len(steps),
 320				Messages:   stepInputMessages,
 321			})
 322
 323			// Apply prepared step modifications
 324			if prepared.Messages != nil {
 325				stepInputMessages = prepared.Messages
 326			}
 327			if prepared.Model != nil {
 328				stepModel = prepared.Model
 329			}
 330			if prepared.System != nil {
 331				stepSystemPrompt = *prepared.System
 332			}
 333			if prepared.ToolChoice != nil {
 334				stepToolChoice = *prepared.ToolChoice
 335			}
 336			if len(prepared.ActiveTools) > 0 {
 337				stepActiveTools = prepared.ActiveTools
 338			}
 339			disableAllTools = prepared.DisableAllTools
 340		}
 341
 342		// Recreate prompt with potentially modified system prompt
 343		if stepSystemPrompt != a.settings.systemPrompt {
 344			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
 345			if err != nil {
 346				return nil, err
 347			}
 348			// Replace system message part, keep the rest
 349			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 350				stepInputMessages[0] = stepPrompt[0] // Replace system message
 351			}
 352		}
 353
 354		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 355
 356		retryOptions := DefaultRetryOptions()
 357		retryOptions.OnRetry = opts.OnRetry
 358		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 359
 360		result, err := retry(ctx, func() (*Response, error) {
 361			return stepModel.Generate(ctx, Call{
 362				Prompt:           stepInputMessages,
 363				MaxOutputTokens:  opts.MaxOutputTokens,
 364				Temperature:      opts.Temperature,
 365				TopP:             opts.TopP,
 366				TopK:             opts.TopK,
 367				PresencePenalty:  opts.PresencePenalty,
 368				FrequencyPenalty: opts.FrequencyPenalty,
 369				Tools:            preparedTools,
 370				ToolChoice:       &stepToolChoice,
 371				Headers:          opts.Headers,
 372				ProviderOptions:  opts.ProviderOptions,
 373			})
 374		})
 375		if err != nil {
 376			return nil, err
 377		}
 378
 379		var stepToolCalls []ToolCallContent
 380		for _, content := range result.Content {
 381			if content.GetType() == ContentTypeToolCall {
 382				toolCall, ok := AsContentType[ToolCallContent](content)
 383				if !ok {
 384					continue
 385				}
 386
 387				// Validate and potentially repair the tool call
 388				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
 389				stepToolCalls = append(stepToolCalls, validatedToolCall)
 390			}
 391		}
 392
 393		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 394
 395		// Build step content with validated tool calls and tool results
 396		stepContent := []Content{}
 397		toolCallIndex := 0
 398		for _, content := range result.Content {
 399			if content.GetType() == ContentTypeToolCall {
 400				// Replace with validated tool call
 401				if toolCallIndex < len(stepToolCalls) {
 402					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
 403					toolCallIndex++
 404				}
 405			} else {
 406				// Keep other content as-is
 407				stepContent = append(stepContent, content)
 408			}
 409		}
 410		// Add tool results
 411		for _, result := range toolResults {
 412			stepContent = append(stepContent, result)
 413		}
 414		currentStepMessages := toResponseMessages(stepContent)
 415		responseMessages = append(responseMessages, currentStepMessages...)
 416
 417		stepResult := StepResult{
 418			Response: Response{
 419				Content:          stepContent,
 420				FinishReason:     result.FinishReason,
 421				Usage:            result.Usage,
 422				Warnings:         result.Warnings,
 423				ProviderMetadata: result.ProviderMetadata,
 424			},
 425			Messages: currentStepMessages,
 426		}
 427		steps = append(steps, stepResult)
 428		if opts.OnStepFinished != nil {
 429			opts.OnStepFinished(stepResult)
 430		}
 431
 432		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 433
 434		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
 435			break
 436		}
 437	}
 438
 439	totalUsage := Usage{}
 440
 441	for _, step := range steps {
 442		usage := step.Usage
 443		totalUsage.InputTokens += usage.InputTokens
 444		totalUsage.OutputTokens += usage.OutputTokens
 445		totalUsage.ReasoningTokens += usage.ReasoningTokens
 446		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
 447		totalUsage.CacheReadTokens += usage.CacheReadTokens
 448		totalUsage.TotalTokens += usage.TotalTokens
 449	}
 450
 451	agentResult := &AgentResult{
 452		Steps:      steps,
 453		Response:   steps[len(steps)-1].Response,
 454		TotalUsage: totalUsage,
 455	}
 456	return agentResult, nil
 457}
 458
 459func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
 460	if len(conditions) == 0 {
 461		return false
 462	}
 463
 464	for _, condition := range conditions {
 465		if condition(steps) {
 466			return true
 467		}
 468	}
 469	return false
 470}
 471
 472func toResponseMessages(content []Content) []Message {
 473	var assistantParts []MessagePart
 474	var toolParts []MessagePart
 475
 476	for _, c := range content {
 477		switch c.GetType() {
 478		case ContentTypeText:
 479			text, ok := AsContentType[TextContent](c)
 480			if !ok {
 481				continue
 482			}
 483			assistantParts = append(assistantParts, TextPart{
 484				Text:            text.Text,
 485				ProviderOptions: ProviderOptions(text.ProviderMetadata),
 486			})
 487		case ContentTypeReasoning:
 488			reasoning, ok := AsContentType[ReasoningContent](c)
 489			if !ok {
 490				continue
 491			}
 492			assistantParts = append(assistantParts, ReasoningPart{
 493				Text:            reasoning.Text,
 494				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
 495			})
 496		case ContentTypeToolCall:
 497			toolCall, ok := AsContentType[ToolCallContent](c)
 498			if !ok {
 499				continue
 500			}
 501			assistantParts = append(assistantParts, ToolCallPart{
 502				ToolCallID:       toolCall.ToolCallID,
 503				ToolName:         toolCall.ToolName,
 504				Input:            toolCall.Input,
 505				ProviderExecuted: toolCall.ProviderExecuted,
 506				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
 507			})
 508		case ContentTypeFile:
 509			file, ok := AsContentType[FileContent](c)
 510			if !ok {
 511				continue
 512			}
 513			assistantParts = append(assistantParts, FilePart{
 514				Data:            file.Data,
 515				MediaType:       file.MediaType,
 516				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 517			})
 518		case ContentTypeSource:
 519			// Sources are metadata about references used to generate the response.
 520			// They don't need to be included in the conversation messages.
 521			continue
 522		case ContentTypeToolResult:
 523			result, ok := AsContentType[ToolResultContent](c)
 524			if !ok {
 525				continue
 526			}
 527			toolParts = append(toolParts, ToolResultPart{
 528				ToolCallID:      result.ToolCallID,
 529				Output:          result.Result,
 530				ProviderOptions: ProviderOptions(result.ProviderMetadata),
 531			})
 532		}
 533	}
 534
 535	var messages []Message
 536	if len(assistantParts) > 0 {
 537		messages = append(messages, Message{
 538			Role:    MessageRoleAssistant,
 539			Content: assistantParts,
 540		})
 541	}
 542	if len(toolParts) > 0 {
 543		messages = append(messages, Message{
 544			Role:    MessageRoleTool,
 545			Content: toolParts,
 546		})
 547	}
 548	return messages
 549}
 550
 551func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
 552	if len(toolCalls) == 0 {
 553		return nil, nil
 554	}
 555
 556	// Create a map for quick tool lookup
 557	toolMap := make(map[string]tools.BaseTool)
 558	for _, tool := range allTools {
 559		toolMap[tool.Info().Name] = tool
 560	}
 561
 562	// Execute all tool calls in parallel
 563	results := make([]ToolResultContent, len(toolCalls))
 564	var toolExecutionError error
 565	var wg sync.WaitGroup
 566
 567	for i, toolCall := range toolCalls {
 568		wg.Add(1)
 569		go func(index int, call ToolCallContent) {
 570			defer wg.Done()
 571
 572			// Skip invalid tool calls - create error result
 573			if call.Invalid {
 574				results[index] = ToolResultContent{
 575					ToolCallID: call.ToolCallID,
 576					ToolName:   call.ToolName,
 577					Result: ToolResultOutputContentError{
 578						Error: call.ValidationError,
 579					},
 580					ProviderExecuted: false,
 581				}
 582				if toolResultCallback != nil {
 583					toolResultCallback(results[index])
 584				}
 585
 586				return
 587			}
 588
 589			tool, exists := toolMap[call.ToolName]
 590			if !exists {
 591				results[index] = ToolResultContent{
 592					ToolCallID: call.ToolCallID,
 593					ToolName:   call.ToolName,
 594					Result: ToolResultOutputContentError{
 595						Error: errors.New("Error: Tool not found: " + call.ToolName),
 596					},
 597					ProviderExecuted: false,
 598				}
 599
 600				if toolResultCallback != nil {
 601					toolResultCallback(results[index])
 602				}
 603				return
 604			}
 605
 606			// Execute the tool
 607			result, err := tool.Run(ctx, tools.ToolCall{
 608				ID:    call.ToolCallID,
 609				Name:  call.ToolName,
 610				Input: call.Input,
 611			})
 612			if err != nil {
 613				results[index] = ToolResultContent{
 614					ToolCallID: call.ToolCallID,
 615					ToolName:   call.ToolName,
 616					Result: ToolResultOutputContentError{
 617						Error: err,
 618					},
 619					ProviderExecuted: false,
 620				}
 621				if toolResultCallback != nil {
 622					toolResultCallback(results[index])
 623				}
 624				toolExecutionError = err
 625				return
 626			}
 627
 628			if result.IsError {
 629				results[index] = ToolResultContent{
 630					ToolCallID: call.ToolCallID,
 631					ToolName:   call.ToolName,
 632					Result: ToolResultOutputContentError{
 633						Error: errors.New(result.Content),
 634					},
 635					ProviderExecuted: false,
 636				}
 637
 638				if toolResultCallback != nil {
 639					toolResultCallback(results[index])
 640				}
 641			} else {
 642				results[index] = ToolResultContent{
 643					ToolCallID: call.ToolCallID,
 644					ToolName:   toolCall.ToolName,
 645					Result: ToolResultOutputContentText{
 646						Text: result.Content,
 647					},
 648					ProviderExecuted: false,
 649				}
 650				if toolResultCallback != nil {
 651					toolResultCallback(results[index])
 652				}
 653			}
 654		}(i, toolCall)
 655	}
 656
 657	// Wait for all tool executions to complete
 658	wg.Wait()
 659
 660	return results, toolExecutionError
 661}
 662
 663// Stream implements Agent.
 664func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
 665	// Convert AgentStreamCall to AgentCall for preparation
 666	call := AgentCall{
 667		Prompt:           opts.Prompt,
 668		Files:            opts.Files,
 669		Messages:         opts.Messages,
 670		MaxOutputTokens:  opts.MaxOutputTokens,
 671		Temperature:      opts.Temperature,
 672		TopP:             opts.TopP,
 673		TopK:             opts.TopK,
 674		PresencePenalty:  opts.PresencePenalty,
 675		FrequencyPenalty: opts.FrequencyPenalty,
 676		ActiveTools:      opts.ActiveTools,
 677		Headers:          opts.Headers,
 678		ProviderOptions:  opts.ProviderOptions,
 679		MaxRetries:       opts.MaxRetries,
 680		StopWhen:         opts.StopWhen,
 681		PrepareStep:      opts.PrepareStep,
 682		RepairToolCall:   opts.RepairToolCall,
 683	}
 684
 685	call = a.prepareCall(call)
 686
 687	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
 688	if err != nil {
 689		return nil, err
 690	}
 691
 692	var responseMessages []Message
 693	var steps []StepResult
 694	var totalUsage Usage
 695
 696	// Start agent stream
 697	if opts.OnAgentStart != nil {
 698		opts.OnAgentStart()
 699	}
 700
 701	for stepNumber := 0; ; stepNumber++ {
 702		stepInputMessages := append(initialPrompt, responseMessages...)
 703		stepModel := a.settings.model
 704		stepSystemPrompt := a.settings.systemPrompt
 705		stepActiveTools := call.ActiveTools
 706		stepToolChoice := ToolChoiceAuto
 707		disableAllTools := false
 708
 709		// Apply step preparation if provided
 710		if call.PrepareStep != nil {
 711			prepared := call.PrepareStep(PrepareStepFunctionOptions{
 712				Model:      stepModel,
 713				Steps:      steps,
 714				StepNumber: stepNumber,
 715				Messages:   stepInputMessages,
 716			})
 717
 718			if prepared.Messages != nil {
 719				stepInputMessages = prepared.Messages
 720			}
 721			if prepared.Model != nil {
 722				stepModel = prepared.Model
 723			}
 724			if prepared.System != nil {
 725				stepSystemPrompt = *prepared.System
 726			}
 727			if prepared.ToolChoice != nil {
 728				stepToolChoice = *prepared.ToolChoice
 729			}
 730			if len(prepared.ActiveTools) > 0 {
 731				stepActiveTools = prepared.ActiveTools
 732			}
 733			disableAllTools = prepared.DisableAllTools
 734		}
 735
 736		// Recreate prompt with potentially modified system prompt
 737		if stepSystemPrompt != a.settings.systemPrompt {
 738			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
 739			if err != nil {
 740				return nil, err
 741			}
 742			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
 743				stepInputMessages[0] = stepPrompt[0]
 744			}
 745		}
 746
 747		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 748
 749		// Start step stream
 750		if opts.OnStepStart != nil {
 751			opts.OnStepStart(stepNumber)
 752		}
 753
 754		// Create streaming call
 755		streamCall := Call{
 756			Prompt:           stepInputMessages,
 757			MaxOutputTokens:  call.MaxOutputTokens,
 758			Temperature:      call.Temperature,
 759			TopP:             call.TopP,
 760			TopK:             call.TopK,
 761			PresencePenalty:  call.PresencePenalty,
 762			FrequencyPenalty: call.FrequencyPenalty,
 763			Tools:            preparedTools,
 764			ToolChoice:       &stepToolChoice,
 765			Headers:          call.Headers,
 766			ProviderOptions:  call.ProviderOptions,
 767		}
 768
 769		// Get streaming response
 770		stream, err := stepModel.Stream(ctx, streamCall)
 771		if err != nil {
 772			if opts.OnError != nil {
 773				opts.OnError(err)
 774			}
 775			return nil, err
 776		}
 777
 778		// Process stream with tool execution
 779		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
 780		if err != nil {
 781			if opts.OnError != nil {
 782				opts.OnError(err)
 783			}
 784			return nil, err
 785		}
 786
 787		steps = append(steps, stepResult)
 788		totalUsage = addUsage(totalUsage, stepResult.Usage)
 789
 790		// Call step finished callback
 791		if opts.OnStepFinish != nil {
 792			opts.OnStepFinish(stepResult)
 793		}
 794
 795		// Add step messages to response messages
 796		stepMessages := toResponseMessages(stepResult.Content)
 797		responseMessages = append(responseMessages, stepMessages...)
 798
 799		// Check stop conditions
 800		shouldStop := isStopConditionMet(call.StopWhen, steps)
 801		if shouldStop || !shouldContinue {
 802			break
 803		}
 804	}
 805
 806	// Finish agent stream
 807	agentResult := &AgentResult{
 808		Steps:      steps,
 809		Response:   steps[len(steps)-1].Response,
 810		TotalUsage: totalUsage,
 811	}
 812
 813	if opts.OnFinish != nil {
 814		opts.OnFinish(agentResult)
 815	}
 816
 817	if opts.OnAgentFinish != nil {
 818		opts.OnAgentFinish(agentResult)
 819	}
 820
 821	return agentResult, nil
 822}
 823
 824func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
 825	var preparedTools []Tool
 826
 827	// If explicitly disabling all tools, return no tools
 828	if disableAllTools {
 829		return preparedTools
 830	}
 831
 832	for _, tool := range tools {
 833		// If activeTools has items, only include tools in the list
 834		// If activeTools is empty, include all tools
 835		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 836			continue
 837		}
 838		info := tool.Info()
 839		preparedTools = append(preparedTools, FunctionTool{
 840			Name:        info.Name,
 841			Description: info.Description,
 842			InputSchema: map[string]any{
 843				"type":       "object",
 844				"properties": info.Parameters,
 845				"required":   info.Required,
 846			},
 847		})
 848	}
 849	return preparedTools
 850}
 851
 852// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
 853func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []tools.BaseTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
 854	if err := a.validateToolCall(toolCall, availableTools); err == nil {
 855		return toolCall
 856	} else {
 857		if repairFunc != nil {
 858			repairOptions := ToolCallRepairOptions{
 859				OriginalToolCall: toolCall,
 860				ValidationError:  err,
 861				AvailableTools:   availableTools,
 862				SystemPrompt:     systemPrompt,
 863				Messages:         messages,
 864			}
 865
 866			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
 867				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
 868					return *repairedToolCall
 869				}
 870			}
 871		}
 872
 873		invalidToolCall := toolCall
 874		invalidToolCall.Invalid = true
 875		invalidToolCall.ValidationError = err
 876		return invalidToolCall
 877	}
 878}
 879
 880// validateToolCall validates a tool call against available tools and their schemas
 881func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []tools.BaseTool) error {
 882	var tool tools.BaseTool
 883	for _, t := range availableTools {
 884		if t.Info().Name == toolCall.ToolName {
 885			tool = t
 886			break
 887		}
 888	}
 889
 890	if tool == nil {
 891		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
 892	}
 893
 894	// Validate JSON parsing
 895	var input map[string]any
 896	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
 897		return fmt.Errorf("invalid JSON input: %w", err)
 898	}
 899
 900	// Basic schema validation (check required fields)
 901	// TODO: more robust schema validation using JSON Schema or similar
 902	toolInfo := tool.Info()
 903	for _, required := range toolInfo.Required {
 904		if _, exists := input[required]; !exists {
 905			return fmt.Errorf("missing required parameter: %s", required)
 906		}
 907	}
 908	return nil
 909}
 910
 911func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 912	if prompt == "" {
 913		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
 914	}
 915
 916	var preparedPrompt Prompt
 917
 918	if system != "" {
 919		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 920	}
 921
 922	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 923	preparedPrompt = append(preparedPrompt, messages...)
 924	return preparedPrompt, nil
 925}
 926
 927func WithSystemPrompt(prompt string) agentOption {
 928	return func(s *AgentSettings) {
 929		s.systemPrompt = prompt
 930	}
 931}
 932
 933func WithMaxOutputTokens(tokens int64) agentOption {
 934	return func(s *AgentSettings) {
 935		s.maxOutputTokens = &tokens
 936	}
 937}
 938
 939func WithTemperature(temp float64) agentOption {
 940	return func(s *AgentSettings) {
 941		s.temperature = &temp
 942	}
 943}
 944
 945func WithTopP(topP float64) agentOption {
 946	return func(s *AgentSettings) {
 947		s.topP = &topP
 948	}
 949}
 950
 951func WithTopK(topK int64) agentOption {
 952	return func(s *AgentSettings) {
 953		s.topK = &topK
 954	}
 955}
 956
 957func WithPresencePenalty(penalty float64) agentOption {
 958	return func(s *AgentSettings) {
 959		s.presencePenalty = &penalty
 960	}
 961}
 962
 963func WithFrequencyPenalty(penalty float64) agentOption {
 964	return func(s *AgentSettings) {
 965		s.frequencyPenalty = &penalty
 966	}
 967}
 968
 969func WithTools(tools ...tools.BaseTool) agentOption {
 970	return func(s *AgentSettings) {
 971		s.tools = append(s.tools, tools...)
 972	}
 973}
 974
 975func WithStopConditions(conditions ...StopCondition) agentOption {
 976	return func(s *AgentSettings) {
 977		s.stopWhen = append(s.stopWhen, conditions...)
 978	}
 979}
 980
 981func WithPrepareStep(fn PrepareStepFunction) agentOption {
 982	return func(s *AgentSettings) {
 983		s.prepareStep = fn
 984	}
 985}
 986
 987func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
 988	return func(s *AgentSettings) {
 989		s.repairToolCall = fn
 990	}
 991}
 992
 993func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
 994	return func(s *AgentSettings) {
 995		s.onStepFinished = fn
 996	}
 997}
 998
 999// processStepStream processes a single step's stream and returns the step result
1000func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1001	var stepContent []Content
1002	var stepToolCalls []ToolCallContent
1003	var stepUsage Usage
1004	var stepFinishReason FinishReason = FinishReasonUnknown
1005	var stepWarnings []CallWarning
1006	var stepProviderMetadata ProviderMetadata
1007
1008	activeToolCalls := make(map[string]*ToolCallContent)
1009	activeTextContent := make(map[string]string)
1010
1011	// Process stream parts
1012	for part := range stream {
1013		// Forward all parts to chunk callback
1014		if opts.OnChunk != nil {
1015			opts.OnChunk(part)
1016		}
1017
1018		switch part.Type {
1019		case StreamPartTypeWarnings:
1020			stepWarnings = part.Warnings
1021			if opts.OnWarnings != nil {
1022				opts.OnWarnings(part.Warnings)
1023			}
1024
1025		case StreamPartTypeTextStart:
1026			activeTextContent[part.ID] = ""
1027			if opts.OnTextStart != nil {
1028				opts.OnTextStart(part.ID)
1029			}
1030
1031		case StreamPartTypeTextDelta:
1032			if _, exists := activeTextContent[part.ID]; exists {
1033				activeTextContent[part.ID] += part.Delta
1034			}
1035			if opts.OnTextDelta != nil {
1036				opts.OnTextDelta(part.ID, part.Delta)
1037			}
1038
1039		case StreamPartTypeTextEnd:
1040			if text, exists := activeTextContent[part.ID]; exists {
1041				stepContent = append(stepContent, TextContent{
1042					Text:             text,
1043					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1044				})
1045				delete(activeTextContent, part.ID)
1046			}
1047			if opts.OnTextEnd != nil {
1048				opts.OnTextEnd(part.ID)
1049			}
1050
1051		case StreamPartTypeReasoningStart:
1052			activeTextContent[part.ID] = ""
1053			if opts.OnReasoningStart != nil {
1054				opts.OnReasoningStart(part.ID)
1055			}
1056
1057		case StreamPartTypeReasoningDelta:
1058			if _, exists := activeTextContent[part.ID]; exists {
1059				activeTextContent[part.ID] += part.Delta
1060			}
1061			if opts.OnReasoningDelta != nil {
1062				opts.OnReasoningDelta(part.ID, part.Delta)
1063			}
1064
1065		case StreamPartTypeReasoningEnd:
1066			if text, exists := activeTextContent[part.ID]; exists {
1067				stepContent = append(stepContent, ReasoningContent{
1068					Text:             text,
1069					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1070				})
1071				delete(activeTextContent, part.ID)
1072			}
1073			if opts.OnReasoningEnd != nil {
1074				opts.OnReasoningEnd(part.ID)
1075			}
1076
1077		case StreamPartTypeToolInputStart:
1078			activeToolCalls[part.ID] = &ToolCallContent{
1079				ToolCallID:       part.ID,
1080				ToolName:         part.ToolCallName,
1081				Input:            "",
1082				ProviderExecuted: part.ProviderExecuted,
1083			}
1084			if opts.OnToolInputStart != nil {
1085				opts.OnToolInputStart(part.ID, part.ToolCallName)
1086			}
1087
1088		case StreamPartTypeToolInputDelta:
1089			if toolCall, exists := activeToolCalls[part.ID]; exists {
1090				toolCall.Input += part.Delta
1091			}
1092			if opts.OnToolInputDelta != nil {
1093				opts.OnToolInputDelta(part.ID, part.Delta)
1094			}
1095
1096		case StreamPartTypeToolInputEnd:
1097			if opts.OnToolInputEnd != nil {
1098				opts.OnToolInputEnd(part.ID)
1099			}
1100
1101		case StreamPartTypeToolCall:
1102			toolCall := ToolCallContent{
1103				ToolCallID:       part.ID,
1104				ToolName:         part.ToolCallName,
1105				Input:            part.ToolCallInput,
1106				ProviderExecuted: part.ProviderExecuted,
1107				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1108			}
1109
1110			// Validate and potentially repair the tool call
1111			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1112			stepToolCalls = append(stepToolCalls, validatedToolCall)
1113			stepContent = append(stepContent, validatedToolCall)
1114
1115			if opts.OnToolCall != nil {
1116				opts.OnToolCall(validatedToolCall)
1117			}
1118
1119			// Clean up active tool call
1120			delete(activeToolCalls, part.ID)
1121
1122		case StreamPartTypeSource:
1123			sourceContent := SourceContent{
1124				SourceType:       part.SourceType,
1125				ID:               part.ID,
1126				URL:              part.URL,
1127				Title:            part.Title,
1128				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1129			}
1130			stepContent = append(stepContent, sourceContent)
1131			if opts.OnSource != nil {
1132				opts.OnSource(sourceContent)
1133			}
1134
1135		case StreamPartTypeFinish:
1136			stepUsage = part.Usage
1137			stepFinishReason = part.FinishReason
1138			stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1139			if opts.OnStreamFinish != nil {
1140				opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1141			}
1142
1143		case StreamPartTypeError:
1144			if opts.OnStreamError != nil {
1145				opts.OnStreamError(part.Error)
1146			}
1147			if opts.OnError != nil {
1148				opts.OnError(part.Error)
1149			}
1150			return StepResult{}, false, part.Error
1151		}
1152	}
1153
1154	// Execute tools if any
1155	var toolResults []ToolResultContent
1156	if len(stepToolCalls) > 0 {
1157		var err error
1158		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1159		if err != nil {
1160			return StepResult{}, false, err
1161		}
1162		// Add tool results to content
1163		for _, result := range toolResults {
1164			stepContent = append(stepContent, result)
1165		}
1166	}
1167
1168	stepResult := StepResult{
1169		Response: Response{
1170			Content:          stepContent,
1171			FinishReason:     stepFinishReason,
1172			Usage:            stepUsage,
1173			Warnings:         stepWarnings,
1174			ProviderMetadata: stepProviderMetadata,
1175		},
1176		Messages: toResponseMessages(stepContent),
1177	}
1178
1179	// Determine if we should continue (has tool calls and not stopped)
1180	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1181
1182	return stepResult, shouldContinue, nil
1183}
1184
1185func addUsage(a, b Usage) Usage {
1186	return Usage{
1187		InputTokens:         a.InputTokens + b.InputTokens,
1188		OutputTokens:        a.OutputTokens + b.OutputTokens,
1189		TotalTokens:         a.TotalTokens + b.TotalTokens,
1190		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
1191		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1192		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
1193	}
1194}
1195
1196func WithHeaders(headers map[string]string) agentOption {
1197	return func(s *AgentSettings) {
1198		s.headers = headers
1199	}
1200}
1201
1202func WithProviderOptions(providerOptions ProviderOptions) agentOption {
1203	return func(s *AgentSettings) {
1204		s.providerOptions = providerOptions
1205	}
1206}