1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "slices"
12 "sync"
13
14 "charm.land/fantasy/schema"
15 "github.com/charmbracelet/x/exp/slice"
16)
17
18// StepResult represents the result of a single step in an agent execution.
19type StepResult struct {
20 Response
21 Messages []Message
22}
23
24// stepExecutionResult encapsulates the result of executing a step with stream processing.
25type stepExecutionResult struct {
26 StepResult StepResult
27 ShouldContinue bool
28}
29
30// StopCondition defines a function that determines when an agent should stop executing.
31type StopCondition = func(steps []StepResult) bool
32
33// StepCountIs returns a stop condition that stops after the specified number of steps.
34func StepCountIs(stepCount int) StopCondition {
35 return func(steps []StepResult) bool {
36 return len(steps) >= stepCount
37 }
38}
39
40// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
41func HasToolCall(toolName string) StopCondition {
42 return func(steps []StepResult) bool {
43 if len(steps) == 0 {
44 return false
45 }
46 lastStep := steps[len(steps)-1]
47 toolCalls := lastStep.Content.ToolCalls()
48 for _, toolCall := range toolCalls {
49 if toolCall.ToolName == toolName {
50 return true
51 }
52 }
53 return false
54 }
55}
56
57// HasContent returns a stop condition that stops when the specified content type appears in the last step.
58func HasContent(contentType ContentType) StopCondition {
59 return func(steps []StepResult) bool {
60 if len(steps) == 0 {
61 return false
62 }
63 lastStep := steps[len(steps)-1]
64 for _, content := range lastStep.Content {
65 if content.GetType() == contentType {
66 return true
67 }
68 }
69 return false
70 }
71}
72
73// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
74func FinishReasonIs(reason FinishReason) StopCondition {
75 return func(steps []StepResult) bool {
76 if len(steps) == 0 {
77 return false
78 }
79 lastStep := steps[len(steps)-1]
80 return lastStep.FinishReason == reason
81 }
82}
83
84// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
85func MaxTokensUsed(maxTokens int64) StopCondition {
86 return func(steps []StepResult) bool {
87 var totalTokens int64
88 for _, step := range steps {
89 totalTokens += step.Usage.TotalTokens
90 }
91 return totalTokens >= maxTokens
92 }
93}
94
95// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
96type PrepareStepFunctionOptions struct {
97 Steps []StepResult
98 StepNumber int
99 Model LanguageModel
100 Messages []Message
101}
102
103// PrepareStepResult contains the result of preparing a step in an agent execution.
104type PrepareStepResult struct {
105 Model LanguageModel
106 Messages []Message
107 System *string
108 ToolChoice *ToolChoice
109 ActiveTools []string
110 DisableAllTools bool
111 Tools []AgentTool
112}
113
114// ToolCallRepairOptions contains the options for repairing a tool call.
115type ToolCallRepairOptions struct {
116 OriginalToolCall ToolCallContent
117 ValidationError error
118 AvailableTools []AgentTool
119 SystemPrompt string
120 Messages []Message
121}
122
123type (
124 // PrepareStepFunction defines a function that prepares a step in an agent execution.
125 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
126
127 // OnStepFinishedFunction defines a function that is called when a step finishes.
128 OnStepFinishedFunction = func(step StepResult)
129
130 // RepairToolCallFunction defines a function that repairs a tool call.
131 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
132)
133
134type agentSettings struct {
135 systemPrompt string
136 maxOutputTokens *int64
137 temperature *float64
138 topP *float64
139 topK *int64
140 presencePenalty *float64
141 frequencyPenalty *float64
142 headers map[string]string
143 userAgent string
144 providerOptions ProviderOptions
145
146 providerDefinedTools []ProviderDefinedTool
147 executableProviderTools []ExecutableProviderTool
148 tools []AgentTool
149 maxRetries *int
150
151 model LanguageModel
152
153 stopWhen []StopCondition
154 prepareStep PrepareStepFunction
155 repairToolCall RepairToolCallFunction
156 onRetry OnRetryCallback
157}
158
159// AgentCall represents a call to an agent.
160type AgentCall struct {
161 Prompt string `json:"prompt"`
162 Files []FilePart `json:"files"`
163 Messages []Message `json:"messages"`
164 MaxOutputTokens *int64
165 Temperature *float64 `json:"temperature"`
166 TopP *float64 `json:"top_p"`
167 TopK *int64 `json:"top_k"`
168 PresencePenalty *float64 `json:"presence_penalty"`
169 FrequencyPenalty *float64 `json:"frequency_penalty"`
170 ActiveTools []string `json:"active_tools"`
171 ProviderOptions ProviderOptions
172 OnRetry OnRetryCallback
173 MaxRetries *int
174
175 StopWhen []StopCondition
176 PrepareStep PrepareStepFunction
177 RepairToolCall RepairToolCallFunction
178}
179
180// Agent-level callbacks.
181type (
182 // OnAgentStartFunc is called when agent starts.
183 OnAgentStartFunc func()
184
185 // OnAgentFinishFunc is called when agent finishes.
186 OnAgentFinishFunc func(result *AgentResult) error
187
188 // OnStepStartFunc is called when a step starts.
189 OnStepStartFunc func(stepNumber int) error
190
191 // OnStepFinishFunc is called when a step finishes.
192 OnStepFinishFunc func(stepResult StepResult) error
193
194 // OnFinishFunc is called when entire agent completes.
195 OnFinishFunc func(result *AgentResult)
196
197 // OnErrorFunc is called when an error occurs.
198 OnErrorFunc func(error)
199)
200
201// Stream part callbacks - called for each corresponding stream part type.
202type (
203 // OnChunkFunc is called for each stream part (catch-all).
204 OnChunkFunc func(StreamPart) error
205
206 // OnWarningsFunc is called for warnings.
207 OnWarningsFunc func(warnings []CallWarning) error
208
209 // OnTextStartFunc is called when text starts.
210 OnTextStartFunc func(id string) error
211
212 // OnTextDeltaFunc is called for text deltas.
213 OnTextDeltaFunc func(id, text string) error
214
215 // OnTextEndFunc is called when text ends.
216 OnTextEndFunc func(id string) error
217
218 // OnReasoningStartFunc is called when reasoning starts.
219 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
220
221 // OnReasoningDeltaFunc is called for reasoning deltas.
222 OnReasoningDeltaFunc func(id, text string) error
223
224 // OnReasoningEndFunc is called when reasoning ends.
225 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
226
227 // OnToolInputStartFunc is called when tool input starts.
228 OnToolInputStartFunc func(id, toolName string) error
229
230 // OnToolInputDeltaFunc is called for tool input deltas.
231 OnToolInputDeltaFunc func(id, delta string) error
232
233 // OnToolInputEndFunc is called when tool input ends.
234 OnToolInputEndFunc func(id string) error
235
236 // OnToolCallFunc is called when tool call is complete.
237 OnToolCallFunc func(toolCall ToolCallContent) error
238
239 // OnToolResultFunc is called when tool execution completes.
240 OnToolResultFunc func(result ToolResultContent) error
241
242 // OnSourceFunc is called for source references.
243 OnSourceFunc func(source SourceContent) error
244
245 // OnStreamFinishFunc is called when stream finishes.
246 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
247)
248
249// AgentStreamCall represents a streaming call to an agent.
250type AgentStreamCall struct {
251 Prompt string `json:"prompt"`
252 Files []FilePart `json:"files"`
253 Messages []Message `json:"messages"`
254 MaxOutputTokens *int64
255 Temperature *float64 `json:"temperature"`
256 TopP *float64 `json:"top_p"`
257 TopK *int64 `json:"top_k"`
258 PresencePenalty *float64 `json:"presence_penalty"`
259 FrequencyPenalty *float64 `json:"frequency_penalty"`
260 ActiveTools []string `json:"active_tools"`
261 Headers map[string]string
262 ProviderOptions ProviderOptions
263 OnRetry OnRetryCallback
264 MaxRetries *int
265
266 StopWhen []StopCondition
267 PrepareStep PrepareStepFunction
268 RepairToolCall RepairToolCallFunction
269
270 // Agent-level callbacks
271 OnAgentStart OnAgentStartFunc // Called when agent starts
272 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
273 OnStepStart OnStepStartFunc // Called when a step starts
274 OnStepFinish OnStepFinishFunc // Called when a step finishes
275 OnFinish OnFinishFunc // Called when entire agent completes
276 OnError OnErrorFunc // Called when an error occurs
277
278 // Stream part callbacks - called for each corresponding stream part type
279 OnChunk OnChunkFunc // Called for each stream part (catch-all)
280 OnWarnings OnWarningsFunc // Called for warnings
281 OnTextStart OnTextStartFunc // Called when text starts
282 OnTextDelta OnTextDeltaFunc // Called for text deltas
283 OnTextEnd OnTextEndFunc // Called when text ends
284 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
285 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
286 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
287 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
288 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
289 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
290 OnToolCall OnToolCallFunc // Called when tool call is complete
291 OnToolResult OnToolResultFunc // Called when tool execution completes
292 OnSource OnSourceFunc // Called for source references
293 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
294}
295
296// AgentResult represents the result of an agent execution.
297type AgentResult struct {
298 Steps []StepResult
299 // Final response
300 Response Response
301 TotalUsage Usage
302}
303
304// Agent represents an AI agent that can generate responses and stream responses.
305type Agent interface {
306 Generate(context.Context, AgentCall) (*AgentResult, error)
307 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
308}
309
310// AgentOption defines a function that configures agent settings.
311type AgentOption = func(*agentSettings)
312
313type agent struct {
314 settings agentSettings
315}
316
317// NewAgent creates a new agent with the given language model and options.
318func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
319 settings := agentSettings{
320 model: model,
321 }
322 for _, o := range opts {
323 o(&settings)
324 }
325 return &agent{
326 settings: settings,
327 }
328}
329
330func (a *agent) prepareCall(call AgentCall) AgentCall {
331 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
332 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
333 call.TopP = cmp.Or(call.TopP, a.settings.topP)
334 call.TopK = cmp.Or(call.TopK, a.settings.topK)
335 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
336 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
337 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
338
339 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
340 call.StopWhen = a.settings.stopWhen
341 }
342 if call.PrepareStep == nil && a.settings.prepareStep != nil {
343 call.PrepareStep = a.settings.prepareStep
344 }
345 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
346 call.RepairToolCall = a.settings.repairToolCall
347 }
348 if call.OnRetry == nil && a.settings.onRetry != nil {
349 call.OnRetry = a.settings.onRetry
350 }
351
352 providerOptions := ProviderOptions{}
353 if a.settings.providerOptions != nil {
354 maps.Copy(providerOptions, a.settings.providerOptions)
355 }
356 if call.ProviderOptions != nil {
357 maps.Copy(providerOptions, call.ProviderOptions)
358 }
359 call.ProviderOptions = providerOptions
360
361 headers := map[string]string{}
362
363 if a.settings.headers != nil {
364 maps.Copy(headers, a.settings.headers)
365 }
366
367 return call
368}
369
370// Generate implements Agent.
371func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
372 opts = a.prepareCall(opts)
373 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
374 if err != nil {
375 return nil, err
376 }
377 var responseMessages []Message
378 var steps []StepResult
379
380 for {
381 stepInputMessages := append(initialPrompt, responseMessages...)
382 stepModel := a.settings.model
383 stepSystemPrompt := a.settings.systemPrompt
384 stepActiveTools := opts.ActiveTools
385 stepToolChoice := ToolChoiceAuto
386 disableAllTools := false
387 stepTools := a.settings.tools
388 if opts.PrepareStep != nil {
389 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
390 Model: stepModel,
391 Steps: steps,
392 StepNumber: len(steps),
393 Messages: stepInputMessages,
394 })
395 if err != nil {
396 return nil, err
397 }
398
399 ctx = updatedCtx
400
401 // Apply prepared step modifications
402 if prepared.Messages != nil {
403 stepInputMessages = prepared.Messages
404 }
405 if prepared.Model != nil {
406 stepModel = prepared.Model
407 }
408 if prepared.System != nil {
409 stepSystemPrompt = *prepared.System
410 }
411 if prepared.ToolChoice != nil {
412 stepToolChoice = *prepared.ToolChoice
413 }
414 if len(prepared.ActiveTools) > 0 {
415 stepActiveTools = prepared.ActiveTools
416 }
417 disableAllTools = prepared.DisableAllTools
418 if prepared.Tools != nil {
419 stepTools = prepared.Tools
420 }
421 }
422
423 // Recreate prompt with potentially modified system prompt
424 if stepSystemPrompt != a.settings.systemPrompt {
425 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
426 if err != nil {
427 return nil, err
428 }
429 // Replace system message part, keep the rest
430 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
431 stepInputMessages[0] = stepPrompt[0] // Replace system message
432 }
433 }
434
435 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
436
437 // Filter executable provider tools by activeTools at the
438 // step level, consistent with how stepTools (AgentTools)
439 // are scoped before being passed to inner functions.
440 stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
441
442 retryOptions := DefaultRetryOptions()
443 if opts.MaxRetries != nil {
444 retryOptions.MaxRetries = *opts.MaxRetries
445 }
446 retryOptions.OnRetry = opts.OnRetry
447 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
448 result, err := retry(ctx, func() (*Response, error) {
449 return stepModel.Generate(ctx, Call{
450 Prompt: stepInputMessages,
451 MaxOutputTokens: opts.MaxOutputTokens,
452 Temperature: opts.Temperature,
453 TopP: opts.TopP,
454 TopK: opts.TopK,
455 PresencePenalty: opts.PresencePenalty,
456 FrequencyPenalty: opts.FrequencyPenalty,
457 Tools: preparedTools,
458 ToolChoice: &stepToolChoice,
459 UserAgent: a.settings.userAgent,
460 ProviderOptions: opts.ProviderOptions,
461 })
462 })
463 if err != nil {
464 return nil, err
465 }
466
467 var stepToolCalls []ToolCallContent
468 for _, content := range result.Content {
469 if content.GetType() == ContentTypeToolCall {
470 toolCall, ok := AsContentType[ToolCallContent](content)
471 if !ok {
472 continue
473 }
474 // Provider-executed tool calls (e.g. web search) are
475 // handled by the provider and should not be validated
476 // or executed by the agent.
477 if toolCall.ProviderExecuted {
478 continue
479 }
480 // Validate and potentially repair the tool call
481 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepExecProviderTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
482 stepToolCalls = append(stepToolCalls, validatedToolCall)
483 }
484 }
485
486 toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)
487
488 // Build step content with validated tool calls and tool results. // Provider-executed tool calls are kept as-is.
489 stepContent := []Content{}
490 toolCallIndex := 0
491 for _, content := range result.Content {
492 if content.GetType() == ContentTypeToolCall {
493 tc, ok := AsContentType[ToolCallContent](content)
494 if ok && tc.ProviderExecuted {
495 stepContent = append(stepContent, content)
496 continue
497 }
498 // Replace with validated tool call.
499 if toolCallIndex < len(stepToolCalls) {
500 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
501 toolCallIndex++
502 }
503 } else {
504 stepContent = append(stepContent, content)
505 }
506 } // Add tool results
507 for _, result := range toolResults {
508 stepContent = append(stepContent, result)
509 }
510 currentStepMessages := toResponseMessages(stepContent)
511 responseMessages = append(responseMessages, currentStepMessages...)
512
513 stepResult := StepResult{
514 Response: Response{
515 Content: stepContent,
516 FinishReason: result.FinishReason,
517 Usage: result.Usage,
518 Warnings: result.Warnings,
519 ProviderMetadata: result.ProviderMetadata,
520 },
521 Messages: currentStepMessages,
522 }
523 steps = append(steps, stepResult)
524 shouldStop := isStopConditionMet(opts.StopWhen, steps)
525
526 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
527 break
528 }
529 }
530
531 totalUsage := Usage{}
532
533 for _, step := range steps {
534 usage := step.Usage
535 totalUsage.InputTokens += usage.InputTokens
536 totalUsage.OutputTokens += usage.OutputTokens
537 totalUsage.ReasoningTokens += usage.ReasoningTokens
538 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
539 totalUsage.CacheReadTokens += usage.CacheReadTokens
540 totalUsage.TotalTokens += usage.TotalTokens
541 }
542
543 agentResult := &AgentResult{
544 Steps: steps,
545 Response: steps[len(steps)-1].Response,
546 TotalUsage: totalUsage,
547 }
548 return agentResult, nil
549}
550
551func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
552 if len(conditions) == 0 {
553 return false
554 }
555
556 for _, condition := range conditions {
557 if condition(steps) {
558 return true
559 }
560 }
561 return false
562}
563
564func toResponseMessages(content []Content) []Message {
565 var assistantParts []MessagePart
566 var toolParts []MessagePart
567
568 for _, c := range content {
569 switch c.GetType() {
570 case ContentTypeText:
571 text, ok := AsContentType[TextContent](c)
572 if !ok {
573 continue
574 }
575 assistantParts = append(assistantParts, TextPart{
576 Text: text.Text,
577 ProviderOptions: ProviderOptions(text.ProviderMetadata),
578 })
579 case ContentTypeReasoning:
580 reasoning, ok := AsContentType[ReasoningContent](c)
581 if !ok {
582 continue
583 }
584 assistantParts = append(assistantParts, ReasoningPart{
585 Text: reasoning.Text,
586 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
587 })
588 case ContentTypeToolCall:
589 toolCall, ok := AsContentType[ToolCallContent](c)
590 if !ok {
591 continue
592 }
593 assistantParts = append(assistantParts, ToolCallPart{
594 ToolCallID: toolCall.ToolCallID,
595 ToolName: toolCall.ToolName,
596 Input: toolCall.Input,
597 ProviderExecuted: toolCall.ProviderExecuted,
598 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
599 })
600 case ContentTypeFile:
601 file, ok := AsContentType[FileContent](c)
602 if !ok {
603 continue
604 }
605 assistantParts = append(assistantParts, FilePart{
606 Data: file.Data,
607 MediaType: file.MediaType,
608 ProviderOptions: ProviderOptions(file.ProviderMetadata),
609 })
610 case ContentTypeSource:
611 // Sources are metadata about references used to generate the response.
612 // They don't need to be included in the conversation messages.
613 continue
614 case ContentTypeToolResult:
615 result, ok := AsContentType[ToolResultContent](c)
616 if !ok {
617 continue
618 }
619 resultPart := ToolResultPart{
620 ToolCallID: result.ToolCallID,
621 Output: result.Result,
622 ProviderExecuted: result.ProviderExecuted,
623 ProviderOptions: ProviderOptions(result.ProviderMetadata),
624 }
625 if result.ProviderExecuted {
626 // Provider-executed tool results (e.g. web search)
627 // belong in the assistant message alongside the
628 // server_tool_use block that produced them.
629 assistantParts = append(assistantParts, resultPart)
630 } else {
631 toolParts = append(toolParts, resultPart)
632 }
633 }
634 }
635
636 var messages []Message
637 if len(assistantParts) > 0 {
638 messages = append(messages, Message{
639 Role: MessageRoleAssistant,
640 Content: assistantParts,
641 })
642 }
643 if len(toolParts) > 0 {
644 messages = append(messages, Message{
645 Role: MessageRoleTool,
646 Content: toolParts,
647 })
648 }
649 return messages
650}
651
652func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, execProviderTools []ExecutableProviderTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
653 if len(toolCalls) == 0 {
654 return nil, nil
655 }
656
657 // Create a map for quick tool lookup
658 toolMap := make(map[string]AgentTool)
659 for _, tool := range allTools {
660 toolMap[tool.Info().Name] = tool
661 }
662
663 execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
664 for _, ept := range execProviderTools {
665 execProviderToolMap[ept.GetName()] = ept
666 }
667
668 // Execute all tool calls sequentially in order
669 results := make([]ToolResultContent, 0, len(toolCalls))
670
671 for _, toolCall := range toolCalls {
672 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, toolCall, toolResultCallback)
673 results = append(results, result)
674 if isCriticalError {
675 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
676 return nil, errorResult.Error
677 }
678 }
679 }
680
681 return results, nil
682}
683
684// executeSingleTool executes a single tool and returns its result and a critical error flag.
685func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, execProviderToolMap map[string]ExecutableProviderTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
686 result := ToolResultContent{
687 ToolCallID: toolCall.ToolCallID,
688 ToolName: toolCall.ToolName,
689 ProviderExecuted: false,
690 }
691
692 // Skip invalid tool calls - create error result (not critical)
693 if toolCall.Invalid {
694 result.Result = ToolResultOutputContentError{
695 Error: toolCall.ValidationError,
696 }
697 if toolResultCallback != nil {
698 _ = toolResultCallback(result)
699 }
700 return result, false
701 }
702
703 // Find the run function — either from a regular AgentTool or an
704 // executable provider tool.
705 var runTool func(ctx context.Context, call ToolCall) (ToolResponse, error)
706 if tool, exists := toolMap[toolCall.ToolName]; exists {
707 runTool = tool.Run
708 } else if ept, ok := execProviderToolMap[toolCall.ToolName]; ok {
709 runTool = ept.Run
710 }
711 if runTool == nil {
712 result.Result = ToolResultOutputContentError{
713 Error: errors.New("tool not found: " + toolCall.ToolName),
714 }
715 if toolResultCallback != nil {
716 _ = toolResultCallback(result)
717 }
718 return result, false
719 }
720
721 // Execute the tool
722 toolResult, err := runTool(ctx, ToolCall{
723 ID: toolCall.ToolCallID,
724 Name: toolCall.ToolName,
725 Input: toolCall.Input,
726 })
727 if err != nil {
728 result.Result = ToolResultOutputContentError{
729 Error: err,
730 }
731 result.ClientMetadata = toolResult.Metadata
732 if toolResultCallback != nil {
733 _ = toolResultCallback(result)
734 }
735 return result, true
736 }
737
738 result.ClientMetadata = toolResult.Metadata
739 if toolResult.IsError {
740 result.Result = ToolResultOutputContentError{
741 Error: errors.New(toolResult.Content),
742 }
743 } else if toolResult.Type == "image" || toolResult.Type == "media" {
744 result.Result = ToolResultOutputContentMedia{
745 Data: base64.StdEncoding.EncodeToString(toolResult.Data),
746 MediaType: toolResult.MediaType,
747 Text: toolResult.Content,
748 }
749 } else {
750 result.Result = ToolResultOutputContentText{
751 Text: toolResult.Content,
752 }
753 }
754 if toolResultCallback != nil {
755 _ = toolResultCallback(result)
756 }
757 return result, false
758}
759
760// Stream implements Agent.
761func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
762 // Convert AgentStreamCall to AgentCall for preparation
763 call := AgentCall{
764 Prompt: opts.Prompt,
765 Files: opts.Files,
766 Messages: opts.Messages,
767 MaxOutputTokens: opts.MaxOutputTokens,
768 Temperature: opts.Temperature,
769 TopP: opts.TopP,
770 TopK: opts.TopK,
771 PresencePenalty: opts.PresencePenalty,
772 FrequencyPenalty: opts.FrequencyPenalty,
773 ActiveTools: opts.ActiveTools,
774 ProviderOptions: opts.ProviderOptions,
775 MaxRetries: opts.MaxRetries,
776 OnRetry: opts.OnRetry,
777 StopWhen: opts.StopWhen,
778 PrepareStep: opts.PrepareStep,
779 RepairToolCall: opts.RepairToolCall,
780 }
781
782 call = a.prepareCall(call)
783
784 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
785 if err != nil {
786 return nil, err
787 }
788
789 var responseMessages []Message
790 var steps []StepResult
791 var totalUsage Usage
792
793 // Start agent stream
794 if opts.OnAgentStart != nil {
795 opts.OnAgentStart()
796 }
797
798 for stepNumber := 0; ; stepNumber++ {
799 stepInputMessages := append(initialPrompt, responseMessages...)
800 stepModel := a.settings.model
801 stepSystemPrompt := a.settings.systemPrompt
802 stepActiveTools := call.ActiveTools
803 stepToolChoice := ToolChoiceAuto
804 disableAllTools := false
805 stepTools := a.settings.tools
806 // Apply step preparation if provided
807 if call.PrepareStep != nil {
808 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
809 Model: stepModel,
810 Steps: steps,
811 StepNumber: stepNumber,
812 Messages: stepInputMessages,
813 })
814 if err != nil {
815 return nil, err
816 }
817
818 ctx = updatedCtx
819
820 if prepared.Messages != nil {
821 stepInputMessages = prepared.Messages
822 }
823 if prepared.Model != nil {
824 stepModel = prepared.Model
825 }
826 if prepared.System != nil {
827 stepSystemPrompt = *prepared.System
828 }
829 if prepared.ToolChoice != nil {
830 stepToolChoice = *prepared.ToolChoice
831 }
832 if len(prepared.ActiveTools) > 0 {
833 stepActiveTools = prepared.ActiveTools
834 }
835 disableAllTools = prepared.DisableAllTools
836 if prepared.Tools != nil {
837 stepTools = prepared.Tools
838 }
839 }
840
841 // Recreate prompt with potentially modified system prompt
842 if stepSystemPrompt != a.settings.systemPrompt {
843 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
844 if err != nil {
845 return nil, err
846 }
847 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
848 stepInputMessages[0] = stepPrompt[0]
849 }
850 }
851
852 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
853
854 // Filter executable provider tools by activeTools at the
855 // step level, consistent with how stepTools (AgentTools)
856 // are scoped before being passed to inner functions.
857 stepExecProviderTools := a.filterExecProviderTools(stepActiveTools)
858
859 // Start step stream
860 if opts.OnStepStart != nil {
861 _ = opts.OnStepStart(stepNumber)
862 }
863 // Create streaming call
864 streamCall := Call{
865 Prompt: stepInputMessages,
866 MaxOutputTokens: call.MaxOutputTokens,
867 Temperature: call.Temperature,
868 TopP: call.TopP,
869 TopK: call.TopK,
870 PresencePenalty: call.PresencePenalty,
871 FrequencyPenalty: call.FrequencyPenalty,
872 Tools: preparedTools,
873 ToolChoice: &stepToolChoice,
874 UserAgent: a.settings.userAgent,
875 ProviderOptions: call.ProviderOptions,
876 }
877
878 // Execute step with retry logic wrapping both stream creation and processing
879 retryOptions := DefaultRetryOptions()
880 if call.MaxRetries != nil {
881 retryOptions.MaxRetries = *call.MaxRetries
882 }
883 retryOptions.OnRetry = call.OnRetry
884 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
885
886 result, err := retry(ctx, func() (stepExecutionResult, error) {
887 // Create the stream
888 stream, err := stepModel.Stream(ctx, streamCall)
889 if err != nil {
890 return stepExecutionResult{}, err
891 }
892
893 // Process the stream
894 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools)
895 if err != nil {
896 return stepExecutionResult{}, err
897 }
898 return result, nil
899 })
900 if err != nil {
901 if opts.OnError != nil {
902 opts.OnError(err)
903 }
904 return nil, err
905 }
906
907 steps = append(steps, result.StepResult)
908 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
909
910 // Call step finished callback
911 if opts.OnStepFinish != nil {
912 _ = opts.OnStepFinish(result.StepResult)
913 }
914
915 // Add step messages to response messages
916 stepMessages := toResponseMessages(result.StepResult.Content)
917 responseMessages = append(responseMessages, stepMessages...)
918
919 // Check stop conditions
920 shouldStop := isStopConditionMet(call.StopWhen, steps)
921 if shouldStop || !result.ShouldContinue {
922 break
923 }
924 }
925
926 // Finish agent stream
927 agentResult := &AgentResult{
928 Steps: steps,
929 Response: steps[len(steps)-1].Response,
930 TotalUsage: totalUsage,
931 }
932
933 if opts.OnFinish != nil {
934 opts.OnFinish(agentResult)
935 }
936
937 if opts.OnAgentFinish != nil {
938 _ = opts.OnAgentFinish(agentResult)
939 }
940
941 return agentResult, nil
942}
943
944// filterExecProviderTools returns the subset of executable provider
945// tools permitted by activeTools. When activeTools is empty every
946// tool is included (no filtering).
947func (a *agent) filterExecProviderTools(activeTools []string) []ExecutableProviderTool {
948 if len(activeTools) == 0 {
949 return a.settings.executableProviderTools
950 }
951 filtered := make([]ExecutableProviderTool, 0, len(a.settings.executableProviderTools))
952 for _, ept := range a.settings.executableProviderTools {
953 if slices.Contains(activeTools, ept.GetName()) {
954 filtered = append(filtered, ept)
955 }
956 }
957 return filtered
958}
959
960func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
961 preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
962
963 // If explicitly disabling all tools, return no tools
964 if disableAllTools {
965 return preparedTools
966 }
967
968 for _, tool := range tools {
969 // If activeTools has items, only include tools in the list
970 // If activeTools is empty, include all tools
971 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
972 continue
973 }
974 info := tool.Info()
975 inputSchema := map[string]any{
976 "type": "object",
977 "properties": info.Parameters,
978 "required": info.Required,
979 }
980 schema.Normalize(inputSchema)
981 preparedTools = append(preparedTools, FunctionTool{
982 Name: info.Name,
983 Description: info.Description,
984 InputSchema: inputSchema,
985 ProviderOptions: tool.ProviderOptions(),
986 })
987 }
988 for _, tool := range providerDefinedTools {
989 // If activeTools has items, only include tools in the list. If
990 // activeTools is empty, include all tools
991 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
992 continue
993 }
994 preparedTools = append(preparedTools, tool)
995 }
996 return preparedTools
997}
998
999// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
1000func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
1001 if err := a.validateToolCall(toolCall, availableTools, execProviderTools); err == nil {
1002 return toolCall
1003 } else { //nolint: revive
1004 if repairFunc != nil {
1005 repairOptions := ToolCallRepairOptions{
1006 OriginalToolCall: toolCall,
1007 ValidationError: err,
1008 AvailableTools: availableTools,
1009 SystemPrompt: systemPrompt,
1010 Messages: messages,
1011 }
1012
1013 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
1014 if validateErr := a.validateToolCall(*repairedToolCall, availableTools, execProviderTools); validateErr == nil {
1015 return *repairedToolCall
1016 }
1017 }
1018 }
1019
1020 invalidToolCall := toolCall
1021 invalidToolCall.Invalid = true
1022 invalidToolCall.ValidationError = err
1023 return invalidToolCall
1024 }
1025}
1026
1027// validateToolCall validates a tool call against available tools and their schemas.
1028// Both availableTools and execProviderTools must already be filtered by the
1029// caller (e.g. via activeTools); this function trusts that the slices
1030// represent exactly the tools permitted for the current step.
1031func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool, execProviderTools []ExecutableProviderTool) error {
1032 var tool AgentTool
1033 for _, t := range availableTools {
1034 if t.Info().Name == toolCall.ToolName {
1035 tool = t
1036 break
1037 }
1038 }
1039
1040 if tool == nil {
1041 // Check if this is an executable provider tool. Provider-
1042 // defined tools have their schema enforced server-side, so
1043 // we only validate that the input is parseable JSON.
1044 for _, ept := range execProviderTools {
1045 if ept.GetName() == toolCall.ToolName {
1046 var input map[string]any
1047 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1048 return fmt.Errorf("invalid JSON input: %w", err)
1049 }
1050 return nil
1051 }
1052 }
1053 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1054 }
1055
1056 // Validate JSON parsing
1057 var input map[string]any
1058 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1059 return fmt.Errorf("invalid JSON input: %w", err)
1060 }
1061
1062 // Basic schema validation (check required fields)
1063 // TODO: more robust schema validation using JSON Schema or similar
1064 toolInfo := tool.Info()
1065 for _, required := range toolInfo.Required {
1066 if _, exists := input[required]; !exists {
1067 return fmt.Errorf("missing required parameter: %s", required)
1068 }
1069 }
1070 return nil
1071}
1072
1073func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1074 // Validation: empty prompt is only allowed when there are messages,
1075 // no files to attach, and the last message is a user or tool message.
1076 if prompt == "" {
1077 lastMessage, hasMessages := slice.Last(messages)
1078
1079 if !hasMessages {
1080 return nil, &Error{
1081 Title: "invalid argument",
1082 Message: "prompt can't be empty when there are no messages",
1083 }
1084 }
1085
1086 if len(files) > 0 {
1087 return nil, &Error{
1088 Title: "invalid argument",
1089 Message: "prompt can't be empty when there are files",
1090 }
1091 }
1092
1093 switch lastMessage.Role {
1094 case MessageRoleUser, MessageRoleTool:
1095 default:
1096 return nil, &Error{
1097 Title: "invalid argument",
1098 Message: "prompt can't be empty when the last message is not a user or tool message",
1099 }
1100 }
1101 }
1102
1103 var preparedPrompt Prompt
1104
1105 if system != "" {
1106 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1107 }
1108 preparedPrompt = append(preparedPrompt, messages...)
1109 if prompt != "" {
1110 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1111 }
1112 return preparedPrompt, nil
1113}
1114
1115// WithSystemPrompt sets the system prompt for the agent.
1116func WithSystemPrompt(prompt string) AgentOption {
1117 return func(s *agentSettings) {
1118 s.systemPrompt = prompt
1119 }
1120}
1121
1122// WithMaxOutputTokens sets the maximum output tokens for the agent.
1123func WithMaxOutputTokens(tokens int64) AgentOption {
1124 return func(s *agentSettings) {
1125 s.maxOutputTokens = &tokens
1126 }
1127}
1128
1129// WithTemperature sets the temperature for the agent.
1130func WithTemperature(temp float64) AgentOption {
1131 return func(s *agentSettings) {
1132 s.temperature = &temp
1133 }
1134}
1135
1136// WithTopP sets the top-p value for the agent.
1137func WithTopP(topP float64) AgentOption {
1138 return func(s *agentSettings) {
1139 s.topP = &topP
1140 }
1141}
1142
1143// WithTopK sets the top-k value for the agent.
1144func WithTopK(topK int64) AgentOption {
1145 return func(s *agentSettings) {
1146 s.topK = &topK
1147 }
1148}
1149
1150// WithPresencePenalty sets the presence penalty for the agent.
1151func WithPresencePenalty(penalty float64) AgentOption {
1152 return func(s *agentSettings) {
1153 s.presencePenalty = &penalty
1154 }
1155}
1156
1157// WithFrequencyPenalty sets the frequency penalty for the agent.
1158func WithFrequencyPenalty(penalty float64) AgentOption {
1159 return func(s *agentSettings) {
1160 s.frequencyPenalty = &penalty
1161 }
1162}
1163
1164// WithTools sets the tools for the agent.
1165func WithTools(tools ...AgentTool) AgentOption {
1166 return func(s *agentSettings) {
1167 s.tools = append(s.tools, tools...)
1168 }
1169}
1170
1171// WithProviderDefinedTools registers provider-defined tools with the
1172// agent. Provider-executed tools (e.g. web search) are passed through
1173// to the API. Client-executed tools (ExecutableProviderTool) are also
1174// registered for local execution.
1175func WithProviderDefinedTools(tools ...ProviderTool) AgentOption {
1176 return func(s *agentSettings) {
1177 for _, t := range tools {
1178 // Every provider tool goes into providerDefinedTools
1179 // for wire formatting.
1180 s.providerDefinedTools = append(
1181 s.providerDefinedTools, t.providerDefinedTool(),
1182 )
1183 // Executable ones also register for local execution.
1184 if exec, ok := t.(ExecutableProviderTool); ok {
1185 s.executableProviderTools = append(
1186 s.executableProviderTools, exec,
1187 )
1188 }
1189 }
1190 }
1191}
1192
1193// WithStopConditions sets the stop conditions for the agent.
1194func WithStopConditions(conditions ...StopCondition) AgentOption {
1195 return func(s *agentSettings) {
1196 s.stopWhen = append(s.stopWhen, conditions...)
1197 }
1198}
1199
1200// WithPrepareStep sets the prepare step function for the agent.
1201func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1202 return func(s *agentSettings) {
1203 s.prepareStep = fn
1204 }
1205}
1206
1207// WithRepairToolCall sets the repair tool call function for the agent.
1208func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1209 return func(s *agentSettings) {
1210 s.repairToolCall = fn
1211 }
1212}
1213
1214// WithMaxRetries sets the maximum number of retries for the agent.
1215func WithMaxRetries(maxRetries int) AgentOption {
1216 return func(s *agentSettings) {
1217 s.maxRetries = &maxRetries
1218 }
1219}
1220
1221// WithOnRetry sets the retry callback for the agent.
1222func WithOnRetry(callback OnRetryCallback) AgentOption {
1223 return func(s *agentSettings) {
1224 s.onRetry = callback
1225 }
1226}
1227
1228// processStepStream processes a single step's stream and returns the step result.
1229func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) {
1230 var stepContent []Content
1231 var stepToolCalls []ToolCallContent
1232 var stepUsage Usage
1233 stepFinishReason := FinishReasonUnknown
1234 var stepWarnings []CallWarning
1235 var stepProviderMetadata ProviderMetadata
1236
1237 activeToolCalls := make(map[string]*ToolCallContent)
1238 activeTextContent := make(map[string]string)
1239 type reasoningContent struct {
1240 content string
1241 options ProviderMetadata
1242 }
1243 activeReasoningContent := make(map[string]reasoningContent)
1244
1245 // Set up concurrent tool execution
1246 type toolExecutionRequest struct {
1247 toolCall ToolCallContent
1248 parallel bool
1249 }
1250 toolChan := make(chan toolExecutionRequest, 10)
1251 var toolExecutionWg sync.WaitGroup
1252 var toolStateMu sync.Mutex
1253 toolResults := make([]ToolResultContent, 0)
1254 var toolExecutionErr error
1255
1256 // Create a map for quick tool lookup
1257 toolMap := make(map[string]AgentTool)
1258 for _, tool := range stepTools {
1259 toolMap[tool.Info().Name] = tool
1260 }
1261
1262 execProviderToolMap := make(map[string]ExecutableProviderTool, len(execProviderTools))
1263 for _, ept := range execProviderTools {
1264 execProviderToolMap[ept.GetName()] = ept
1265 }
1266
1267 // Semaphores for controlling parallelism
1268 parallelSem := make(chan struct{}, 5)
1269 var sequentialMu sync.Mutex
1270
1271 // Single coordinator goroutine that dispatches tools
1272 toolExecutionWg.Go(func() {
1273 for req := range toolChan {
1274 if req.parallel {
1275 parallelSem <- struct{}{}
1276 toolExecutionWg.Go(func() {
1277 defer func() { <-parallelSem }()
1278 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1279 toolStateMu.Lock()
1280 toolResults = append(toolResults, result)
1281 if isCriticalError && toolExecutionErr == nil {
1282 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1283 toolExecutionErr = errorResult.Error
1284 }
1285 }
1286 toolStateMu.Unlock()
1287 })
1288 } else {
1289 sequentialMu.Lock()
1290 result, isCriticalError := a.executeSingleTool(ctx, toolMap, execProviderToolMap, req.toolCall, opts.OnToolResult)
1291 toolStateMu.Lock()
1292 toolResults = append(toolResults, result)
1293 if isCriticalError && toolExecutionErr == nil {
1294 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1295 toolExecutionErr = errorResult.Error
1296 }
1297 }
1298 toolStateMu.Unlock()
1299 sequentialMu.Unlock()
1300 }
1301 }
1302 })
1303
1304 // Process stream parts
1305 for part := range stream {
1306 // Forward all parts to chunk callback
1307 if opts.OnChunk != nil {
1308 err := opts.OnChunk(part)
1309 if err != nil {
1310 return stepExecutionResult{}, err
1311 }
1312 }
1313
1314 switch part.Type {
1315 case StreamPartTypeWarnings:
1316 stepWarnings = part.Warnings
1317 if opts.OnWarnings != nil {
1318 err := opts.OnWarnings(part.Warnings)
1319 if err != nil {
1320 return stepExecutionResult{}, err
1321 }
1322 }
1323
1324 case StreamPartTypeTextStart:
1325 activeTextContent[part.ID] = ""
1326 if opts.OnTextStart != nil {
1327 err := opts.OnTextStart(part.ID)
1328 if err != nil {
1329 return stepExecutionResult{}, err
1330 }
1331 }
1332
1333 case StreamPartTypeTextDelta:
1334 if _, exists := activeTextContent[part.ID]; exists {
1335 activeTextContent[part.ID] += part.Delta
1336 }
1337 if opts.OnTextDelta != nil {
1338 err := opts.OnTextDelta(part.ID, part.Delta)
1339 if err != nil {
1340 return stepExecutionResult{}, err
1341 }
1342 }
1343
1344 case StreamPartTypeTextEnd:
1345 if text, exists := activeTextContent[part.ID]; exists {
1346 stepContent = append(stepContent, TextContent{
1347 Text: text,
1348 ProviderMetadata: part.ProviderMetadata,
1349 })
1350 delete(activeTextContent, part.ID)
1351 }
1352 if opts.OnTextEnd != nil {
1353 err := opts.OnTextEnd(part.ID)
1354 if err != nil {
1355 return stepExecutionResult{}, err
1356 }
1357 }
1358
1359 case StreamPartTypeReasoningStart:
1360 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1361 if opts.OnReasoningStart != nil {
1362 content := ReasoningContent{
1363 Text: part.Delta,
1364 ProviderMetadata: part.ProviderMetadata,
1365 }
1366 err := opts.OnReasoningStart(part.ID, content)
1367 if err != nil {
1368 return stepExecutionResult{}, err
1369 }
1370 }
1371
1372 case StreamPartTypeReasoningDelta:
1373 if active, exists := activeReasoningContent[part.ID]; exists {
1374 active.content += part.Delta
1375 if part.ProviderMetadata != nil {
1376 active.options = part.ProviderMetadata
1377 }
1378 activeReasoningContent[part.ID] = active
1379 }
1380 if opts.OnReasoningDelta != nil {
1381 err := opts.OnReasoningDelta(part.ID, part.Delta)
1382 if err != nil {
1383 return stepExecutionResult{}, err
1384 }
1385 }
1386
1387 case StreamPartTypeReasoningEnd:
1388 if active, exists := activeReasoningContent[part.ID]; exists {
1389 if part.ProviderMetadata != nil {
1390 active.options = part.ProviderMetadata
1391 }
1392 content := ReasoningContent{
1393 Text: active.content,
1394 ProviderMetadata: active.options,
1395 }
1396 stepContent = append(stepContent, content)
1397 if opts.OnReasoningEnd != nil {
1398 err := opts.OnReasoningEnd(part.ID, content)
1399 if err != nil {
1400 return stepExecutionResult{}, err
1401 }
1402 }
1403 delete(activeReasoningContent, part.ID)
1404 }
1405
1406 case StreamPartTypeToolInputStart:
1407 activeToolCalls[part.ID] = &ToolCallContent{
1408 ToolCallID: part.ID,
1409 ToolName: part.ToolCallName,
1410 Input: "",
1411 ProviderExecuted: part.ProviderExecuted,
1412 }
1413 if opts.OnToolInputStart != nil {
1414 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1415 if err != nil {
1416 return stepExecutionResult{}, err
1417 }
1418 }
1419
1420 case StreamPartTypeToolInputDelta:
1421 if toolCall, exists := activeToolCalls[part.ID]; exists {
1422 toolCall.Input += part.Delta
1423 }
1424 if opts.OnToolInputDelta != nil {
1425 err := opts.OnToolInputDelta(part.ID, part.Delta)
1426 if err != nil {
1427 return stepExecutionResult{}, err
1428 }
1429 }
1430
1431 case StreamPartTypeToolInputEnd:
1432 if opts.OnToolInputEnd != nil {
1433 err := opts.OnToolInputEnd(part.ID)
1434 if err != nil {
1435 return stepExecutionResult{}, err
1436 }
1437 }
1438
1439 case StreamPartTypeToolCall:
1440 toolCall := ToolCallContent{
1441 ToolCallID: part.ID,
1442 ToolName: part.ToolCallName,
1443 Input: part.ToolCallInput,
1444 ProviderExecuted: part.ProviderExecuted,
1445 ProviderMetadata: part.ProviderMetadata,
1446 }
1447
1448 // Provider-executed tool calls are handled by the provider
1449 // and should not be validated or executed by the agent.
1450 if toolCall.ProviderExecuted {
1451 stepContent = append(stepContent, toolCall)
1452 if opts.OnToolCall != nil {
1453 err := opts.OnToolCall(toolCall)
1454 if err != nil {
1455 return stepExecutionResult{}, err
1456 }
1457 }
1458 delete(activeToolCalls, part.ID)
1459 } else {
1460 // Validate and potentially repair the tool call
1461 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1462 stepToolCalls = append(stepToolCalls, validatedToolCall)
1463 stepContent = append(stepContent, validatedToolCall)
1464
1465 if opts.OnToolCall != nil {
1466 err := opts.OnToolCall(validatedToolCall)
1467 if err != nil {
1468 return stepExecutionResult{}, err
1469 }
1470 }
1471
1472 // Determine if tool can run in parallel
1473 isParallel := false
1474 if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1475 isParallel = tool.Info().Parallel
1476 }
1477
1478 // Send tool call to execution channel
1479 toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1480
1481 // Clean up active tool call
1482 delete(activeToolCalls, part.ID)
1483 }
1484
1485 case StreamPartTypeToolResult:
1486 // Provider-executed tool results (e.g. web search)
1487 // are emitted by the provider and added directly
1488 // to the step content for multi-turn round-tripping.
1489 if part.ProviderExecuted {
1490 resultContent := ToolResultContent{
1491 ToolCallID: part.ID,
1492 ToolName: part.ToolCallName,
1493 ProviderExecuted: true,
1494 ProviderMetadata: part.ProviderMetadata,
1495 }
1496 stepContent = append(stepContent, resultContent)
1497 if opts.OnToolResult != nil {
1498 err := opts.OnToolResult(resultContent)
1499 if err != nil {
1500 return stepExecutionResult{}, err
1501 }
1502 }
1503 }
1504
1505 case StreamPartTypeSource:
1506 sourceContent := SourceContent{
1507 SourceType: part.SourceType,
1508 ID: part.ID,
1509 URL: part.URL,
1510 Title: part.Title,
1511 ProviderMetadata: part.ProviderMetadata,
1512 }
1513 stepContent = append(stepContent, sourceContent)
1514 if opts.OnSource != nil {
1515 err := opts.OnSource(sourceContent)
1516 if err != nil {
1517 return stepExecutionResult{}, err
1518 }
1519 }
1520
1521 case StreamPartTypeFinish:
1522 stepUsage = part.Usage
1523 stepFinishReason = part.FinishReason
1524 stepProviderMetadata = part.ProviderMetadata
1525 if opts.OnStreamFinish != nil {
1526 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1527 if err != nil {
1528 return stepExecutionResult{}, err
1529 }
1530 }
1531
1532 case StreamPartTypeError:
1533 return stepExecutionResult{}, part.Error
1534 }
1535 }
1536
1537 // Close the tool execution channel and wait for all executions to complete
1538 close(toolChan)
1539 toolExecutionWg.Wait()
1540
1541 // Check for tool execution errors
1542 if toolExecutionErr != nil {
1543 return stepExecutionResult{}, toolExecutionErr
1544 }
1545
1546 // Add tool results to content if any
1547 if len(toolResults) > 0 {
1548 for _, result := range toolResults {
1549 stepContent = append(stepContent, result)
1550 }
1551 }
1552
1553 stepResult := StepResult{
1554 Response: Response{
1555 Content: stepContent,
1556 FinishReason: stepFinishReason,
1557 Usage: stepUsage,
1558 Warnings: stepWarnings,
1559 ProviderMetadata: stepProviderMetadata,
1560 },
1561 Messages: toResponseMessages(stepContent),
1562 }
1563
1564 // Determine if we should continue (has tool calls and not stopped)
1565 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1566
1567 return stepExecutionResult{
1568 StepResult: stepResult,
1569 ShouldContinue: shouldContinue,
1570 }, nil
1571}
1572
1573func addUsage(a, b Usage) Usage {
1574 return Usage{
1575 InputTokens: a.InputTokens + b.InputTokens,
1576 OutputTokens: a.OutputTokens + b.OutputTokens,
1577 TotalTokens: a.TotalTokens + b.TotalTokens,
1578 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1579 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1580 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1581 }
1582}
1583
1584// WithHeaders sets the headers for the agent.
1585func WithHeaders(headers map[string]string) AgentOption {
1586 return func(s *agentSettings) {
1587 s.headers = headers
1588 }
1589}
1590
1591// WithUserAgent sets the User-Agent header for the agent. This overrides any
1592// provider-level User-Agent setting.
1593func WithUserAgent(ua string) AgentOption {
1594 return func(s *agentSettings) {
1595 s.userAgent = ua
1596 }
1597}
1598
1599// WithProviderOptions sets the provider options for the agent.
1600func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1601 return func(s *agentSettings) {
1602 s.providerOptions = providerOptions
1603 }
1604}