1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "slices"
12 "sync"
13
14 "charm.land/fantasy/schema"
15 "github.com/charmbracelet/x/exp/slice"
16)
17
18// StepResult represents the result of a single step in an agent execution.
19type StepResult struct {
20 Response
21 Messages []Message
22}
23
24// stepExecutionResult encapsulates the result of executing a step with stream processing.
25type stepExecutionResult struct {
26 StepResult StepResult
27 ShouldContinue bool
28}
29
30// StopCondition defines a function that determines when an agent should stop executing.
31type StopCondition = func(steps []StepResult) bool
32
33// StepCountIs returns a stop condition that stops after the specified number of steps.
34func StepCountIs(stepCount int) StopCondition {
35 return func(steps []StepResult) bool {
36 return len(steps) >= stepCount
37 }
38}
39
40// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
41func HasToolCall(toolName string) StopCondition {
42 return func(steps []StepResult) bool {
43 if len(steps) == 0 {
44 return false
45 }
46 lastStep := steps[len(steps)-1]
47 toolCalls := lastStep.Content.ToolCalls()
48 for _, toolCall := range toolCalls {
49 if toolCall.ToolName == toolName {
50 return true
51 }
52 }
53 return false
54 }
55}
56
57// HasContent returns a stop condition that stops when the specified content type appears in the last step.
58func HasContent(contentType ContentType) StopCondition {
59 return func(steps []StepResult) bool {
60 if len(steps) == 0 {
61 return false
62 }
63 lastStep := steps[len(steps)-1]
64 for _, content := range lastStep.Content {
65 if content.GetType() == contentType {
66 return true
67 }
68 }
69 return false
70 }
71}
72
73// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
74func FinishReasonIs(reason FinishReason) StopCondition {
75 return func(steps []StepResult) bool {
76 if len(steps) == 0 {
77 return false
78 }
79 lastStep := steps[len(steps)-1]
80 return lastStep.FinishReason == reason
81 }
82}
83
84// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
85func MaxTokensUsed(maxTokens int64) StopCondition {
86 return func(steps []StepResult) bool {
87 var totalTokens int64
88 for _, step := range steps {
89 totalTokens += step.Usage.TotalTokens
90 }
91 return totalTokens >= maxTokens
92 }
93}
94
95// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
96type PrepareStepFunctionOptions struct {
97 Steps []StepResult
98 StepNumber int
99 Model LanguageModel
100 Messages []Message
101}
102
103// PrepareStepResult contains the result of preparing a step in an agent execution.
104type PrepareStepResult struct {
105 Model LanguageModel
106 Messages []Message
107 System *string
108 ToolChoice *ToolChoice
109 ActiveTools []string
110 DisableAllTools bool
111 Tools []AgentTool
112}
113
114// ToolCallRepairOptions contains the options for repairing a tool call.
115type ToolCallRepairOptions struct {
116 OriginalToolCall ToolCallContent
117 ValidationError error
118 AvailableTools []AgentTool
119 SystemPrompt string
120 Messages []Message
121}
122
123type (
124 // PrepareStepFunction defines a function that prepares a step in an agent execution.
125 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
126
127 // OnStepFinishedFunction defines a function that is called when a step finishes.
128 OnStepFinishedFunction = func(step StepResult)
129
130 // RepairToolCallFunction defines a function that repairs a tool call.
131 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
132)
133
134type agentSettings struct {
135 systemPrompt string
136 maxOutputTokens *int64
137 temperature *float64
138 topP *float64
139 topK *int64
140 presencePenalty *float64
141 frequencyPenalty *float64
142 headers map[string]string
143 userAgent string
144 providerOptions ProviderOptions
145
146 providerDefinedTools []ProviderDefinedTool
147 executableProviderTools []ExecutableProviderTool
148 tools []AgentTool
149 maxRetries *int
150
151 model LanguageModel
152
153 stopWhen []StopCondition
154 prepareStep PrepareStepFunction
155 repairToolCall RepairToolCallFunction
156 onRetry OnRetryCallback
157}
158
159// AgentCall represents a call to an agent.
160type AgentCall struct {
161 Prompt string `json:"prompt"`
162 Files []FilePart `json:"files"`
163 Messages []Message `json:"messages"`
164 MaxOutputTokens *int64
165 Temperature *float64 `json:"temperature"`
166 TopP *float64 `json:"top_p"`
167 TopK *int64 `json:"top_k"`
168 PresencePenalty *float64 `json:"presence_penalty"`
169 FrequencyPenalty *float64 `json:"frequency_penalty"`
170 ActiveTools []string `json:"active_tools"`
171 ProviderOptions ProviderOptions
172 OnRetry OnRetryCallback
173 MaxRetries *int
174
175 StopWhen []StopCondition
176 PrepareStep PrepareStepFunction
177 RepairToolCall RepairToolCallFunction
178}
179
180// Agent-level callbacks.
181type (
182 // OnAgentStartFunc is called when agent starts.
183 OnAgentStartFunc func()
184
185 // OnAgentFinishFunc is called when agent finishes.
186 OnAgentFinishFunc func(result *AgentResult) error
187
188 // OnStepStartFunc is called when a step starts.
189 OnStepStartFunc func(stepNumber int) error
190
191 // OnStepFinishFunc is called when a step finishes.
192 OnStepFinishFunc func(stepResult StepResult) error
193
194 // OnFinishFunc is called when entire agent completes.
195 OnFinishFunc func(result *AgentResult)
196
197 // OnErrorFunc is called when an error occurs.
198 OnErrorFunc func(error)
199)
200
201// Stream part callbacks - called for each corresponding stream part type.
202type (
203 // OnChunkFunc is called for each stream part (catch-all).
204 OnChunkFunc func(StreamPart) error
205
206 // OnWarningsFunc is called for warnings.
207 OnWarningsFunc func(warnings []CallWarning) error
208
209 // OnTextStartFunc is called when text starts.
210 OnTextStartFunc func(id string) error
211
212 // OnTextDeltaFunc is called for text deltas.
213 OnTextDeltaFunc func(id, text string) error
214
215 // OnTextEndFunc is called when text ends.
216 OnTextEndFunc func(id string) error
217
218 // OnReasoningStartFunc is called when reasoning starts.
219 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
220
221 // OnReasoningDeltaFunc is called for reasoning deltas.
222 OnReasoningDeltaFunc func(id, text string) error
223
224 // OnReasoningEndFunc is called when reasoning ends.
225 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
226
227 // OnToolInputStartFunc is called when tool input starts.
228 OnToolInputStartFunc func(id, toolName string) error
229
230 // OnToolInputDeltaFunc is called for tool input deltas.
231 OnToolInputDeltaFunc func(id, delta string) error
232
233 // OnToolInputEndFunc is called when tool input ends.
234 OnToolInputEndFunc func(id string) error
235
236 // OnToolCallFunc is called when tool call is complete.
237 OnToolCallFunc func(toolCall ToolCallContent) error
238
239 // OnToolResultFunc is called when tool execution completes.
240 OnToolResultFunc func(result ToolResultContent) error
241
242 // OnSourceFunc is called for source references.
243 OnSourceFunc func(source SourceContent) error
244
245 // OnStreamFinishFunc is called when stream finishes.
246 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
247)
248
249// AgentStreamCall represents a streaming call to an agent.
250type AgentStreamCall struct {
251 Prompt string `json:"prompt"`
252 Files []FilePart `json:"files"`
253 Messages []Message `json:"messages"`
254 MaxOutputTokens *int64
255 Temperature *float64 `json:"temperature"`
256 TopP *float64 `json:"top_p"`
257 TopK *int64 `json:"top_k"`
258 PresencePenalty *float64 `json:"presence_penalty"`
259 FrequencyPenalty *float64 `json:"frequency_penalty"`
260 ActiveTools []string `json:"active_tools"`
261 Headers map[string]string
262 ProviderOptions ProviderOptions
263 OnRetry OnRetryCallback
264 MaxRetries *int
265
266 StopWhen []StopCondition
267 PrepareStep PrepareStepFunction
268 RepairToolCall RepairToolCallFunction
269
270 // Agent-level callbacks
271 OnAgentStart OnAgentStartFunc // Called when agent starts
272 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
273 OnStepStart OnStepStartFunc // Called when a step starts
274 OnStepFinish OnStepFinishFunc // Called when a step finishes
275 OnFinish OnFinishFunc // Called when entire agent completes
276 OnError OnErrorFunc // Called when an error occurs
277
278 // Stream part callbacks - called for each corresponding stream part type
279 OnChunk OnChunkFunc // Called for each stream part (catch-all)
280 OnWarnings OnWarningsFunc // Called for warnings
281 OnTextStart OnTextStartFunc // Called when text starts
282 OnTextDelta OnTextDeltaFunc // Called for text deltas
283 OnTextEnd OnTextEndFunc // Called when text ends
284 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
285 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
286 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
287 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
288 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
289 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
290 OnToolCall OnToolCallFunc // Called when tool call is complete
291 OnToolResult OnToolResultFunc // Called when tool execution completes
292 OnSource OnSourceFunc // Called for source references
293 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
294}
295
296// AgentResult represents the result of an agent execution.
297type AgentResult struct {
298 Steps []StepResult
299 // Final response
300 Response Response
301 TotalUsage Usage
302}
303
304// Agent represents an AI agent that can generate responses and stream responses.
305type Agent interface {
306 Generate(context.Context, AgentCall) (*AgentResult, error)
307 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
308}
309
310// AgentOption defines a function that configures agent settings.
311type AgentOption = func(*agentSettings)
312
313type agent struct {
314 settings agentSettings
315}
316
317// NewAgent creates a new agent with the given language model and options.
318func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
319 settings := agentSettings{
320 model: model,
321 }
322 for _, o := range opts {
323 o(&settings)
324 }
325 return &agent{
326 settings: settings,
327 }
328}
329
330func (a *agent) prepareCall(call AgentCall) AgentCall {
331 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
332 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
333 call.TopP = cmp.Or(call.TopP, a.settings.topP)
334 call.TopK = cmp.Or(call.TopK, a.settings.topK)
335 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
336 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
337 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
338
339 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
340 call.StopWhen = a.settings.stopWhen
341 }
342 if call.PrepareStep == nil && a.settings.prepareStep != nil {
343 call.PrepareStep = a.settings.prepareStep
344 }
345 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
346 call.RepairToolCall = a.settings.repairToolCall
347 }
348 if call.OnRetry == nil && a.settings.onRetry != nil {
349 call.OnRetry = a.settings.onRetry
350 }
351
352 providerOptions := ProviderOptions{}
353 if a.settings.providerOptions != nil {
354 maps.Copy(providerOptions, a.settings.providerOptions)
355 }
356 if call.ProviderOptions != nil {
357 maps.Copy(providerOptions, call.ProviderOptions)
358 }
359 call.ProviderOptions = providerOptions
360
361 headers := map[string]string{}
362
363 if a.settings.headers != nil {
364 maps.Copy(headers, a.settings.headers)
365 }
366
367 return call
368}
369
370// Generate implements Agent.
371func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
372 opts = a.prepareCall(opts)
373 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
374 if err != nil {
375 return nil, err
376 }
377 var responseMessages []Message
378 var steps []StepResult
379
380 for {
381 stepInputMessages := append(initialPrompt, responseMessages...)
382 stepModel := a.settings.model
383 stepSystemPrompt := a.settings.systemPrompt
384 stepActiveTools := opts.ActiveTools
385 stepToolChoice := ToolChoiceAuto
386 disableAllTools := false
387 stepTools := a.settings.tools
388 if opts.PrepareStep != nil {
389 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
390 Model: stepModel,
391 Steps: steps,
392 StepNumber: len(steps),
393 Messages: stepInputMessages,
394 })
395 if err != nil {
396 return nil, err
397 }
398
399 ctx = updatedCtx
400
401 // Apply prepared step modifications
402 if prepared.Messages != nil {
403 stepInputMessages = prepared.Messages
404 }
405 if prepared.Model != nil {
406 stepModel = prepared.Model
407 }
408 if prepared.System != nil {
409 stepSystemPrompt = *prepared.System
410 }
411 if prepared.ToolChoice != nil {
412 stepToolChoice = *prepared.ToolChoice
413 }
414 if len(prepared.ActiveTools) > 0 {
415 stepActiveTools = prepared.ActiveTools
416 }
417 disableAllTools = prepared.DisableAllTools
418 if prepared.Tools != nil {
419 stepTools = prepared.Tools
420 }
421 }
422
423 // Recreate prompt with potentially modified system prompt
424 if stepSystemPrompt != a.settings.systemPrompt {
425 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
426 if err != nil {
427 return nil, err
428 }
429 // Replace system message part, keep the rest
430 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
431 stepInputMessages[0] = stepPrompt[0] // Replace system message
432 }
433 }
434
435 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
436
437 // Filter executable provider tools by activeTools at the
438 // step level, consistent with how stepTools (AgentTools)
439 // are scoped before being passed to inner functions.
440 stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
441
442 retryOptions := DefaultRetryOptions()
443 if opts.MaxRetries != nil {
444 retryOptions.MaxRetries = *opts.MaxRetries
445 }
446 retryOptions.OnRetry = opts.OnRetry
447 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
448 result, err := retry(ctx, func() (*Response, error) {
449 return stepModel.Generate(ctx, Call{
450 Prompt: stepInputMessages,
451 MaxOutputTokens: opts.MaxOutputTokens,
452 Temperature: opts.Temperature,
453 TopP: opts.TopP,
454 TopK: opts.TopK,
455 PresencePenalty: opts.PresencePenalty,
456 FrequencyPenalty: opts.FrequencyPenalty,
457 Tools: preparedTools,
458 ToolChoice: &stepToolChoice,
459 UserAgent: a.settings.userAgent,
460 ProviderOptions: opts.ProviderOptions,
461 })
462 })
463 if err != nil {
464 return nil, err
465 }
466
467 var stepToolCalls []ToolCallContent
468 for _, content := range result.Content {
469 if content.GetType() == ContentTypeToolCall {
470 toolCall, ok := AsContentType[ToolCallContent](content)
471 if !ok {
472 continue
473 }
474 // Provider-executed tool calls (e.g. web search) are
475 // handled by the provider and should not be validated
476 // or executed by the agent.
477 if toolCall.ProviderExecuted {
478 continue
479 }
480 // Validate and potentially repair the tool call
481 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepExecProviderTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
482 stepToolCalls = append(stepToolCalls, validatedToolCall)
483 }
484 }
485
486 toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)
487
488 // If any tool result requested a stop, deliver all results but don't
489 // request another completion from the model.
490 stopTurnRequested := hasStopTurn(toolResults)
491
492 // Build step content with validated tool calls and tool results.
493 // Provider-executed tool calls are kept as-is.
494 stepContent := []Content{}
495 toolCallIndex := 0
496 for _, content := range result.Content {
497 if content.GetType() == ContentTypeToolCall {
498 tc, ok := AsContentType[ToolCallContent](content)
499 if ok && tc.ProviderExecuted {
500 stepContent = append(stepContent, content)
501 continue
502 }
503 // Replace with validated tool call.
504 if toolCallIndex < len(stepToolCalls) {
505 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
506 toolCallIndex++
507 }
508 } else {
509 stepContent = append(stepContent, content)
510 }
511 } // Add tool results
512 for _, result := range toolResults {
513 stepContent = append(stepContent, result)
514 }
515 currentStepMessages := toResponseMessages(stepContent)
516 responseMessages = append(responseMessages, currentStepMessages...)
517
518 stepResult := StepResult{
519 Response: Response{
520 Content: stepContent,
521 FinishReason: result.FinishReason,
522 Usage: result.Usage,
523 Warnings: result.Warnings,
524 ProviderMetadata: result.ProviderMetadata,
525 },
526 Messages: currentStepMessages,
527 }
528 steps = append(steps, stepResult)
529 shouldStop := isStopConditionMet(opts.StopWhen, steps)
530
531 if shouldStop || err != nil || stopTurnRequested || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
532 break
533 }
534 }
535
536 totalUsage := Usage{}
537
538 for _, step := range steps {
539 usage := step.Usage
540 totalUsage.InputTokens += usage.InputTokens
541 totalUsage.OutputTokens += usage.OutputTokens
542 totalUsage.ReasoningTokens += usage.ReasoningTokens
543 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
544 totalUsage.CacheReadTokens += usage.CacheReadTokens
545 totalUsage.TotalTokens += usage.TotalTokens
546 }
547
548 agentResult := &AgentResult{
549 Steps: steps,
550 Response: steps[len(steps)-1].Response,
551 TotalUsage: totalUsage,
552 }
553 return agentResult, nil
554}
555
556func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
557 if len(conditions) == 0 {
558 return false
559 }
560
561 for _, condition := range conditions {
562 if condition(steps) {
563 return true
564 }
565 }
566 return false
567}
568
569func hasStopTurn(results []ToolResultContent) bool {
570 for _, r := range results {
571 if r.StopTurn {
572 return true
573 }
574 }
575 return false
576}
577
578func toResponseMessages(content []Content) []Message {
579 var assistantParts []MessagePart
580 var toolParts []MessagePart
581
582 for _, c := range content {
583 switch c.GetType() {
584 case ContentTypeText:
585 text, ok := AsContentType[TextContent](c)
586 if !ok {
587 continue
588 }
589 assistantParts = append(assistantParts, TextPart{
590 Text: text.Text,
591 ProviderOptions: ProviderOptions(text.ProviderMetadata),
592 })
593 case ContentTypeReasoning:
594 reasoning, ok := AsContentType[ReasoningContent](c)
595 if !ok {
596 continue
597 }
598 assistantParts = append(assistantParts, ReasoningPart{
599 Text: reasoning.Text,
600 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
601 })
602 case ContentTypeToolCall:
603 toolCall, ok := AsContentType[ToolCallContent](c)
604 if !ok {
605 continue
606 }
607 assistantParts = append(assistantParts, ToolCallPart{
608 ToolCallID: toolCall.ToolCallID,
609 ToolName: toolCall.ToolName,
610 Input: toolCall.Input,
611 ProviderExecuted: toolCall.ProviderExecuted,
612 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
613 })
614 case ContentTypeFile:
615 file, ok := AsContentType[FileContent](c)
616 if !ok {
617 continue
618 }
619 assistantParts = append(assistantParts, FilePart{
620 Data: file.Data,
621 MediaType: file.MediaType,
622 ProviderOptions: ProviderOptions(file.ProviderMetadata),
623 })
624 case ContentTypeSource:
625 // Sources are metadata about references used to generate the response.
626 // They don't need to be included in the conversation messages.
627 continue
628 case ContentTypeToolResult:
629 result, ok := AsContentType[ToolResultContent](c)
630 if !ok {
631 continue
632 }
633 resultPart := ToolResultPart{
634 ToolCallID: result.ToolCallID,
635 Output: result.Result,
636 ProviderExecuted: result.ProviderExecuted,
637 ProviderOptions: ProviderOptions(result.ProviderMetadata),
638 }
639 if result.ProviderExecuted {
640 // Provider-executed tool results (e.g. web search)
641 // belong in the assistant message alongside the
642 // server_tool_use block that produced them.
643 assistantParts = append(assistantParts, resultPart)
644 } else {
645 toolParts = append(toolParts, resultPart)
646 }
647 }
648 }
649
650 var messages []Message
651 if len(assistantParts) > 0 {
652 messages = append(messages, Message{
653 Role: MessageRoleAssistant,
654 Content: assistantParts,
655 })
656 }
657 if len(toolParts) > 0 {
658 messages = append(messages, Message{
659 Role: MessageRoleTool,
660 Content: toolParts,
661 })
662 }
663 return messages
664}
665
666func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, execProviderTools []ExecutableProviderTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
667 if len(toolCalls) == 0 {
668 return nil, nil
669 }
670
671 // Create a map for quick tool lookup
672 toolMap := make(map[string]AgentTool)
673 for _, tool := range allTools {
674 toolMap[tool.Info().Name] = tool
675 }
676
677 execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
678 for _, ept := range execProviderTools {
679 execProviderToolMap[ept.GetName()] = ept
680 }
681
682 // Execute all tool calls sequentially in order
683 results := make([]ToolResultContent, 0, len(toolCalls))
684
685 for _, toolCall := range toolCalls {
686 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, toolCall, toolResultCallback)
687 results = append(results, result)
688 if isCriticalError {
689 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
690 return nil, errorResult.Error
691 }
692 }
693 }
694
695 return results, nil
696}
697
698// executeSingleTool executes a single tool and returns its result and a critical error flag.
699func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, execProviderToolMap map[string]ExecutableProviderTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
700 result := ToolResultContent{
701 ToolCallID: toolCall.ToolCallID,
702 ToolName: toolCall.ToolName,
703 ProviderExecuted: false,
704 }
705
706 // Skip invalid tool calls - create error result (not critical)
707 if toolCall.Invalid {
708 result.Result = ToolResultOutputContentError{
709 Error: toolCall.ValidationError,
710 }
711 if toolResultCallback != nil {
712 _ = toolResultCallback(result)
713 }
714 return result, false
715 }
716
717 // Find the run function — either from a regular AgentTool or an
718 // executable provider tool.
719 var runTool func(ctx context.Context, call ToolCall) (ToolResponse, error)
720 if tool, exists := toolMap[toolCall.ToolName]; exists {
721 runTool = tool.Run
722 } else if ept, ok := execProviderToolMap[toolCall.ToolName]; ok {
723 runTool = ept.Run
724 }
725 if runTool == nil {
726 result.Result = ToolResultOutputContentError{
727 Error: errors.New("tool not found: " + toolCall.ToolName),
728 }
729 if toolResultCallback != nil {
730 _ = toolResultCallback(result)
731 }
732 return result, false
733 }
734
735 // Execute the tool
736 toolResult, err := runTool(ctx, ToolCall{
737 ID: toolCall.ToolCallID,
738 Name: toolCall.ToolName,
739 Input: toolCall.Input,
740 })
741 if err != nil {
742 result.Result = ToolResultOutputContentError{
743 Error: err,
744 }
745 result.ClientMetadata = toolResult.Metadata
746 result.StopTurn = toolResult.StopTurn
747 if toolResultCallback != nil {
748 _ = toolResultCallback(result)
749 }
750 return result, true
751 }
752
753 result.ClientMetadata = toolResult.Metadata
754 result.StopTurn = toolResult.StopTurn
755 if toolResult.IsError {
756 result.Result = ToolResultOutputContentError{
757 Error: errors.New(toolResult.Content),
758 }
759 } else if toolResult.Type == "image" || toolResult.Type == "media" {
760 result.Result = ToolResultOutputContentMedia{
761 Data: base64.StdEncoding.EncodeToString(toolResult.Data),
762 MediaType: toolResult.MediaType,
763 Text: toolResult.Content,
764 }
765 } else {
766 result.Result = ToolResultOutputContentText{
767 Text: toolResult.Content,
768 }
769 }
770 if toolResultCallback != nil {
771 _ = toolResultCallback(result)
772 }
773 return result, false
774}
775
776// Stream implements Agent.
777func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
778 // Convert AgentStreamCall to AgentCall for preparation
779 call := AgentCall{
780 Prompt: opts.Prompt,
781 Files: opts.Files,
782 Messages: opts.Messages,
783 MaxOutputTokens: opts.MaxOutputTokens,
784 Temperature: opts.Temperature,
785 TopP: opts.TopP,
786 TopK: opts.TopK,
787 PresencePenalty: opts.PresencePenalty,
788 FrequencyPenalty: opts.FrequencyPenalty,
789 ActiveTools: opts.ActiveTools,
790 ProviderOptions: opts.ProviderOptions,
791 MaxRetries: opts.MaxRetries,
792 OnRetry: opts.OnRetry,
793 StopWhen: opts.StopWhen,
794 PrepareStep: opts.PrepareStep,
795 RepairToolCall: opts.RepairToolCall,
796 }
797
798 call = a.prepareCall(call)
799
800 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
801 if err != nil {
802 return nil, err
803 }
804
805 var responseMessages []Message
806 var steps []StepResult
807 var totalUsage Usage
808
809 // Start agent stream
810 if opts.OnAgentStart != nil {
811 opts.OnAgentStart()
812 }
813
814 for stepNumber := 0; ; stepNumber++ {
815 stepInputMessages := append(initialPrompt, responseMessages...)
816 stepModel := a.settings.model
817 stepSystemPrompt := a.settings.systemPrompt
818 stepActiveTools := call.ActiveTools
819 stepToolChoice := ToolChoiceAuto
820 disableAllTools := false
821 stepTools := a.settings.tools
822 // Apply step preparation if provided
823 if call.PrepareStep != nil {
824 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
825 Model: stepModel,
826 Steps: steps,
827 StepNumber: stepNumber,
828 Messages: stepInputMessages,
829 })
830 if err != nil {
831 return nil, err
832 }
833
834 ctx = updatedCtx
835
836 if prepared.Messages != nil {
837 stepInputMessages = prepared.Messages
838 }
839 if prepared.Model != nil {
840 stepModel = prepared.Model
841 }
842 if prepared.System != nil {
843 stepSystemPrompt = *prepared.System
844 }
845 if prepared.ToolChoice != nil {
846 stepToolChoice = *prepared.ToolChoice
847 }
848 if len(prepared.ActiveTools) > 0 {
849 stepActiveTools = prepared.ActiveTools
850 }
851 disableAllTools = prepared.DisableAllTools
852 if prepared.Tools != nil {
853 stepTools = prepared.Tools
854 }
855 }
856
857 // Recreate prompt with potentially modified system prompt
858 if stepSystemPrompt != a.settings.systemPrompt {
859 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
860 if err != nil {
861 return nil, err
862 }
863 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
864 stepInputMessages[0] = stepPrompt[0]
865 }
866 }
867
868 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
869
870 // Filter executable provider tools by activeTools at the
871 // step level, consistent with how stepTools (AgentTools)
872 // are scoped before being passed to inner functions.
873 stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
874
875 // Start step stream
876 if opts.OnStepStart != nil {
877 _ = opts.OnStepStart(stepNumber)
878 }
879 // Create streaming call
880 streamCall := Call{
881 Prompt: stepInputMessages,
882 MaxOutputTokens: call.MaxOutputTokens,
883 Temperature: call.Temperature,
884 TopP: call.TopP,
885 TopK: call.TopK,
886 PresencePenalty: call.PresencePenalty,
887 FrequencyPenalty: call.FrequencyPenalty,
888 Tools: preparedTools,
889 ToolChoice: &stepToolChoice,
890 UserAgent: a.settings.userAgent,
891 ProviderOptions: call.ProviderOptions,
892 }
893
894 // Execute step with retry logic wrapping both stream creation and processing
895 retryOptions := DefaultRetryOptions()
896 if call.MaxRetries != nil {
897 retryOptions.MaxRetries = *call.MaxRetries
898 }
899 retryOptions.OnRetry = call.OnRetry
900 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
901
902 result, err := retry(ctx, func() (stepExecutionResult, error) {
903 // Create the stream
904 stream, err := stepModel.Stream(ctx, streamCall)
905 if err != nil {
906 return stepExecutionResult{}, err
907 }
908
909 // Process the stream
910 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools)
911 if err != nil {
912 return stepExecutionResult{}, err
913 }
914 return result, nil
915 })
916 if err != nil {
917 if opts.OnError != nil {
918 opts.OnError(err)
919 }
920 return nil, err
921 }
922
923 steps = append(steps, result.StepResult)
924 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
925
926 // Call step finished callback
927 if opts.OnStepFinish != nil {
928 _ = opts.OnStepFinish(result.StepResult)
929 }
930
931 // Add step messages to response messages
932 stepMessages := toResponseMessages(result.StepResult.Content)
933 responseMessages = append(responseMessages, stepMessages...)
934
935 // Check stop conditions
936 shouldStop := isStopConditionMet(call.StopWhen, steps)
937 if shouldStop || !result.ShouldContinue {
938 break
939 }
940 }
941
942 // Finish agent stream
943 agentResult := &AgentResult{
944 Steps: steps,
945 Response: steps[len(steps)-1].Response,
946 TotalUsage: totalUsage,
947 }
948
949 if opts.OnFinish != nil {
950 opts.OnFinish(agentResult)
951 }
952
953 if opts.OnAgentFinish != nil {
954 _ = opts.OnAgentFinish(agentResult)
955 }
956
957 return agentResult, nil
958}
959
960// filterExecProviderTools returns the subset of executable provider
961// tools permitted by activeTools. When activeTools is empty every
962// tool is included (no filtering).
963func (a *agent) filterExecProviderTools(activeTools []string) []ExecutableProviderTool {
964 if len(activeTools) == 0 {
965 return a.settings.executableProviderTools
966 }
967 filtered := make([]ExecutableProviderTool, 0, len(a.settings.executableProviderTools))
968 for _, ept := range a.settings.executableProviderTools {
969 if slices.Contains(activeTools, ept.GetName()) {
970 filtered = append(filtered, ept)
971 }
972 }
973 return filtered
974}
975
976func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
977 preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
978
979 // If explicitly disabling all tools, return no tools
980 if disableAllTools {
981 return preparedTools
982 }
983
984 for _, tool := range tools {
985 // If activeTools has items, only include tools in the list
986 // If activeTools is empty, include all tools
987 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
988 continue
989 }
990 info := tool.Info()
991 inputSchema := map[string]any{
992 "type": "object",
993 "properties": info.Parameters,
994 "required": info.Required,
995 }
996 schema.Normalize(inputSchema)
997 preparedTools = append(preparedTools, FunctionTool{
998 Name: info.Name,
999 Description: info.Description,
1000 InputSchema: inputSchema,
1001 ProviderOptions: tool.ProviderOptions(),
1002 })
1003 }
1004 for _, tool := range providerDefinedTools {
1005 // If activeTools has items, only include tools in the list. If
1006 // activeTools is empty, include all tools
1007 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
1008 continue
1009 }
1010 preparedTools = append(preparedTools, tool)
1011 }
1012 return preparedTools
1013}
1014
1015// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
1016func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
1017 if err := a.validateToolCall(toolCall, availableTools, execProviderTools); err == nil {
1018 return toolCall
1019 } else { //nolint: revive
1020 if repairFunc != nil {
1021 repairOptions := ToolCallRepairOptions{
1022 OriginalToolCall: toolCall,
1023 ValidationError: err,
1024 AvailableTools: availableTools,
1025 SystemPrompt: systemPrompt,
1026 Messages: messages,
1027 }
1028
1029 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
1030 if validateErr := a.validateToolCall(*repairedToolCall, availableTools, execProviderTools); validateErr == nil {
1031 return *repairedToolCall
1032 }
1033 }
1034 }
1035
1036 invalidToolCall := toolCall
1037 invalidToolCall.Invalid = true
1038 invalidToolCall.ValidationError = err
1039 return invalidToolCall
1040 }
1041}
1042
1043// validateToolCall validates a tool call against available tools and their schemas.
1044// Both availableTools and execProviderTools must already be filtered by the
1045// caller (e.g. via activeTools); this function trusts that the slices
1046// represent exactly the tools permitted for the current step.
1047func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool) error {
1048 var tool AgentTool
1049 for _, t := range availableTools {
1050 if t.Info().Name == toolCall.ToolName {
1051 tool = t
1052 break
1053 }
1054 }
1055
1056 if tool == nil {
1057 // Check if this is an executable provider tool. Provider-
1058 // defined tools have their schema enforced server-side, so
1059 // we only validate that the input is parseable JSON.
1060 for _, ept := range execProviderTools {
1061 if ept.GetName() == toolCall.ToolName {
1062 var input map[string]any
1063 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1064 return fmt.Errorf("invalid JSON input: %w", err)
1065 }
1066 return nil
1067 }
1068 }
1069 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1070 }
1071
1072 // Validate JSON parsing
1073 var input map[string]any
1074 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1075 return fmt.Errorf("invalid JSON input: %w", err)
1076 }
1077
1078 // Basic schema validation (check required fields)
1079 // TODO: more robust schema validation using JSON Schema or similar
1080 toolInfo := tool.Info()
1081 for _, required := range toolInfo.Required {
1082 if _, exists := input[required]; !exists {
1083 return fmt.Errorf("missing required parameter: %s", required)
1084 }
1085 }
1086 return nil
1087}
1088
1089func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1090 // Validation: empty prompt is only allowed when there are messages,
1091 // no files to attach, and the last message is a user or tool message.
1092 if prompt == "" {
1093 lastMessage, hasMessages := slice.Last(messages)
1094
1095 if !hasMessages {
1096 return nil, &Error{
1097 Title: "invalid argument",
1098 Message: "prompt can't be empty when there are no messages",
1099 }
1100 }
1101
1102 if len(files) > 0 {
1103 return nil, &Error{
1104 Title: "invalid argument",
1105 Message: "prompt can't be empty when there are files",
1106 }
1107 }
1108
1109 switch lastMessage.Role {
1110 case MessageRoleUser, MessageRoleTool:
1111 default:
1112 return nil, &Error{
1113 Title: "invalid argument",
1114 Message: "prompt can't be empty when the last message is not a user or tool message",
1115 }
1116 }
1117 }
1118
1119 var preparedPrompt Prompt
1120
1121 if system != "" {
1122 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1123 }
1124 preparedPrompt = append(preparedPrompt, messages...)
1125 if prompt != "" {
1126 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1127 }
1128 return preparedPrompt, nil
1129}
1130
1131// WithSystemPrompt sets the system prompt for the agent.
1132func WithSystemPrompt(prompt string) AgentOption {
1133 return func(s *agentSettings) {
1134 s.systemPrompt = prompt
1135 }
1136}
1137
1138// WithMaxOutputTokens sets the maximum output tokens for the agent.
1139func WithMaxOutputTokens(tokens int64) AgentOption {
1140 return func(s *agentSettings) {
1141 s.maxOutputTokens = &tokens
1142 }
1143}
1144
1145// WithTemperature sets the temperature for the agent.
1146func WithTemperature(temp float64) AgentOption {
1147 return func(s *agentSettings) {
1148 s.temperature = &temp
1149 }
1150}
1151
1152// WithTopP sets the top-p value for the agent.
1153func WithTopP(topP float64) AgentOption {
1154 return func(s *agentSettings) {
1155 s.topP = &topP
1156 }
1157}
1158
1159// WithTopK sets the top-k value for the agent.
1160func WithTopK(topK int64) AgentOption {
1161 return func(s *agentSettings) {
1162 s.topK = &topK
1163 }
1164}
1165
1166// WithPresencePenalty sets the presence penalty for the agent.
1167func WithPresencePenalty(penalty float64) AgentOption {
1168 return func(s *agentSettings) {
1169 s.presencePenalty = &penalty
1170 }
1171}
1172
1173// WithFrequencyPenalty sets the frequency penalty for the agent.
1174func WithFrequencyPenalty(penalty float64) AgentOption {
1175 return func(s *agentSettings) {
1176 s.frequencyPenalty = &penalty
1177 }
1178}
1179
1180// WithTools sets the tools for the agent.
1181func WithTools(tools ...AgentTool) AgentOption {
1182 return func(s *agentSettings) {
1183 s.tools = append(s.tools, tools...)
1184 }
1185}
1186
1187// WithProviderDefinedTools registers provider-defined tools with the
1188// agent. Provider-executed tools (e.g. web search) are passed through
1189// to the API. Client-executed tools (ExecutableProviderTool) are also
1190// registered for local execution.
1191func WithProviderDefinedTools(tools ...ProviderTool) AgentOption {
1192 return func(s *agentSettings) {
1193 for _, t := range tools {
1194 // Every provider tool goes into providerDefinedTools
1195 // for wire formatting.
1196 s.providerDefinedTools = append(
1197 s.providerDefinedTools, t.providerDefinedTool(),
1198 )
1199 // Executable ones also register for local execution.
1200 if exec, ok := t.(ExecutableProviderTool); ok {
1201 s.executableProviderTools = append(
1202 s.executableProviderTools, exec,
1203 )
1204 }
1205 }
1206 }
1207}
1208
1209// WithStopConditions sets the stop conditions for the agent.
1210func WithStopConditions(conditions ...StopCondition) AgentOption {
1211 return func(s *agentSettings) {
1212 s.stopWhen = append(s.stopWhen, conditions...)
1213 }
1214}
1215
1216// WithPrepareStep sets the prepare step function for the agent.
1217func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1218 return func(s *agentSettings) {
1219 s.prepareStep = fn
1220 }
1221}
1222
1223// WithRepairToolCall sets the repair tool call function for the agent.
1224func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1225 return func(s *agentSettings) {
1226 s.repairToolCall = fn
1227 }
1228}
1229
1230// WithMaxRetries sets the maximum number of retries for the agent.
1231func WithMaxRetries(maxRetries int) AgentOption {
1232 return func(s *agentSettings) {
1233 s.maxRetries = &maxRetries
1234 }
1235}
1236
1237// WithOnRetry sets the retry callback for the agent.
1238func WithOnRetry(callback OnRetryCallback) AgentOption {
1239 return func(s *agentSettings) {
1240 s.onRetry = callback
1241 }
1242}
1243
1244// processStepStream processes a single step's stream and returns the step result.
1245func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) {
1246 var stepContent []Content
1247 var stepToolCalls []ToolCallContent
1248 var stepUsage Usage
1249 stepFinishReason := FinishReasonUnknown
1250 var stepWarnings []CallWarning
1251 var stepProviderMetadata ProviderMetadata
1252
1253 activeToolCalls := make(map[string]*ToolCallContent)
1254 activeTextContent := make(map[string]string)
1255 type reasoningContent struct {
1256 content string
1257 options ProviderMetadata
1258 }
1259 activeReasoningContent := make(map[string]reasoningContent)
1260
1261 // Set up concurrent tool execution
1262 type toolExecutionRequest struct {
1263 toolCall ToolCallContent
1264 parallel bool
1265 }
1266 var pendingDispatches []toolExecutionRequest
1267
1268 // Create a map for quick tool lookup
1269 toolMap := make(map[string]AgentTool)
1270 for _, tool := range stepTools {
1271 toolMap[tool.Info().Name] = tool
1272 }
1273
1274 execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
1275 for _, ept := range execProviderTools {
1276 execProviderToolMap[ept.GetName()] = ept
1277 }
1278
1279 // Process stream parts
1280 for part := range stream {
1281 // Forward all parts to chunk callback
1282 if opts.OnChunk != nil {
1283 err := opts.OnChunk(part)
1284 if err != nil {
1285 return stepExecutionResult{}, err
1286 }
1287 }
1288
1289 switch part.Type {
1290 case StreamPartTypeWarnings:
1291 stepWarnings = part.Warnings
1292 if opts.OnWarnings != nil {
1293 err := opts.OnWarnings(part.Warnings)
1294 if err != nil {
1295 return stepExecutionResult{}, err
1296 }
1297 }
1298
1299 case StreamPartTypeTextStart:
1300 activeTextContent[part.ID] = ""
1301 if opts.OnTextStart != nil {
1302 err := opts.OnTextStart(part.ID)
1303 if err != nil {
1304 return stepExecutionResult{}, err
1305 }
1306 }
1307
1308 case StreamPartTypeTextDelta:
1309 if _, exists := activeTextContent[part.ID]; exists {
1310 activeTextContent[part.ID] += part.Delta
1311 }
1312 if opts.OnTextDelta != nil {
1313 err := opts.OnTextDelta(part.ID, part.Delta)
1314 if err != nil {
1315 return stepExecutionResult{}, err
1316 }
1317 }
1318
1319 case StreamPartTypeTextEnd:
1320 if text, exists := activeTextContent[part.ID]; exists {
1321 stepContent = append(stepContent, TextContent{
1322 Text: text,
1323 ProviderMetadata: part.ProviderMetadata,
1324 })
1325 delete(activeTextContent, part.ID)
1326 }
1327 if opts.OnTextEnd != nil {
1328 err := opts.OnTextEnd(part.ID)
1329 if err != nil {
1330 return stepExecutionResult{}, err
1331 }
1332 }
1333
1334 case StreamPartTypeReasoningStart:
1335 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1336 if opts.OnReasoningStart != nil {
1337 content := ReasoningContent{
1338 Text: part.Delta,
1339 ProviderMetadata: part.ProviderMetadata,
1340 }
1341 err := opts.OnReasoningStart(part.ID, content)
1342 if err != nil {
1343 return stepExecutionResult{}, err
1344 }
1345 }
1346
1347 case StreamPartTypeReasoningDelta:
1348 if active, exists := activeReasoningContent[part.ID]; exists {
1349 active.content += part.Delta
1350 if part.ProviderMetadata != nil {
1351 active.options = part.ProviderMetadata
1352 }
1353 activeReasoningContent[part.ID] = active
1354 }
1355 if opts.OnReasoningDelta != nil {
1356 err := opts.OnReasoningDelta(part.ID, part.Delta)
1357 if err != nil {
1358 return stepExecutionResult{}, err
1359 }
1360 }
1361
1362 case StreamPartTypeReasoningEnd:
1363 if active, exists := activeReasoningContent[part.ID]; exists {
1364 if part.ProviderMetadata != nil {
1365 active.options = part.ProviderMetadata
1366 }
1367 content := ReasoningContent{
1368 Text: active.content,
1369 ProviderMetadata: active.options,
1370 }
1371 stepContent = append(stepContent, content)
1372 if opts.OnReasoningEnd != nil {
1373 err := opts.OnReasoningEnd(part.ID, content)
1374 if err != nil {
1375 return stepExecutionResult{}, err
1376 }
1377 }
1378 delete(activeReasoningContent, part.ID)
1379 }
1380
1381 case StreamPartTypeToolInputStart:
1382 activeToolCalls[part.ID] = &ToolCallContent{
1383 ToolCallID: part.ID,
1384 ToolName: part.ToolCallName,
1385 Input: "",
1386 ProviderExecuted: part.ProviderExecuted,
1387 }
1388 if opts.OnToolInputStart != nil {
1389 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1390 if err != nil {
1391 return stepExecutionResult{}, err
1392 }
1393 }
1394
1395 case StreamPartTypeToolInputDelta:
1396 if toolCall, exists := activeToolCalls[part.ID]; exists {
1397 toolCall.Input += part.Delta
1398 }
1399 if opts.OnToolInputDelta != nil {
1400 err := opts.OnToolInputDelta(part.ID, part.Delta)
1401 if err != nil {
1402 return stepExecutionResult{}, err
1403 }
1404 }
1405
1406 case StreamPartTypeToolInputEnd:
1407 if opts.OnToolInputEnd != nil {
1408 err := opts.OnToolInputEnd(part.ID)
1409 if err != nil {
1410 return stepExecutionResult{}, err
1411 }
1412 }
1413
1414 case StreamPartTypeToolCall:
1415 toolCall := ToolCallContent{
1416 ToolCallID: part.ID,
1417 ToolName: part.ToolCallName,
1418 Input: part.ToolCallInput,
1419 ProviderExecuted: part.ProviderExecuted,
1420 ProviderMetadata: part.ProviderMetadata,
1421 }
1422
1423 // Provider-executed tool calls are handled by the provider
1424 // and should not be validated or executed by the agent.
1425 if toolCall.ProviderExecuted {
1426 stepContent = append(stepContent, toolCall)
1427 if opts.OnToolCall != nil {
1428 err := opts.OnToolCall(toolCall)
1429 if err != nil {
1430 return stepExecutionResult{}, err
1431 }
1432 }
1433 delete(activeToolCalls, part.ID)
1434 } else {
1435 // Validate and potentially repair the tool call
1436 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1437 stepToolCalls = append(stepToolCalls, validatedToolCall)
1438 stepContent = append(stepContent, validatedToolCall)
1439
1440 if opts.OnToolCall != nil {
1441 err := opts.OnToolCall(validatedToolCall)
1442 if err != nil {
1443 return stepExecutionResult{}, err
1444 }
1445 }
1446
1447 // Determine if tool can run in parallel
1448 isParallel := false
1449 if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1450 isParallel = tool.Info().Parallel
1451 }
1452
1453 // Buffer dispatch until stream is fully consumed so that all
1454 // OnToolCall callbacks complete before any tool result is written.
1455 pendingDispatches = append(pendingDispatches, toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel})
1456
1457 // Clean up active tool call
1458 delete(activeToolCalls, part.ID)
1459 }
1460
1461 case StreamPartTypeToolResult:
1462 // Provider-executed tool results (e.g. web search)
1463 // are emitted by the provider and added directly
1464 // to the step content for multi-turn round-tripping.
1465 if part.ProviderExecuted {
1466 resultContent := ToolResultContent{
1467 ToolCallID: part.ID,
1468 ToolName: part.ToolCallName,
1469 ProviderExecuted: true,
1470 ProviderMetadata: part.ProviderMetadata,
1471 }
1472 stepContent = append(stepContent, resultContent)
1473 if opts.OnToolResult != nil {
1474 err := opts.OnToolResult(resultContent)
1475 if err != nil {
1476 return stepExecutionResult{}, err
1477 }
1478 }
1479 }
1480
1481 case StreamPartTypeSource:
1482 sourceContent := SourceContent{
1483 SourceType: part.SourceType,
1484 ID: part.ID,
1485 URL: part.URL,
1486 Title: part.Title,
1487 ProviderMetadata: part.ProviderMetadata,
1488 }
1489 stepContent = append(stepContent, sourceContent)
1490 if opts.OnSource != nil {
1491 err := opts.OnSource(sourceContent)
1492 if err != nil {
1493 return stepExecutionResult{}, err
1494 }
1495 }
1496
1497 case StreamPartTypeFinish:
1498 stepUsage = part.Usage
1499 stepFinishReason = part.FinishReason
1500 stepProviderMetadata = part.ProviderMetadata
1501 if opts.OnStreamFinish != nil {
1502 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1503 if err != nil {
1504 return stepExecutionResult{}, err
1505 }
1506 }
1507
1508 case StreamPartTypeError:
1509 return stepExecutionResult{}, part.Error
1510 }
1511 }
1512
1513 // All tool calls are now collected. Create the execution channel sized to
1514 // avoid blocking during dispatch, start the coordinator, then flush the batch.
1515 toolChan := make(chan toolExecutionRequest, len(pendingDispatches))
1516 var toolExecutionWg sync.WaitGroup
1517 var toolStateMu sync.Mutex
1518 toolResults := make([]ToolResultContent, 0, len(pendingDispatches))
1519 var toolExecutionErr error
1520
1521 // Semaphores for controlling parallelism.
1522 parallelSem := make(chan struct{}, 5)
1523 var sequentialMu sync.Mutex
1524
1525 // Single coordinator goroutine that dispatches tools.
1526 toolExecutionWg.Go(func() {
1527 for req := range toolChan {
1528 if req.parallel {
1529 parallelSem <- struct{}{}
1530 toolExecutionWg.Go(func() {
1531 defer func() { <-parallelSem }()
1532 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1533 toolStateMu.Lock()
1534 toolResults = append(toolResults, result)
1535 if isCriticalError && toolExecutionErr == nil {
1536 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1537 toolExecutionErr = errorResult.Error
1538 }
1539 }
1540 toolStateMu.Unlock()
1541 })
1542 } else {
1543 sequentialMu.Lock()
1544 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1545 toolStateMu.Lock()
1546 toolResults = append(toolResults, result)
1547 if isCriticalError && toolExecutionErr == nil {
1548 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1549 toolExecutionErr = errorResult.Error
1550 }
1551 }
1552 toolStateMu.Unlock()
1553 sequentialMu.Unlock()
1554 }
1555 }
1556 })
1557
1558 // Dispatch all buffered tool calls now that every OnToolCall callback has
1559 // been called, then close and wait.
1560 for _, req := range pendingDispatches {
1561 toolChan <- req
1562 }
1563
1564 // Close the tool execution channel and wait for all executions to complete.
1565 close(toolChan)
1566 toolExecutionWg.Wait()
1567
1568 // Check for tool execution errors
1569 if toolExecutionErr != nil {
1570 return stepExecutionResult{}, toolExecutionErr
1571 }
1572
1573 // Add tool results to content if any
1574 if len(toolResults) > 0 {
1575 for _, result := range toolResults {
1576 stepContent = append(stepContent, result)
1577 }
1578 }
1579
1580 stepResult := StepResult{
1581 Response: Response{
1582 Content: stepContent,
1583 FinishReason: stepFinishReason,
1584 Usage: stepUsage,
1585 Warnings: stepWarnings,
1586 ProviderMetadata: stepProviderMetadata,
1587 },
1588 Messages: toResponseMessages(stepContent),
1589 }
1590
1591 // Determine if we should continue (has tool calls and not stopped)
1592 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls && !hasStopTurn(toolResults)
1593
1594 return stepExecutionResult{
1595 StepResult: stepResult,
1596 ShouldContinue: shouldContinue,
1597 }, nil
1598}
1599
1600func addUsage(a, b Usage) Usage {
1601 return Usage{
1602 InputTokens: a.InputTokens + b.InputTokens,
1603 OutputTokens: a.OutputTokens + b.OutputTokens,
1604 TotalTokens: a.TotalTokens + b.TotalTokens,
1605 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1606 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1607 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1608 }
1609}
1610
1611// WithHeaders sets the headers for the agent.
1612func WithHeaders(headers map[string]string) AgentOption {
1613 return func(s *agentSettings) {
1614 s.headers = headers
1615 }
1616}
1617
1618// WithUserAgent sets the User-Agent header for the agent. This overrides any
1619// provider-level User-Agent setting.
1620func WithUserAgent(ua string) AgentOption {
1621 return func(s *agentSettings) {
1622 s.userAgent = ua
1623 }
1624}
1625
1626// WithProviderOptions sets the provider options for the agent.
1627func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1628 return func(s *agentSettings) {
1629 s.providerOptions = providerOptions
1630 }
1631}