1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "slices"
12 "sync"
13
14 "charm.land/fantasy/schema"
15 "github.com/charmbracelet/x/exp/slice"
16)
17
18// StepResult represents the result of a single step in an agent execution.
19type StepResult struct {
20 Response
21 Messages []Message
22}
23
24// stepExecutionResult encapsulates the result of executing a step with stream processing.
25type stepExecutionResult struct {
26 StepResult StepResult
27 ShouldContinue bool
28}
29
30// StopCondition defines a function that determines when an agent should stop executing.
31type StopCondition = func(steps []StepResult) bool
32
33// StepCountIs returns a stop condition that stops after the specified number of steps.
34func StepCountIs(stepCount int) StopCondition {
35 return func(steps []StepResult) bool {
36 return len(steps) >= stepCount
37 }
38}
39
40// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
41func HasToolCall(toolName string) StopCondition {
42 return func(steps []StepResult) bool {
43 if len(steps) == 0 {
44 return false
45 }
46 lastStep := steps[len(steps)-1]
47 toolCalls := lastStep.Content.ToolCalls()
48 for _, toolCall := range toolCalls {
49 if toolCall.ToolName == toolName {
50 return true
51 }
52 }
53 return false
54 }
55}
56
57// HasContent returns a stop condition that stops when the specified content type appears in the last step.
58func HasContent(contentType ContentType) StopCondition {
59 return func(steps []StepResult) bool {
60 if len(steps) == 0 {
61 return false
62 }
63 lastStep := steps[len(steps)-1]
64 for _, content := range lastStep.Content {
65 if content.GetType() == contentType {
66 return true
67 }
68 }
69 return false
70 }
71}
72
73// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
74func FinishReasonIs(reason FinishReason) StopCondition {
75 return func(steps []StepResult) bool {
76 if len(steps) == 0 {
77 return false
78 }
79 lastStep := steps[len(steps)-1]
80 return lastStep.FinishReason == reason
81 }
82}
83
84// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
85func MaxTokensUsed(maxTokens int64) StopCondition {
86 return func(steps []StepResult) bool {
87 var totalTokens int64
88 for _, step := range steps {
89 totalTokens += step.Usage.TotalTokens
90 }
91 return totalTokens >= maxTokens
92 }
93}
94
95// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
96type PrepareStepFunctionOptions struct {
97 Steps []StepResult
98 StepNumber int
99 Model LanguageModel
100 Messages []Message
101}
102
103// PrepareStepResult contains the result of preparing a step in an agent execution.
104type PrepareStepResult struct {
105 Model LanguageModel
106 Messages []Message
107 System *string
108 ToolChoice *ToolChoice
109 ActiveTools []string
110 DisableAllTools bool
111 Tools []AgentTool
112}
113
114// ToolCallRepairOptions contains the options for repairing a tool call.
115type ToolCallRepairOptions struct {
116 OriginalToolCall ToolCallContent
117 ValidationError error
118 AvailableTools []AgentTool
119 SystemPrompt string
120 Messages []Message
121}
122
123type (
124 // PrepareStepFunction defines a function that prepares a step in an agent execution.
125 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
126
127 // OnStepFinishedFunction defines a function that is called when a step finishes.
128 OnStepFinishedFunction = func(step StepResult)
129
130 // RepairToolCallFunction defines a function that repairs a tool call.
131 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
132)
133
134type agentSettings struct {
135 systemPrompt string
136 maxOutputTokens *int64
137 temperature *float64
138 topP *float64
139 topK *int64
140 presencePenalty *float64
141 frequencyPenalty *float64
142 headers map[string]string
143 userAgent string
144 providerOptions ProviderOptions
145
146 providerDefinedTools []ProviderDefinedTool
147 executableProviderTools []ExecutableProviderTool
148 tools []AgentTool
149 toolChoice *ToolChoice
150 maxRetries *int
151
152 model LanguageModel
153
154 stopWhen []StopCondition
155 prepareStep PrepareStepFunction
156 repairToolCall RepairToolCallFunction
157 onRetry OnRetryCallback
158}
159
160// AgentCall represents a call to an agent.
161type AgentCall struct {
162 Prompt string `json:"prompt"`
163 Files []FilePart `json:"files"`
164 Messages []Message `json:"messages"`
165 MaxOutputTokens *int64
166 Temperature *float64 `json:"temperature"`
167 TopP *float64 `json:"top_p"`
168 TopK *int64 `json:"top_k"`
169 PresencePenalty *float64 `json:"presence_penalty"`
170 FrequencyPenalty *float64 `json:"frequency_penalty"`
171 ActiveTools []string `json:"active_tools"`
172 ToolChoice *ToolChoice `json:"tool_choice"`
173 ProviderOptions ProviderOptions
174 OnRetry OnRetryCallback
175 MaxRetries *int
176
177 StopWhen []StopCondition
178 PrepareStep PrepareStepFunction
179 RepairToolCall RepairToolCallFunction
180}
181
182// Agent-level callbacks.
183type (
184 // OnAgentStartFunc is called when agent starts.
185 OnAgentStartFunc func()
186
187 // OnAgentFinishFunc is called when agent finishes.
188 OnAgentFinishFunc func(result *AgentResult) error
189
190 // OnStepStartFunc is called when a step starts.
191 OnStepStartFunc func(stepNumber int) error
192
193 // OnStepFinishFunc is called when a step finishes.
194 OnStepFinishFunc func(stepResult StepResult) error
195
196 // OnFinishFunc is called when entire agent completes.
197 OnFinishFunc func(result *AgentResult)
198
199 // OnErrorFunc is called when an error occurs.
200 OnErrorFunc func(error)
201)
202
203// Stream part callbacks - called for each corresponding stream part type.
204type (
205 // OnChunkFunc is called for each stream part (catch-all).
206 OnChunkFunc func(StreamPart) error
207
208 // OnWarningsFunc is called for warnings.
209 OnWarningsFunc func(warnings []CallWarning) error
210
211 // OnTextStartFunc is called when text starts.
212 OnTextStartFunc func(id string) error
213
214 // OnTextDeltaFunc is called for text deltas.
215 OnTextDeltaFunc func(id, text string) error
216
217 // OnTextEndFunc is called when text ends.
218 OnTextEndFunc func(id string) error
219
220 // OnReasoningStartFunc is called when reasoning starts.
221 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
222
223 // OnReasoningDeltaFunc is called for reasoning deltas.
224 OnReasoningDeltaFunc func(id, text string) error
225
226 // OnReasoningEndFunc is called when reasoning ends.
227 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
228
229 // OnToolInputStartFunc is called when tool input starts.
230 OnToolInputStartFunc func(id, toolName string) error
231
232 // OnToolInputDeltaFunc is called for tool input deltas.
233 OnToolInputDeltaFunc func(id, delta string) error
234
235 // OnToolInputEndFunc is called when tool input ends.
236 OnToolInputEndFunc func(id string) error
237
238 // OnToolCallFunc is called when tool call is complete.
239 OnToolCallFunc func(toolCall ToolCallContent) error
240
241 // OnToolResultFunc is called when tool execution completes.
242 OnToolResultFunc func(result ToolResultContent) error
243
244 // OnSourceFunc is called for source references.
245 OnSourceFunc func(source SourceContent) error
246
247 // OnStreamFinishFunc is called when stream finishes.
248 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
249)
250
251// AgentStreamCall represents a streaming call to an agent.
252type AgentStreamCall struct {
253 Prompt string `json:"prompt"`
254 Files []FilePart `json:"files"`
255 Messages []Message `json:"messages"`
256 MaxOutputTokens *int64
257 Temperature *float64 `json:"temperature"`
258 TopP *float64 `json:"top_p"`
259 TopK *int64 `json:"top_k"`
260 PresencePenalty *float64 `json:"presence_penalty"`
261 FrequencyPenalty *float64 `json:"frequency_penalty"`
262 ActiveTools []string `json:"active_tools"`
263 ToolChoice *ToolChoice `json:"tool_choice"`
264 Headers map[string]string
265 ProviderOptions ProviderOptions
266 OnRetry OnRetryCallback
267 MaxRetries *int
268
269 StopWhen []StopCondition
270 PrepareStep PrepareStepFunction
271 RepairToolCall RepairToolCallFunction
272
273 // Agent-level callbacks
274 OnAgentStart OnAgentStartFunc // Called when agent starts
275 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
276 OnStepStart OnStepStartFunc // Called when a step starts
277 OnStepFinish OnStepFinishFunc // Called when a step finishes
278 OnFinish OnFinishFunc // Called when entire agent completes
279 OnError OnErrorFunc // Called when an error occurs
280
281 // Stream part callbacks - called for each corresponding stream part type
282 OnChunk OnChunkFunc // Called for each stream part (catch-all)
283 OnWarnings OnWarningsFunc // Called for warnings
284 OnTextStart OnTextStartFunc // Called when text starts
285 OnTextDelta OnTextDeltaFunc // Called for text deltas
286 OnTextEnd OnTextEndFunc // Called when text ends
287 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
288 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
289 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
290 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
291 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
292 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
293 OnToolCall OnToolCallFunc // Called when tool call is complete
294 OnToolResult OnToolResultFunc // Called when tool execution completes
295 OnSource OnSourceFunc // Called for source references
296 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
297}
298
299// AgentResult represents the result of an agent execution.
300type AgentResult struct {
301 Steps []StepResult
302 // Final response
303 Response Response
304 TotalUsage Usage
305}
306
307// Agent represents an AI agent that can generate responses and stream responses.
308type Agent interface {
309 Generate(context.Context, AgentCall) (*AgentResult, error)
310 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
311}
312
313// AgentOption defines a function that configures agent settings.
314type AgentOption = func(*agentSettings)
315
316type agent struct {
317 settings agentSettings
318}
319
320// NewAgent creates a new agent with the given language model and options.
321func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
322 settings := agentSettings{
323 model: model,
324 }
325 for _, o := range opts {
326 o(&settings)
327 }
328 return &agent{
329 settings: settings,
330 }
331}
332
333func (a *agent) prepareCall(call AgentCall) AgentCall {
334 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
335 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
336 call.TopP = cmp.Or(call.TopP, a.settings.topP)
337 call.TopK = cmp.Or(call.TopK, a.settings.topK)
338 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
339 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
340 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
341 call.ToolChoice = cmp.Or(call.ToolChoice, a.settings.toolChoice)
342
343 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
344 call.StopWhen = a.settings.stopWhen
345 }
346 if call.PrepareStep == nil && a.settings.prepareStep != nil {
347 call.PrepareStep = a.settings.prepareStep
348 }
349 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
350 call.RepairToolCall = a.settings.repairToolCall
351 }
352 if call.OnRetry == nil && a.settings.onRetry != nil {
353 call.OnRetry = a.settings.onRetry
354 }
355
356 providerOptions := ProviderOptions{}
357 if a.settings.providerOptions != nil {
358 maps.Copy(providerOptions, a.settings.providerOptions)
359 }
360 if call.ProviderOptions != nil {
361 maps.Copy(providerOptions, call.ProviderOptions)
362 }
363 call.ProviderOptions = providerOptions
364
365 headers := map[string]string{}
366
367 if a.settings.headers != nil {
368 maps.Copy(headers, a.settings.headers)
369 }
370
371 return call
372}
373
374// Generate implements Agent.
375func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
376 opts = a.prepareCall(opts)
377 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
378 if err != nil {
379 return nil, err
380 }
381 var responseMessages []Message
382 var steps []StepResult
383
384 for {
385 stepInputMessages := append(initialPrompt, responseMessages...)
386 stepModel := a.settings.model
387 stepSystemPrompt := a.settings.systemPrompt
388 stepActiveTools := opts.ActiveTools
389 stepToolChoice := ToolChoiceAuto
390 if opts.ToolChoice != nil {
391 stepToolChoice = *opts.ToolChoice
392 }
393 disableAllTools := false
394 stepTools := a.settings.tools
395 if opts.PrepareStep != nil {
396 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
397 Model: stepModel,
398 Steps: steps,
399 StepNumber: len(steps),
400 Messages: stepInputMessages,
401 })
402 if err != nil {
403 return nil, err
404 }
405
406 ctx = updatedCtx
407
408 // Apply prepared step modifications
409 if prepared.Messages != nil {
410 stepInputMessages = prepared.Messages
411 }
412 if prepared.Model != nil {
413 stepModel = prepared.Model
414 }
415 if prepared.System != nil {
416 stepSystemPrompt = *prepared.System
417 }
418 if prepared.ToolChoice != nil {
419 stepToolChoice = *prepared.ToolChoice
420 }
421 if len(prepared.ActiveTools) > 0 {
422 stepActiveTools = prepared.ActiveTools
423 }
424 disableAllTools = prepared.DisableAllTools
425 if prepared.Tools != nil {
426 stepTools = prepared.Tools
427 }
428 }
429
430 // Recreate prompt with potentially modified system prompt
431 if stepSystemPrompt != a.settings.systemPrompt {
432 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
433 if err != nil {
434 return nil, err
435 }
436 // Replace system message part, keep the rest
437 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
438 stepInputMessages[0] = stepPrompt[0] // Replace system message
439 }
440 }
441
442 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
443
444 // Filter executable provider tools by activeTools at the
445 // step level, consistent with how stepTools (AgentTools)
446 // are scoped before being passed to inner functions.
447 stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
448
449 retryOptions := DefaultRetryOptions()
450 if opts.MaxRetries != nil {
451 retryOptions.MaxRetries = *opts.MaxRetries
452 }
453 retryOptions.OnRetry = opts.OnRetry
454 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
455 result, err := retry(ctx, func() (*Response, error) {
456 return stepModel.Generate(ctx, Call{
457 Prompt: stepInputMessages,
458 MaxOutputTokens: opts.MaxOutputTokens,
459 Temperature: opts.Temperature,
460 TopP: opts.TopP,
461 TopK: opts.TopK,
462 PresencePenalty: opts.PresencePenalty,
463 FrequencyPenalty: opts.FrequencyPenalty,
464 Tools: preparedTools,
465 ToolChoice: &stepToolChoice,
466 UserAgent: a.settings.userAgent,
467 ProviderOptions: opts.ProviderOptions,
468 })
469 })
470 if err != nil {
471 return nil, err
472 }
473
474 var stepToolCalls []ToolCallContent
475 for _, content := range result.Content {
476 if content.GetType() == ContentTypeToolCall {
477 toolCall, ok := AsContentType[ToolCallContent](content)
478 if !ok {
479 continue
480 }
481 // Provider-executed tool calls (e.g. web search) are
482 // handled by the provider and should not be validated
483 // or executed by the agent.
484 if toolCall.ProviderExecuted {
485 continue
486 }
487 // Validate and potentially repair the tool call
488 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepExecProviderTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
489 stepToolCalls = append(stepToolCalls, validatedToolCall)
490 }
491 }
492
493 toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)
494
495 // If any tool result requested a stop, deliver all results but don't
496 // request another completion from the model.
497 stopTurnRequested := hasStopTurn(toolResults)
498
499 // Build step content with validated tool calls and tool results.
500 // Provider-executed tool calls are kept as-is.
501 stepContent := []Content{}
502 toolCallIndex := 0
503 for _, content := range result.Content {
504 if content.GetType() == ContentTypeToolCall {
505 tc, ok := AsContentType[ToolCallContent](content)
506 if ok && tc.ProviderExecuted {
507 stepContent = append(stepContent, content)
508 continue
509 }
510 // Replace with validated tool call.
511 if toolCallIndex < len(stepToolCalls) {
512 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
513 toolCallIndex++
514 }
515 } else {
516 stepContent = append(stepContent, content)
517 }
518 } // Add tool results
519 for _, result := range toolResults {
520 stepContent = append(stepContent, result)
521 }
522 currentStepMessages := toResponseMessages(stepContent)
523 responseMessages = append(responseMessages, currentStepMessages...)
524
525 stepResult := StepResult{
526 Response: Response{
527 Content: stepContent,
528 FinishReason: result.FinishReason,
529 Usage: result.Usage,
530 Warnings: result.Warnings,
531 ProviderMetadata: result.ProviderMetadata,
532 },
533 Messages: currentStepMessages,
534 }
535 steps = append(steps, stepResult)
536 shouldStop := isStopConditionMet(opts.StopWhen, steps)
537
538 if shouldStop || err != nil || stopTurnRequested || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
539 break
540 }
541 }
542
543 totalUsage := Usage{}
544
545 for _, step := range steps {
546 usage := step.Usage
547 totalUsage.InputTokens += usage.InputTokens
548 totalUsage.OutputTokens += usage.OutputTokens
549 totalUsage.ReasoningTokens += usage.ReasoningTokens
550 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
551 totalUsage.CacheReadTokens += usage.CacheReadTokens
552 totalUsage.TotalTokens += usage.TotalTokens
553 }
554
555 agentResult := &AgentResult{
556 Steps: steps,
557 Response: steps[len(steps)-1].Response,
558 TotalUsage: totalUsage,
559 }
560 return agentResult, nil
561}
562
563func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
564 if len(conditions) == 0 {
565 return false
566 }
567
568 for _, condition := range conditions {
569 if condition(steps) {
570 return true
571 }
572 }
573 return false
574}
575
576func hasStopTurn(results []ToolResultContent) bool {
577 for _, r := range results {
578 if r.StopTurn {
579 return true
580 }
581 }
582 return false
583}
584
585func toResponseMessages(content []Content) []Message {
586 var assistantParts []MessagePart
587 var toolParts []MessagePart
588
589 for _, c := range content {
590 switch c.GetType() {
591 case ContentTypeText:
592 text, ok := AsContentType[TextContent](c)
593 if !ok {
594 continue
595 }
596 assistantParts = append(assistantParts, TextPart{
597 Text: text.Text,
598 ProviderOptions: ProviderOptions(text.ProviderMetadata),
599 })
600 case ContentTypeReasoning:
601 reasoning, ok := AsContentType[ReasoningContent](c)
602 if !ok {
603 continue
604 }
605 assistantParts = append(assistantParts, ReasoningPart{
606 Text: reasoning.Text,
607 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
608 })
609 case ContentTypeToolCall:
610 toolCall, ok := AsContentType[ToolCallContent](c)
611 if !ok {
612 continue
613 }
614 assistantParts = append(assistantParts, ToolCallPart{
615 ToolCallID: toolCall.ToolCallID,
616 ToolName: toolCall.ToolName,
617 Input: toolCall.Input,
618 ProviderExecuted: toolCall.ProviderExecuted,
619 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
620 })
621 case ContentTypeFile:
622 file, ok := AsContentType[FileContent](c)
623 if !ok {
624 continue
625 }
626 assistantParts = append(assistantParts, FilePart{
627 Data: file.Data,
628 MediaType: file.MediaType,
629 ProviderOptions: ProviderOptions(file.ProviderMetadata),
630 })
631 case ContentTypeSource:
632 // Sources are metadata about references used to generate the response.
633 // They don't need to be included in the conversation messages.
634 continue
635 case ContentTypeToolResult:
636 result, ok := AsContentType[ToolResultContent](c)
637 if !ok {
638 continue
639 }
640 resultPart := ToolResultPart{
641 ToolCallID: result.ToolCallID,
642 Output: result.Result,
643 ProviderExecuted: result.ProviderExecuted,
644 ProviderOptions: ProviderOptions(result.ProviderMetadata),
645 }
646 if result.ProviderExecuted {
647 // Provider-executed tool results (e.g. web search)
648 // belong in the assistant message alongside the
649 // server_tool_use block that produced them.
650 assistantParts = append(assistantParts, resultPart)
651 } else {
652 toolParts = append(toolParts, resultPart)
653 }
654 }
655 }
656
657 var messages []Message
658 if len(assistantParts) > 0 {
659 messages = append(messages, Message{
660 Role: MessageRoleAssistant,
661 Content: assistantParts,
662 })
663 }
664 if len(toolParts) > 0 {
665 messages = append(messages, Message{
666 Role: MessageRoleTool,
667 Content: toolParts,
668 })
669 }
670 return messages
671}
672
673func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, execProviderTools []ExecutableProviderTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
674 if len(toolCalls) == 0 {
675 return nil, nil
676 }
677
678 // Create a map for quick tool lookup
679 toolMap := make(map[string]AgentTool)
680 for _, tool := range allTools {
681 toolMap[tool.Info().Name] = tool
682 }
683
684 execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
685 for _, ept := range execProviderTools {
686 execProviderToolMap[ept.GetName()] = ept
687 }
688
689 // Execute all tool calls sequentially in order
690 results := make([]ToolResultContent, 0, len(toolCalls))
691
692 for _, toolCall := range toolCalls {
693 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, toolCall, toolResultCallback)
694 results = append(results, result)
695 if isCriticalError {
696 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
697 return nil, errorResult.Error
698 }
699 }
700 }
701
702 return results, nil
703}
704
705// executeSingleTool executes a single tool and returns its result and a critical error flag.
706func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, execProviderToolMap map[string]ExecutableProviderTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
707 result := ToolResultContent{
708 ToolCallID: toolCall.ToolCallID,
709 ToolName: toolCall.ToolName,
710 ProviderExecuted: false,
711 }
712
713 // Skip invalid tool calls - create error result (not critical)
714 if toolCall.Invalid {
715 result.Result = ToolResultOutputContentError{
716 Error: toolCall.ValidationError,
717 }
718 if toolResultCallback != nil {
719 _ = toolResultCallback(result)
720 }
721 return result, false
722 }
723
724 // Find the run function — either from a regular AgentTool or an
725 // executable provider tool.
726 var runTool func(ctx context.Context, call ToolCall) (ToolResponse, error)
727 if tool, exists := toolMap[toolCall.ToolName]; exists {
728 runTool = tool.Run
729 } else if ept, ok := execProviderToolMap[toolCall.ToolName]; ok {
730 runTool = ept.Run
731 }
732 if runTool == nil {
733 result.Result = ToolResultOutputContentError{
734 Error: errors.New("tool not found: " + toolCall.ToolName),
735 }
736 if toolResultCallback != nil {
737 _ = toolResultCallback(result)
738 }
739 return result, false
740 }
741
742 // Execute the tool
743 toolResult, err := runTool(ctx, ToolCall{
744 ID: toolCall.ToolCallID,
745 Name: toolCall.ToolName,
746 Input: toolCall.Input,
747 })
748 if err != nil {
749 result.Result = ToolResultOutputContentError{
750 Error: err,
751 }
752 result.ClientMetadata = toolResult.Metadata
753 result.StopTurn = toolResult.StopTurn
754 if toolResultCallback != nil {
755 _ = toolResultCallback(result)
756 }
757 return result, true
758 }
759
760 result.ClientMetadata = toolResult.Metadata
761 result.StopTurn = toolResult.StopTurn
762 if toolResult.IsError {
763 result.Result = ToolResultOutputContentError{
764 Error: errors.New(toolResult.Content),
765 }
766 } else if toolResult.Type == "image" || toolResult.Type == "media" {
767 result.Result = ToolResultOutputContentMedia{
768 Data: base64.StdEncoding.EncodeToString(toolResult.Data),
769 MediaType: toolResult.MediaType,
770 Text: toolResult.Content,
771 }
772 } else {
773 result.Result = ToolResultOutputContentText{
774 Text: toolResult.Content,
775 }
776 }
777 if toolResultCallback != nil {
778 _ = toolResultCallback(result)
779 }
780 return result, false
781}
782
783// Stream implements Agent.
784func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
785 // Convert AgentStreamCall to AgentCall for preparation
786 call := AgentCall{
787 Prompt: opts.Prompt,
788 Files: opts.Files,
789 Messages: opts.Messages,
790 MaxOutputTokens: opts.MaxOutputTokens,
791 Temperature: opts.Temperature,
792 TopP: opts.TopP,
793 TopK: opts.TopK,
794 PresencePenalty: opts.PresencePenalty,
795 FrequencyPenalty: opts.FrequencyPenalty,
796 ActiveTools: opts.ActiveTools,
797 ToolChoice: opts.ToolChoice,
798 ProviderOptions: opts.ProviderOptions,
799 MaxRetries: opts.MaxRetries,
800 OnRetry: opts.OnRetry,
801 StopWhen: opts.StopWhen,
802 PrepareStep: opts.PrepareStep,
803 RepairToolCall: opts.RepairToolCall,
804 }
805
806 call = a.prepareCall(call)
807
808 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
809 if err != nil {
810 return nil, err
811 }
812
813 var responseMessages []Message
814 var steps []StepResult
815 var totalUsage Usage
816
817 // Start agent stream
818 if opts.OnAgentStart != nil {
819 opts.OnAgentStart()
820 }
821
822 for stepNumber := 0; ; stepNumber++ {
823 stepInputMessages := append(initialPrompt, responseMessages...)
824 stepModel := a.settings.model
825 stepSystemPrompt := a.settings.systemPrompt
826 stepActiveTools := call.ActiveTools
827 stepToolChoice := ToolChoiceAuto
828 if call.ToolChoice != nil {
829 stepToolChoice = *call.ToolChoice
830 }
831 disableAllTools := false
832 stepTools := a.settings.tools
833 // Apply step preparation if provided
834 if call.PrepareStep != nil {
835 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
836 Model: stepModel,
837 Steps: steps,
838 StepNumber: stepNumber,
839 Messages: stepInputMessages,
840 })
841 if err != nil {
842 return nil, err
843 }
844
845 ctx = updatedCtx
846
847 if prepared.Messages != nil {
848 stepInputMessages = prepared.Messages
849 }
850 if prepared.Model != nil {
851 stepModel = prepared.Model
852 }
853 if prepared.System != nil {
854 stepSystemPrompt = *prepared.System
855 }
856 if prepared.ToolChoice != nil {
857 stepToolChoice = *prepared.ToolChoice
858 }
859 if len(prepared.ActiveTools) > 0 {
860 stepActiveTools = prepared.ActiveTools
861 }
862 disableAllTools = prepared.DisableAllTools
863 if prepared.Tools != nil {
864 stepTools = prepared.Tools
865 }
866 }
867
868 // Recreate prompt with potentially modified system prompt
869 if stepSystemPrompt != a.settings.systemPrompt {
870 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
871 if err != nil {
872 return nil, err
873 }
874 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
875 stepInputMessages[0] = stepPrompt[0]
876 }
877 }
878
879 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
880
881 // Filter executable provider tools by activeTools at the
882 // step level, consistent with how stepTools (AgentTools)
883 // are scoped before being passed to inner functions.
884 stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
885
886 // Start step stream
887 if opts.OnStepStart != nil {
888 _ = opts.OnStepStart(stepNumber)
889 }
890 // Create streaming call
891 streamCall := Call{
892 Prompt: stepInputMessages,
893 MaxOutputTokens: call.MaxOutputTokens,
894 Temperature: call.Temperature,
895 TopP: call.TopP,
896 TopK: call.TopK,
897 PresencePenalty: call.PresencePenalty,
898 FrequencyPenalty: call.FrequencyPenalty,
899 Tools: preparedTools,
900 ToolChoice: &stepToolChoice,
901 UserAgent: a.settings.userAgent,
902 ProviderOptions: call.ProviderOptions,
903 }
904
905 // Execute step with retry logic wrapping both stream creation and processing
906 retryOptions := DefaultRetryOptions()
907 if call.MaxRetries != nil {
908 retryOptions.MaxRetries = *call.MaxRetries
909 }
910 retryOptions.OnRetry = call.OnRetry
911 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
912
913 result, err := retry(ctx, func() (stepExecutionResult, error) {
914 // Create the stream
915 stream, err := stepModel.Stream(ctx, streamCall)
916 if err != nil {
917 return stepExecutionResult{}, err
918 }
919
920 // Process the stream
921 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools)
922 if err != nil {
923 return stepExecutionResult{}, err
924 }
925 return result, nil
926 })
927 if err != nil {
928 if opts.OnError != nil {
929 opts.OnError(err)
930 }
931 return nil, err
932 }
933
934 steps = append(steps, result.StepResult)
935 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
936
937 // Call step finished callback
938 if opts.OnStepFinish != nil {
939 _ = opts.OnStepFinish(result.StepResult)
940 }
941
942 // Add step messages to response messages
943 stepMessages := toResponseMessages(result.StepResult.Content)
944 responseMessages = append(responseMessages, stepMessages...)
945
946 // Check stop conditions
947 shouldStop := isStopConditionMet(call.StopWhen, steps)
948 if shouldStop || !result.ShouldContinue {
949 break
950 }
951 }
952
953 // Finish agent stream
954 agentResult := &AgentResult{
955 Steps: steps,
956 Response: steps[len(steps)-1].Response,
957 TotalUsage: totalUsage,
958 }
959
960 if opts.OnFinish != nil {
961 opts.OnFinish(agentResult)
962 }
963
964 if opts.OnAgentFinish != nil {
965 _ = opts.OnAgentFinish(agentResult)
966 }
967
968 return agentResult, nil
969}
970
971// filterExecProviderTools returns the subset of executable provider
972// tools permitted by activeTools. When activeTools is empty every
973// tool is included (no filtering).
974func (a *agent) filterExecProviderTools(activeTools []string) []ExecutableProviderTool {
975 if len(activeTools) == 0 {
976 return a.settings.executableProviderTools
977 }
978 filtered := make([]ExecutableProviderTool, 0, len(a.settings.executableProviderTools))
979 for _, ept := range a.settings.executableProviderTools {
980 if slices.Contains(activeTools, ept.GetName()) {
981 filtered = append(filtered, ept)
982 }
983 }
984 return filtered
985}
986
987func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
988 preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
989
990 // If explicitly disabling all tools, return no tools
991 if disableAllTools {
992 return preparedTools
993 }
994
995 for _, tool := range tools {
996 // If activeTools has items, only include tools in the list
997 // If activeTools is empty, include all tools
998 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
999 continue
1000 }
1001 info := tool.Info()
1002 inputSchema := map[string]any{
1003 "type": "object",
1004 "properties": info.Parameters,
1005 "required": info.Required,
1006 }
1007 schema.Normalize(inputSchema)
1008 preparedTools = append(preparedTools, FunctionTool{
1009 Name: info.Name,
1010 Description: info.Description,
1011 InputSchema: inputSchema,
1012 ProviderOptions: tool.ProviderOptions(),
1013 })
1014 }
1015 for _, tool := range providerDefinedTools {
1016 // If activeTools has items, only include tools in the list. If
1017 // activeTools is empty, include all tools
1018 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
1019 continue
1020 }
1021 preparedTools = append(preparedTools, tool)
1022 }
1023 return preparedTools
1024}
1025
1026// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
1027func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
1028 if err := a.validateToolCall(toolCall, availableTools, execProviderTools); err == nil {
1029 return toolCall
1030 } else { //nolint: revive
1031 if repairFunc != nil {
1032 repairOptions := ToolCallRepairOptions{
1033 OriginalToolCall: toolCall,
1034 ValidationError: err,
1035 AvailableTools: availableTools,
1036 SystemPrompt: systemPrompt,
1037 Messages: messages,
1038 }
1039
1040 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
1041 if validateErr := a.validateToolCall(*repairedToolCall, availableTools, execProviderTools); validateErr == nil {
1042 return *repairedToolCall
1043 }
1044 }
1045 }
1046
1047 invalidToolCall := toolCall
1048 invalidToolCall.Invalid = true
1049 invalidToolCall.ValidationError = err
1050 return invalidToolCall
1051 }
1052}
1053
1054// validateToolCall validates a tool call against available tools and their schemas.
1055// Both availableTools and execProviderTools must already be filtered by the
1056// caller (e.g. via activeTools); this function trusts that the slices
1057// represent exactly the tools permitted for the current step.
1058func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool) error {
1059 var tool AgentTool
1060 for _, t := range availableTools {
1061 if t.Info().Name == toolCall.ToolName {
1062 tool = t
1063 break
1064 }
1065 }
1066
1067 if tool == nil {
1068 // Check if this is an executable provider tool. Provider-
1069 // defined tools have their schema enforced server-side, so
1070 // we only validate that the input is parseable JSON.
1071 for _, ept := range execProviderTools {
1072 if ept.GetName() == toolCall.ToolName {
1073 var input map[string]any
1074 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1075 return fmt.Errorf("invalid JSON input: %w", err)
1076 }
1077 return nil
1078 }
1079 }
1080 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1081 }
1082
1083 // Validate JSON parsing
1084 var input map[string]any
1085 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1086 return fmt.Errorf("invalid JSON input: %w", err)
1087 }
1088
1089 // Basic schema validation (check required fields)
1090 // TODO: more robust schema validation using JSON Schema or similar
1091 toolInfo := tool.Info()
1092 for _, required := range toolInfo.Required {
1093 if _, exists := input[required]; !exists {
1094 return fmt.Errorf("missing required parameter: %s", required)
1095 }
1096 }
1097 return nil
1098}
1099
1100func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1101 // Validation: empty prompt is only allowed when there are messages,
1102 // no files to attach, and the last message is a user or tool message.
1103 if prompt == "" {
1104 lastMessage, hasMessages := slice.Last(messages)
1105
1106 if !hasMessages {
1107 return nil, &Error{
1108 Title: "invalid argument",
1109 Message: "prompt can't be empty when there are no messages",
1110 }
1111 }
1112
1113 if len(files) > 0 {
1114 return nil, &Error{
1115 Title: "invalid argument",
1116 Message: "prompt can't be empty when there are files",
1117 }
1118 }
1119
1120 switch lastMessage.Role {
1121 case MessageRoleUser, MessageRoleTool:
1122 default:
1123 return nil, &Error{
1124 Title: "invalid argument",
1125 Message: "prompt can't be empty when the last message is not a user or tool message",
1126 }
1127 }
1128 }
1129
1130 var preparedPrompt Prompt
1131
1132 if system != "" {
1133 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1134 }
1135 preparedPrompt = append(preparedPrompt, messages...)
1136 if prompt != "" {
1137 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1138 }
1139 return preparedPrompt, nil
1140}
1141
1142// WithSystemPrompt sets the system prompt for the agent.
1143func WithSystemPrompt(prompt string) AgentOption {
1144 return func(s *agentSettings) {
1145 s.systemPrompt = prompt
1146 }
1147}
1148
1149// WithMaxOutputTokens sets the maximum output tokens for the agent.
1150func WithMaxOutputTokens(tokens int64) AgentOption {
1151 return func(s *agentSettings) {
1152 s.maxOutputTokens = &tokens
1153 }
1154}
1155
1156// WithTemperature sets the temperature for the agent.
1157func WithTemperature(temp float64) AgentOption {
1158 return func(s *agentSettings) {
1159 s.temperature = &temp
1160 }
1161}
1162
1163// WithTopP sets the top-p value for the agent.
1164func WithTopP(topP float64) AgentOption {
1165 return func(s *agentSettings) {
1166 s.topP = &topP
1167 }
1168}
1169
1170// WithTopK sets the top-k value for the agent.
1171func WithTopK(topK int64) AgentOption {
1172 return func(s *agentSettings) {
1173 s.topK = &topK
1174 }
1175}
1176
1177// WithPresencePenalty sets the presence penalty for the agent.
1178func WithPresencePenalty(penalty float64) AgentOption {
1179 return func(s *agentSettings) {
1180 s.presencePenalty = &penalty
1181 }
1182}
1183
1184// WithFrequencyPenalty sets the frequency penalty for the agent.
1185func WithFrequencyPenalty(penalty float64) AgentOption {
1186 return func(s *agentSettings) {
1187 s.frequencyPenalty = &penalty
1188 }
1189}
1190
1191// WithTools sets the tools for the agent.
1192func WithTools(tools ...AgentTool) AgentOption {
1193 return func(s *agentSettings) {
1194 s.tools = append(s.tools, tools...)
1195 }
1196}
1197
1198// WithProviderDefinedTools registers provider-defined tools with the
1199// agent. Provider-executed tools (e.g. web search) are passed through
1200// to the API. Client-executed tools (ExecutableProviderTool) are also
1201// registered for local execution.
1202func WithProviderDefinedTools(tools ...ProviderTool) AgentOption {
1203 return func(s *agentSettings) {
1204 for _, t := range tools {
1205 // Every provider tool goes into providerDefinedTools
1206 // for wire formatting.
1207 s.providerDefinedTools = append(
1208 s.providerDefinedTools, t.providerDefinedTool(),
1209 )
1210 // Executable ones also register for local execution.
1211 if exec, ok := t.(ExecutableProviderTool); ok {
1212 s.executableProviderTools = append(
1213 s.executableProviderTools, exec,
1214 )
1215 }
1216 }
1217 }
1218}
1219
1220// WithToolChoice sets the default tool choice for the agent. It is overridden
1221// by the ToolChoice on a specific call, and by PrepareStep at the step level.
1222func WithToolChoice(choice ToolChoice) AgentOption {
1223 return func(s *agentSettings) {
1224 s.toolChoice = &choice
1225 }
1226}
1227
1228// WithStopConditions sets the stop conditions for the agent.
1229func WithStopConditions(conditions ...StopCondition) AgentOption {
1230 return func(s *agentSettings) {
1231 s.stopWhen = append(s.stopWhen, conditions...)
1232 }
1233}
1234
1235// WithPrepareStep sets the prepare step function for the agent.
1236func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1237 return func(s *agentSettings) {
1238 s.prepareStep = fn
1239 }
1240}
1241
1242// WithRepairToolCall sets the repair tool call function for the agent.
1243func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1244 return func(s *agentSettings) {
1245 s.repairToolCall = fn
1246 }
1247}
1248
1249// WithMaxRetries sets the maximum number of retries for the agent.
1250func WithMaxRetries(maxRetries int) AgentOption {
1251 return func(s *agentSettings) {
1252 s.maxRetries = &maxRetries
1253 }
1254}
1255
1256// WithOnRetry sets the retry callback for the agent.
1257func WithOnRetry(callback OnRetryCallback) AgentOption {
1258 return func(s *agentSettings) {
1259 s.onRetry = callback
1260 }
1261}
1262
1263// processStepStream processes a single step's stream and returns the step result.
1264func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) {
1265 var stepContent []Content
1266 var stepToolCalls []ToolCallContent
1267 var stepUsage Usage
1268 stepFinishReason := FinishReasonUnknown
1269 var stepWarnings []CallWarning
1270 var stepProviderMetadata ProviderMetadata
1271
1272 activeToolCalls := make(map[string]*ToolCallContent)
1273 activeTextContent := make(map[string]string)
1274 type reasoningContent struct {
1275 content string
1276 options ProviderMetadata
1277 }
1278 activeReasoningContent := make(map[string]reasoningContent)
1279
1280 // Set up concurrent tool execution
1281 type toolExecutionRequest struct {
1282 toolCall ToolCallContent
1283 parallel bool
1284 }
1285 var pendingDispatches []toolExecutionRequest
1286
1287 // Create a map for quick tool lookup
1288 toolMap := make(map[string]AgentTool)
1289 for _, tool := range stepTools {
1290 toolMap[tool.Info().Name] = tool
1291 }
1292
1293 execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
1294 for _, ept := range execProviderTools {
1295 execProviderToolMap[ept.GetName()] = ept
1296 }
1297
1298 // Process stream parts
1299 for part := range stream {
1300 // Forward all parts to chunk callback
1301 if opts.OnChunk != nil {
1302 err := opts.OnChunk(part)
1303 if err != nil {
1304 return stepExecutionResult{}, err
1305 }
1306 }
1307
1308 switch part.Type {
1309 case StreamPartTypeWarnings:
1310 stepWarnings = part.Warnings
1311 if opts.OnWarnings != nil {
1312 err := opts.OnWarnings(part.Warnings)
1313 if err != nil {
1314 return stepExecutionResult{}, err
1315 }
1316 }
1317
1318 case StreamPartTypeTextStart:
1319 activeTextContent[part.ID] = ""
1320 if opts.OnTextStart != nil {
1321 err := opts.OnTextStart(part.ID)
1322 if err != nil {
1323 return stepExecutionResult{}, err
1324 }
1325 }
1326
1327 case StreamPartTypeTextDelta:
1328 if _, exists := activeTextContent[part.ID]; exists {
1329 activeTextContent[part.ID] += part.Delta
1330 }
1331 if opts.OnTextDelta != nil {
1332 err := opts.OnTextDelta(part.ID, part.Delta)
1333 if err != nil {
1334 return stepExecutionResult{}, err
1335 }
1336 }
1337
1338 case StreamPartTypeTextEnd:
1339 if text, exists := activeTextContent[part.ID]; exists {
1340 stepContent = append(stepContent, TextContent{
1341 Text: text,
1342 ProviderMetadata: part.ProviderMetadata,
1343 })
1344 delete(activeTextContent, part.ID)
1345 }
1346 if opts.OnTextEnd != nil {
1347 err := opts.OnTextEnd(part.ID)
1348 if err != nil {
1349 return stepExecutionResult{}, err
1350 }
1351 }
1352
1353 case StreamPartTypeReasoningStart:
1354 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1355 if opts.OnReasoningStart != nil {
1356 content := ReasoningContent{
1357 Text: part.Delta,
1358 ProviderMetadata: part.ProviderMetadata,
1359 }
1360 err := opts.OnReasoningStart(part.ID, content)
1361 if err != nil {
1362 return stepExecutionResult{}, err
1363 }
1364 }
1365
1366 case StreamPartTypeReasoningDelta:
1367 if active, exists := activeReasoningContent[part.ID]; exists {
1368 active.content += part.Delta
1369 if part.ProviderMetadata != nil {
1370 active.options = part.ProviderMetadata
1371 }
1372 activeReasoningContent[part.ID] = active
1373 }
1374 if opts.OnReasoningDelta != nil {
1375 err := opts.OnReasoningDelta(part.ID, part.Delta)
1376 if err != nil {
1377 return stepExecutionResult{}, err
1378 }
1379 }
1380
1381 case StreamPartTypeReasoningEnd:
1382 if active, exists := activeReasoningContent[part.ID]; exists {
1383 if part.ProviderMetadata != nil {
1384 active.options = part.ProviderMetadata
1385 }
1386 content := ReasoningContent{
1387 Text: active.content,
1388 ProviderMetadata: active.options,
1389 }
1390 stepContent = append(stepContent, content)
1391 if opts.OnReasoningEnd != nil {
1392 err := opts.OnReasoningEnd(part.ID, content)
1393 if err != nil {
1394 return stepExecutionResult{}, err
1395 }
1396 }
1397 delete(activeReasoningContent, part.ID)
1398 }
1399
1400 case StreamPartTypeToolInputStart:
1401 activeToolCalls[part.ID] = &ToolCallContent{
1402 ToolCallID: part.ID,
1403 ToolName: part.ToolCallName,
1404 Input: "",
1405 ProviderExecuted: part.ProviderExecuted,
1406 }
1407 if opts.OnToolInputStart != nil {
1408 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1409 if err != nil {
1410 return stepExecutionResult{}, err
1411 }
1412 }
1413
1414 case StreamPartTypeToolInputDelta:
1415 if toolCall, exists := activeToolCalls[part.ID]; exists {
1416 toolCall.Input += part.Delta
1417 }
1418 if opts.OnToolInputDelta != nil {
1419 err := opts.OnToolInputDelta(part.ID, part.Delta)
1420 if err != nil {
1421 return stepExecutionResult{}, err
1422 }
1423 }
1424
1425 case StreamPartTypeToolInputEnd:
1426 if opts.OnToolInputEnd != nil {
1427 err := opts.OnToolInputEnd(part.ID)
1428 if err != nil {
1429 return stepExecutionResult{}, err
1430 }
1431 }
1432
1433 case StreamPartTypeToolCall:
1434 toolCall := ToolCallContent{
1435 ToolCallID: part.ID,
1436 ToolName: part.ToolCallName,
1437 Input: part.ToolCallInput,
1438 ProviderExecuted: part.ProviderExecuted,
1439 ProviderMetadata: part.ProviderMetadata,
1440 }
1441
1442 // Provider-executed tool calls are handled by the provider
1443 // and should not be validated or executed by the agent.
1444 if toolCall.ProviderExecuted {
1445 stepContent = append(stepContent, toolCall)
1446 if opts.OnToolCall != nil {
1447 err := opts.OnToolCall(toolCall)
1448 if err != nil {
1449 return stepExecutionResult{}, err
1450 }
1451 }
1452 delete(activeToolCalls, part.ID)
1453 } else {
1454 // Validate and potentially repair the tool call
1455 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1456 stepToolCalls = append(stepToolCalls, validatedToolCall)
1457 stepContent = append(stepContent, validatedToolCall)
1458
1459 if opts.OnToolCall != nil {
1460 err := opts.OnToolCall(validatedToolCall)
1461 if err != nil {
1462 return stepExecutionResult{}, err
1463 }
1464 }
1465
1466 // Determine if tool can run in parallel
1467 isParallel := false
1468 if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1469 isParallel = tool.Info().Parallel
1470 }
1471
1472 // Buffer dispatch until stream is fully consumed so that all
1473 // OnToolCall callbacks complete before any tool result is written.
1474 pendingDispatches = append(pendingDispatches, toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel})
1475
1476 // Clean up active tool call
1477 delete(activeToolCalls, part.ID)
1478 }
1479
1480 case StreamPartTypeToolResult:
1481 // Provider-executed tool results (e.g. web search)
1482 // are emitted by the provider and added directly
1483 // to the step content for multi-turn round-tripping.
1484 if part.ProviderExecuted {
1485 resultContent := ToolResultContent{
1486 ToolCallID: part.ID,
1487 ToolName: part.ToolCallName,
1488 ProviderExecuted: true,
1489 ProviderMetadata: part.ProviderMetadata,
1490 }
1491 stepContent = append(stepContent, resultContent)
1492 if opts.OnToolResult != nil {
1493 err := opts.OnToolResult(resultContent)
1494 if err != nil {
1495 return stepExecutionResult{}, err
1496 }
1497 }
1498 }
1499
1500 case StreamPartTypeSource:
1501 sourceContent := SourceContent{
1502 SourceType: part.SourceType,
1503 ID: part.ID,
1504 URL: part.URL,
1505 Title: part.Title,
1506 ProviderMetadata: part.ProviderMetadata,
1507 }
1508 stepContent = append(stepContent, sourceContent)
1509 if opts.OnSource != nil {
1510 err := opts.OnSource(sourceContent)
1511 if err != nil {
1512 return stepExecutionResult{}, err
1513 }
1514 }
1515
1516 case StreamPartTypeFinish:
1517 stepUsage = part.Usage
1518 stepFinishReason = part.FinishReason
1519 stepProviderMetadata = part.ProviderMetadata
1520 if opts.OnStreamFinish != nil {
1521 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1522 if err != nil {
1523 return stepExecutionResult{}, err
1524 }
1525 }
1526
1527 case StreamPartTypeError:
1528 return stepExecutionResult{}, part.Error
1529 }
1530 }
1531
1532 // All tool calls are now collected. Create the execution channel sized to
1533 // avoid blocking during dispatch, start the coordinator, then flush the batch.
1534 toolChan := make(chan toolExecutionRequest, len(pendingDispatches))
1535 var toolExecutionWg sync.WaitGroup
1536 var toolStateMu sync.Mutex
1537 toolResults := make([]ToolResultContent, 0, len(pendingDispatches))
1538 var toolExecutionErr error
1539
1540 // Semaphores for controlling parallelism.
1541 parallelSem := make(chan struct{}, 5)
1542 var sequentialMu sync.Mutex
1543
1544 // Single coordinator goroutine that dispatches tools.
1545 toolExecutionWg.Go(func() {
1546 for req := range toolChan {
1547 if req.parallel {
1548 parallelSem <- struct{}{}
1549 toolExecutionWg.Go(func() {
1550 defer func() { <-parallelSem }()
1551 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1552 toolStateMu.Lock()
1553 toolResults = append(toolResults, result)
1554 if isCriticalError && toolExecutionErr == nil {
1555 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1556 toolExecutionErr = errorResult.Error
1557 }
1558 }
1559 toolStateMu.Unlock()
1560 })
1561 } else {
1562 sequentialMu.Lock()
1563 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1564 toolStateMu.Lock()
1565 toolResults = append(toolResults, result)
1566 if isCriticalError && toolExecutionErr == nil {
1567 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1568 toolExecutionErr = errorResult.Error
1569 }
1570 }
1571 toolStateMu.Unlock()
1572 sequentialMu.Unlock()
1573 }
1574 }
1575 })
1576
1577 // Dispatch all buffered tool calls now that every OnToolCall callback has
1578 // been called, then close and wait.
1579 for _, req := range pendingDispatches {
1580 toolChan <- req
1581 }
1582
1583 // Close the tool execution channel and wait for all executions to complete.
1584 close(toolChan)
1585 toolExecutionWg.Wait()
1586
1587 // Check for tool execution errors
1588 if toolExecutionErr != nil {
1589 return stepExecutionResult{}, toolExecutionErr
1590 }
1591
1592 // Add tool results to content if any
1593 if len(toolResults) > 0 {
1594 for _, result := range toolResults {
1595 stepContent = append(stepContent, result)
1596 }
1597 }
1598
1599 stepResult := StepResult{
1600 Response: Response{
1601 Content: stepContent,
1602 FinishReason: stepFinishReason,
1603 Usage: stepUsage,
1604 Warnings: stepWarnings,
1605 ProviderMetadata: stepProviderMetadata,
1606 },
1607 Messages: toResponseMessages(stepContent),
1608 }
1609
1610 // Determine if we should continue (has tool calls and not stopped)
1611 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls && !hasStopTurn(toolResults)
1612
1613 return stepExecutionResult{
1614 StepResult: stepResult,
1615 ShouldContinue: shouldContinue,
1616 }, nil
1617}
1618
1619func addUsage(a, b Usage) Usage {
1620 return Usage{
1621 InputTokens: a.InputTokens + b.InputTokens,
1622 OutputTokens: a.OutputTokens + b.OutputTokens,
1623 TotalTokens: a.TotalTokens + b.TotalTokens,
1624 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1625 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1626 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1627 }
1628}
1629
1630// WithHeaders sets the headers for the agent.
1631func WithHeaders(headers map[string]string) AgentOption {
1632 return func(s *agentSettings) {
1633 s.headers = headers
1634 }
1635}
1636
1637// WithUserAgent sets the User-Agent header for the agent. This overrides any
1638// provider-level User-Agent setting.
1639func WithUserAgent(ua string) AgentOption {
1640 return func(s *agentSettings) {
1641 s.userAgent = ua
1642 }
1643}
1644
1645// WithProviderOptions sets the provider options for the agent.
1646func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1647 return func(s *agentSettings) {
1648 s.providerOptions = providerOptions
1649 }
1650}