1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12)
13
14// StepResult represents the result of a single step in an agent execution.
15type StepResult struct {
16 Response
17 Messages []Message
18}
19
20// stepExecutionResult encapsulates the result of executing a step with stream processing.
21type stepExecutionResult struct {
22 StepResult StepResult
23 ShouldContinue bool
24}
25
26// StopCondition defines a function that determines when an agent should stop executing.
27type StopCondition = func(steps []StepResult) bool
28
29// StepCountIs returns a stop condition that stops after the specified number of steps.
30func StepCountIs(stepCount int) StopCondition {
31 return func(steps []StepResult) bool {
32 return len(steps) >= stepCount
33 }
34}
35
36// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
37func HasToolCall(toolName string) StopCondition {
38 return func(steps []StepResult) bool {
39 if len(steps) == 0 {
40 return false
41 }
42 lastStep := steps[len(steps)-1]
43 toolCalls := lastStep.Content.ToolCalls()
44 for _, toolCall := range toolCalls {
45 if toolCall.ToolName == toolName {
46 return true
47 }
48 }
49 return false
50 }
51}
52
53// HasContent returns a stop condition that stops when the specified content type appears in the last step.
54func HasContent(contentType ContentType) StopCondition {
55 return func(steps []StepResult) bool {
56 if len(steps) == 0 {
57 return false
58 }
59 lastStep := steps[len(steps)-1]
60 for _, content := range lastStep.Content {
61 if content.GetType() == contentType {
62 return true
63 }
64 }
65 return false
66 }
67}
68
69// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
70func FinishReasonIs(reason FinishReason) StopCondition {
71 return func(steps []StepResult) bool {
72 if len(steps) == 0 {
73 return false
74 }
75 lastStep := steps[len(steps)-1]
76 return lastStep.FinishReason == reason
77 }
78}
79
80// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
81func MaxTokensUsed(maxTokens int64) StopCondition {
82 return func(steps []StepResult) bool {
83 var totalTokens int64
84 for _, step := range steps {
85 totalTokens += step.Usage.TotalTokens
86 }
87 return totalTokens >= maxTokens
88 }
89}
90
91// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
92type PrepareStepFunctionOptions struct {
93 Steps []StepResult
94 StepNumber int
95 Model LanguageModel
96 Messages []Message
97}
98
99// PrepareStepResult contains the result of preparing a step in an agent execution.
100type PrepareStepResult struct {
101 Model LanguageModel
102 Messages []Message
103 System *string
104 ToolChoice *ToolChoice
105 ActiveTools []string
106 DisableAllTools bool
107 Tools []AgentTool
108}
109
110// ToolCallRepairOptions contains the options for repairing a tool call.
111type ToolCallRepairOptions struct {
112 OriginalToolCall ToolCallContent
113 ValidationError error
114 AvailableTools []AgentTool
115 SystemPrompt string
116 Messages []Message
117}
118
119type (
120 // PrepareStepFunction defines a function that prepares a step in an agent execution.
121 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
122
123 // OnStepFinishedFunction defines a function that is called when a step finishes.
124 OnStepFinishedFunction = func(step StepResult)
125
126 // RepairToolCallFunction defines a function that repairs a tool call.
127 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
128)
129
130type agentSettings struct {
131 systemPrompt string
132 maxOutputTokens *int64
133 temperature *float64
134 topP *float64
135 topK *int64
136 presencePenalty *float64
137 frequencyPenalty *float64
138 headers map[string]string
139 providerOptions ProviderOptions
140
141 // TODO: add support for provider tools
142 tools []AgentTool
143 maxRetries *int
144
145 model LanguageModel
146
147 stopWhen []StopCondition
148 prepareStep PrepareStepFunction
149 repairToolCall RepairToolCallFunction
150 onRetry OnRetryCallback
151}
152
153// AgentCall represents a call to an agent.
154type AgentCall struct {
155 Prompt string `json:"prompt"`
156 Files []FilePart `json:"files"`
157 Messages []Message `json:"messages"`
158 MaxOutputTokens *int64
159 Temperature *float64 `json:"temperature"`
160 TopP *float64 `json:"top_p"`
161 TopK *int64 `json:"top_k"`
162 PresencePenalty *float64 `json:"presence_penalty"`
163 FrequencyPenalty *float64 `json:"frequency_penalty"`
164 ActiveTools []string `json:"active_tools"`
165 ProviderOptions ProviderOptions
166 OnRetry OnRetryCallback
167 MaxRetries *int
168
169 StopWhen []StopCondition
170 PrepareStep PrepareStepFunction
171 RepairToolCall RepairToolCallFunction
172}
173
174// Agent-level callbacks.
175type (
176 // OnAgentStartFunc is called when agent starts.
177 OnAgentStartFunc func()
178
179 // OnAgentFinishFunc is called when agent finishes.
180 OnAgentFinishFunc func(result *AgentResult) error
181
182 // OnStepStartFunc is called when a step starts.
183 OnStepStartFunc func(stepNumber int) error
184
185 // OnStepFinishFunc is called when a step finishes.
186 OnStepFinishFunc func(stepResult StepResult) error
187
188 // OnFinishFunc is called when entire agent completes.
189 OnFinishFunc func(result *AgentResult)
190
191 // OnErrorFunc is called when an error occurs.
192 OnErrorFunc func(error)
193)
194
195// Stream part callbacks - called for each corresponding stream part type.
196type (
197 // OnChunkFunc is called for each stream part (catch-all).
198 OnChunkFunc func(StreamPart) error
199
200 // OnWarningsFunc is called for warnings.
201 OnWarningsFunc func(warnings []CallWarning) error
202
203 // OnTextStartFunc is called when text starts.
204 OnTextStartFunc func(id string) error
205
206 // OnTextDeltaFunc is called for text deltas.
207 OnTextDeltaFunc func(id, text string) error
208
209 // OnTextEndFunc is called when text ends.
210 OnTextEndFunc func(id string) error
211
212 // OnReasoningStartFunc is called when reasoning starts.
213 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
214
215 // OnReasoningDeltaFunc is called for reasoning deltas.
216 OnReasoningDeltaFunc func(id, text string) error
217
218 // OnReasoningEndFunc is called when reasoning ends.
219 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
220
221 // OnToolInputStartFunc is called when tool input starts.
222 OnToolInputStartFunc func(id, toolName string) error
223
224 // OnToolInputDeltaFunc is called for tool input deltas.
225 OnToolInputDeltaFunc func(id, delta string) error
226
227 // OnToolInputEndFunc is called when tool input ends.
228 OnToolInputEndFunc func(id string) error
229
230 // OnToolCallFunc is called when tool call is complete.
231 OnToolCallFunc func(toolCall ToolCallContent) error
232
233 // OnToolResultFunc is called when tool execution completes.
234 OnToolResultFunc func(result ToolResultContent) error
235
236 // OnSourceFunc is called for source references.
237 OnSourceFunc func(source SourceContent) error
238
239 // OnStreamFinishFunc is called when stream finishes.
240 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
241)
242
243// AgentStreamCall represents a streaming call to an agent.
244type AgentStreamCall struct {
245 Prompt string `json:"prompt"`
246 Files []FilePart `json:"files"`
247 Messages []Message `json:"messages"`
248 MaxOutputTokens *int64
249 Temperature *float64 `json:"temperature"`
250 TopP *float64 `json:"top_p"`
251 TopK *int64 `json:"top_k"`
252 PresencePenalty *float64 `json:"presence_penalty"`
253 FrequencyPenalty *float64 `json:"frequency_penalty"`
254 ActiveTools []string `json:"active_tools"`
255 Headers map[string]string
256 ProviderOptions ProviderOptions
257 OnRetry OnRetryCallback
258 MaxRetries *int
259
260 StopWhen []StopCondition
261 PrepareStep PrepareStepFunction
262 RepairToolCall RepairToolCallFunction
263
264 // Agent-level callbacks
265 OnAgentStart OnAgentStartFunc // Called when agent starts
266 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
267 OnStepStart OnStepStartFunc // Called when a step starts
268 OnStepFinish OnStepFinishFunc // Called when a step finishes
269 OnFinish OnFinishFunc // Called when entire agent completes
270 OnError OnErrorFunc // Called when an error occurs
271
272 // Stream part callbacks - called for each corresponding stream part type
273 OnChunk OnChunkFunc // Called for each stream part (catch-all)
274 OnWarnings OnWarningsFunc // Called for warnings
275 OnTextStart OnTextStartFunc // Called when text starts
276 OnTextDelta OnTextDeltaFunc // Called for text deltas
277 OnTextEnd OnTextEndFunc // Called when text ends
278 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
279 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
280 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
281 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
282 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
283 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
284 OnToolCall OnToolCallFunc // Called when tool call is complete
285 OnToolResult OnToolResultFunc // Called when tool execution completes
286 OnSource OnSourceFunc // Called for source references
287 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
288}
289
290// AgentResult represents the result of an agent execution.
291type AgentResult struct {
292 Steps []StepResult
293 // Final response
294 Response Response
295 TotalUsage Usage
296}
297
298// Agent represents an AI agent that can generate responses and stream responses.
299type Agent interface {
300 Generate(context.Context, AgentCall) (*AgentResult, error)
301 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
302}
303
304// AgentOption defines a function that configures agent settings.
305type AgentOption = func(*agentSettings)
306
307type agent struct {
308 settings agentSettings
309}
310
311// NewAgent creates a new agent with the given language model and options.
312func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
313 settings := agentSettings{
314 model: model,
315 }
316 for _, o := range opts {
317 o(&settings)
318 }
319 return &agent{
320 settings: settings,
321 }
322}
323
324func (a *agent) prepareCall(call AgentCall) AgentCall {
325 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
326 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
327 call.TopP = cmp.Or(call.TopP, a.settings.topP)
328 call.TopK = cmp.Or(call.TopK, a.settings.topK)
329 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
330 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
331 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
332
333 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
334 call.StopWhen = a.settings.stopWhen
335 }
336 if call.PrepareStep == nil && a.settings.prepareStep != nil {
337 call.PrepareStep = a.settings.prepareStep
338 }
339 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
340 call.RepairToolCall = a.settings.repairToolCall
341 }
342 if call.OnRetry == nil && a.settings.onRetry != nil {
343 call.OnRetry = a.settings.onRetry
344 }
345
346 providerOptions := ProviderOptions{}
347 if a.settings.providerOptions != nil {
348 maps.Copy(providerOptions, a.settings.providerOptions)
349 }
350 if call.ProviderOptions != nil {
351 maps.Copy(providerOptions, call.ProviderOptions)
352 }
353 call.ProviderOptions = providerOptions
354
355 headers := map[string]string{}
356
357 if a.settings.headers != nil {
358 maps.Copy(headers, a.settings.headers)
359 }
360
361 return call
362}
363
364// Generate implements Agent.
365func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
366 opts = a.prepareCall(opts)
367 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
368 if err != nil {
369 return nil, err
370 }
371 var responseMessages []Message
372 var steps []StepResult
373
374 for {
375 stepInputMessages := append(initialPrompt, responseMessages...)
376 stepModel := a.settings.model
377 stepSystemPrompt := a.settings.systemPrompt
378 stepActiveTools := opts.ActiveTools
379 stepToolChoice := ToolChoiceAuto
380 disableAllTools := false
381 stepTools := a.settings.tools
382 if opts.PrepareStep != nil {
383 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
384 Model: stepModel,
385 Steps: steps,
386 StepNumber: len(steps),
387 Messages: stepInputMessages,
388 })
389 if err != nil {
390 return nil, err
391 }
392
393 ctx = updatedCtx
394
395 // Apply prepared step modifications
396 if prepared.Messages != nil {
397 stepInputMessages = prepared.Messages
398 }
399 if prepared.Model != nil {
400 stepModel = prepared.Model
401 }
402 if prepared.System != nil {
403 stepSystemPrompt = *prepared.System
404 }
405 if prepared.ToolChoice != nil {
406 stepToolChoice = *prepared.ToolChoice
407 }
408 if len(prepared.ActiveTools) > 0 {
409 stepActiveTools = prepared.ActiveTools
410 }
411 disableAllTools = prepared.DisableAllTools
412 if prepared.Tools != nil {
413 stepTools = prepared.Tools
414 }
415 }
416
417 // Recreate prompt with potentially modified system prompt
418 if stepSystemPrompt != a.settings.systemPrompt {
419 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
420 if err != nil {
421 return nil, err
422 }
423 // Replace system message part, keep the rest
424 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
425 stepInputMessages[0] = stepPrompt[0] // Replace system message
426 }
427 }
428
429 preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
430
431 retryOptions := DefaultRetryOptions()
432 if opts.MaxRetries != nil {
433 retryOptions.MaxRetries = *opts.MaxRetries
434 }
435 retryOptions.OnRetry = opts.OnRetry
436 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
437
438 result, err := retry(ctx, func() (*Response, error) {
439 return stepModel.Generate(ctx, Call{
440 Prompt: stepInputMessages,
441 MaxOutputTokens: opts.MaxOutputTokens,
442 Temperature: opts.Temperature,
443 TopP: opts.TopP,
444 TopK: opts.TopK,
445 PresencePenalty: opts.PresencePenalty,
446 FrequencyPenalty: opts.FrequencyPenalty,
447 Tools: preparedTools,
448 ToolChoice: &stepToolChoice,
449 ProviderOptions: opts.ProviderOptions,
450 })
451 })
452 if err != nil {
453 return nil, err
454 }
455
456 var stepToolCalls []ToolCallContent
457 for _, content := range result.Content {
458 if content.GetType() == ContentTypeToolCall {
459 toolCall, ok := AsContentType[ToolCallContent](content)
460 if !ok {
461 continue
462 }
463
464 // Validate and potentially repair the tool call
465 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
466 stepToolCalls = append(stepToolCalls, validatedToolCall)
467 }
468 }
469
470 toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
471
472 // Build step content with validated tool calls and tool results
473 stepContent := []Content{}
474 toolCallIndex := 0
475 for _, content := range result.Content {
476 if content.GetType() == ContentTypeToolCall {
477 // Replace with validated tool call
478 if toolCallIndex < len(stepToolCalls) {
479 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
480 toolCallIndex++
481 }
482 } else {
483 // Keep other content as-is
484 stepContent = append(stepContent, content)
485 }
486 }
487 // Add tool results
488 for _, result := range toolResults {
489 stepContent = append(stepContent, result)
490 }
491 currentStepMessages := toResponseMessages(stepContent)
492 responseMessages = append(responseMessages, currentStepMessages...)
493
494 stepResult := StepResult{
495 Response: Response{
496 Content: stepContent,
497 FinishReason: result.FinishReason,
498 Usage: result.Usage,
499 Warnings: result.Warnings,
500 ProviderMetadata: result.ProviderMetadata,
501 },
502 Messages: currentStepMessages,
503 }
504 steps = append(steps, stepResult)
505 shouldStop := isStopConditionMet(opts.StopWhen, steps)
506
507 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
508 break
509 }
510 }
511
512 totalUsage := Usage{}
513
514 for _, step := range steps {
515 usage := step.Usage
516 totalUsage.InputTokens += usage.InputTokens
517 totalUsage.OutputTokens += usage.OutputTokens
518 totalUsage.ReasoningTokens += usage.ReasoningTokens
519 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
520 totalUsage.CacheReadTokens += usage.CacheReadTokens
521 totalUsage.TotalTokens += usage.TotalTokens
522 }
523
524 agentResult := &AgentResult{
525 Steps: steps,
526 Response: steps[len(steps)-1].Response,
527 TotalUsage: totalUsage,
528 }
529 return agentResult, nil
530}
531
532func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
533 if len(conditions) == 0 {
534 return false
535 }
536
537 for _, condition := range conditions {
538 if condition(steps) {
539 return true
540 }
541 }
542 return false
543}
544
545func toResponseMessages(content []Content) []Message {
546 var assistantParts []MessagePart
547 var toolParts []MessagePart
548
549 for _, c := range content {
550 switch c.GetType() {
551 case ContentTypeText:
552 text, ok := AsContentType[TextContent](c)
553 if !ok {
554 continue
555 }
556 assistantParts = append(assistantParts, TextPart{
557 Text: text.Text,
558 ProviderOptions: ProviderOptions(text.ProviderMetadata),
559 })
560 case ContentTypeReasoning:
561 reasoning, ok := AsContentType[ReasoningContent](c)
562 if !ok {
563 continue
564 }
565 assistantParts = append(assistantParts, ReasoningPart{
566 Text: reasoning.Text,
567 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
568 })
569 case ContentTypeToolCall:
570 toolCall, ok := AsContentType[ToolCallContent](c)
571 if !ok {
572 continue
573 }
574 assistantParts = append(assistantParts, ToolCallPart{
575 ToolCallID: toolCall.ToolCallID,
576 ToolName: toolCall.ToolName,
577 Input: toolCall.Input,
578 ProviderExecuted: toolCall.ProviderExecuted,
579 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
580 })
581 case ContentTypeFile:
582 file, ok := AsContentType[FileContent](c)
583 if !ok {
584 continue
585 }
586 assistantParts = append(assistantParts, FilePart{
587 Data: file.Data,
588 MediaType: file.MediaType,
589 ProviderOptions: ProviderOptions(file.ProviderMetadata),
590 })
591 case ContentTypeSource:
592 // Sources are metadata about references used to generate the response.
593 // They don't need to be included in the conversation messages.
594 continue
595 case ContentTypeToolResult:
596 result, ok := AsContentType[ToolResultContent](c)
597 if !ok {
598 continue
599 }
600 toolParts = append(toolParts, ToolResultPart{
601 ToolCallID: result.ToolCallID,
602 Output: result.Result,
603 ProviderOptions: ProviderOptions(result.ProviderMetadata),
604 })
605 }
606 }
607
608 var messages []Message
609 if len(assistantParts) > 0 {
610 messages = append(messages, Message{
611 Role: MessageRoleAssistant,
612 Content: assistantParts,
613 })
614 }
615 if len(toolParts) > 0 {
616 messages = append(messages, Message{
617 Role: MessageRoleTool,
618 Content: toolParts,
619 })
620 }
621 return messages
622}
623
624func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
625 if len(toolCalls) == 0 {
626 return nil, nil
627 }
628
629 // Create a map for quick tool lookup
630 toolMap := make(map[string]AgentTool)
631 for _, tool := range allTools {
632 toolMap[tool.Info().Name] = tool
633 }
634
635 // Execute all tool calls sequentially in order
636 results := make([]ToolResultContent, 0, len(toolCalls))
637
638 for _, toolCall := range toolCalls {
639 // Skip invalid tool calls - create error result
640 if toolCall.Invalid {
641 result := ToolResultContent{
642 ToolCallID: toolCall.ToolCallID,
643 ToolName: toolCall.ToolName,
644 Result: ToolResultOutputContentError{
645 Error: toolCall.ValidationError,
646 },
647 ProviderExecuted: false,
648 }
649 results = append(results, result)
650 if toolResultCallback != nil {
651 if err := toolResultCallback(result); err != nil {
652 return nil, err
653 }
654 }
655 continue
656 }
657
658 tool, exists := toolMap[toolCall.ToolName]
659 if !exists {
660 result := ToolResultContent{
661 ToolCallID: toolCall.ToolCallID,
662 ToolName: toolCall.ToolName,
663 Result: ToolResultOutputContentError{
664 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
665 },
666 ProviderExecuted: false,
667 }
668 results = append(results, result)
669 if toolResultCallback != nil {
670 if err := toolResultCallback(result); err != nil {
671 return nil, err
672 }
673 }
674 continue
675 }
676
677 // Execute the tool
678 toolResult, err := tool.Run(ctx, ToolCall{
679 ID: toolCall.ToolCallID,
680 Name: toolCall.ToolName,
681 Input: toolCall.Input,
682 })
683 if err != nil {
684 result := ToolResultContent{
685 ToolCallID: toolCall.ToolCallID,
686 ToolName: toolCall.ToolName,
687 Result: ToolResultOutputContentError{
688 Error: err,
689 },
690 ClientMetadata: toolResult.Metadata,
691 ProviderExecuted: false,
692 }
693 if toolResultCallback != nil {
694 if cbErr := toolResultCallback(result); cbErr != nil {
695 return nil, cbErr
696 }
697 }
698 return nil, err
699 }
700
701 var result ToolResultContent
702 if toolResult.IsError {
703 result = ToolResultContent{
704 ToolCallID: toolCall.ToolCallID,
705 ToolName: toolCall.ToolName,
706 Result: ToolResultOutputContentError{
707 Error: errors.New(toolResult.Content),
708 },
709 ClientMetadata: toolResult.Metadata,
710 ProviderExecuted: false,
711 }
712 } else if toolResult.Type == "image" || toolResult.Type == "media" {
713 result = ToolResultContent{
714 ToolCallID: toolCall.ToolCallID,
715 ToolName: toolCall.ToolName,
716 Result: ToolResultOutputContentMedia{
717 Data: string(toolResult.Data),
718 MediaType: toolResult.MediaType,
719 Text: toolResult.Content,
720 },
721 ClientMetadata: toolResult.Metadata,
722 ProviderExecuted: false,
723 }
724 } else {
725 // Default to text response
726 result = ToolResultContent{
727 ToolCallID: toolCall.ToolCallID,
728 ToolName: toolCall.ToolName,
729 Result: ToolResultOutputContentText{
730 Text: toolResult.Content,
731 },
732 ClientMetadata: toolResult.Metadata,
733 ProviderExecuted: false,
734 }
735 }
736 results = append(results, result)
737 if toolResultCallback != nil {
738 if err := toolResultCallback(result); err != nil {
739 return nil, err
740 }
741 }
742 }
743
744 return results, nil
745}
746
747// executeSingleTool executes a single tool and returns its result and a critical error flag.
748func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
749 result := ToolResultContent{
750 ToolCallID: toolCall.ToolCallID,
751 ToolName: toolCall.ToolName,
752 ProviderExecuted: false,
753 }
754
755 // Skip invalid tool calls - create error result (not critical)
756 if toolCall.Invalid {
757 result.Result = ToolResultOutputContentError{
758 Error: toolCall.ValidationError,
759 }
760 if toolResultCallback != nil {
761 _ = toolResultCallback(result)
762 }
763 return result, false
764 }
765
766 tool, exists := toolMap[toolCall.ToolName]
767 if !exists {
768 result.Result = ToolResultOutputContentError{
769 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
770 }
771 if toolResultCallback != nil {
772 _ = toolResultCallback(result)
773 }
774 return result, false
775 }
776
777 // Execute the tool
778 toolResult, err := tool.Run(ctx, ToolCall{
779 ID: toolCall.ToolCallID,
780 Name: toolCall.ToolName,
781 Input: toolCall.Input,
782 })
783 if err != nil {
784 result.Result = ToolResultOutputContentError{
785 Error: err,
786 }
787 result.ClientMetadata = toolResult.Metadata
788 if toolResultCallback != nil {
789 _ = toolResultCallback(result)
790 }
791 // This is a critical error - tool.Run() failed
792 return result, true
793 }
794
795 result.ClientMetadata = toolResult.Metadata
796 if toolResult.IsError {
797 result.Result = ToolResultOutputContentError{
798 Error: errors.New(toolResult.Content),
799 }
800 } else {
801 result.Result = ToolResultOutputContentText{
802 Text: toolResult.Content,
803 }
804 }
805 if toolResultCallback != nil {
806 _ = toolResultCallback(result)
807 }
808 // Not a critical error - tool ran successfully (even if it reported an error state)
809 return result, false
810}
811
812// Stream implements Agent.
813func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
814 // Convert AgentStreamCall to AgentCall for preparation
815 call := AgentCall{
816 Prompt: opts.Prompt,
817 Files: opts.Files,
818 Messages: opts.Messages,
819 MaxOutputTokens: opts.MaxOutputTokens,
820 Temperature: opts.Temperature,
821 TopP: opts.TopP,
822 TopK: opts.TopK,
823 PresencePenalty: opts.PresencePenalty,
824 FrequencyPenalty: opts.FrequencyPenalty,
825 ActiveTools: opts.ActiveTools,
826 ProviderOptions: opts.ProviderOptions,
827 MaxRetries: opts.MaxRetries,
828 OnRetry: opts.OnRetry,
829 StopWhen: opts.StopWhen,
830 PrepareStep: opts.PrepareStep,
831 RepairToolCall: opts.RepairToolCall,
832 }
833
834 call = a.prepareCall(call)
835
836 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
837 if err != nil {
838 return nil, err
839 }
840
841 var responseMessages []Message
842 var steps []StepResult
843 var totalUsage Usage
844
845 // Start agent stream
846 if opts.OnAgentStart != nil {
847 opts.OnAgentStart()
848 }
849
850 for stepNumber := 0; ; stepNumber++ {
851 stepInputMessages := append(initialPrompt, responseMessages...)
852 stepModel := a.settings.model
853 stepSystemPrompt := a.settings.systemPrompt
854 stepActiveTools := call.ActiveTools
855 stepToolChoice := ToolChoiceAuto
856 disableAllTools := false
857 stepTools := a.settings.tools
858 // Apply step preparation if provided
859 if call.PrepareStep != nil {
860 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
861 Model: stepModel,
862 Steps: steps,
863 StepNumber: stepNumber,
864 Messages: stepInputMessages,
865 })
866 if err != nil {
867 return nil, err
868 }
869
870 ctx = updatedCtx
871
872 if prepared.Messages != nil {
873 stepInputMessages = prepared.Messages
874 }
875 if prepared.Model != nil {
876 stepModel = prepared.Model
877 }
878 if prepared.System != nil {
879 stepSystemPrompt = *prepared.System
880 }
881 if prepared.ToolChoice != nil {
882 stepToolChoice = *prepared.ToolChoice
883 }
884 if len(prepared.ActiveTools) > 0 {
885 stepActiveTools = prepared.ActiveTools
886 }
887 disableAllTools = prepared.DisableAllTools
888 if prepared.Tools != nil {
889 stepTools = prepared.Tools
890 }
891 }
892
893 // Recreate prompt with potentially modified system prompt
894 if stepSystemPrompt != a.settings.systemPrompt {
895 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
896 if err != nil {
897 return nil, err
898 }
899 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
900 stepInputMessages[0] = stepPrompt[0]
901 }
902 }
903
904 preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
905
906 // Start step stream
907 if opts.OnStepStart != nil {
908 _ = opts.OnStepStart(stepNumber)
909 }
910
911 // Create streaming call
912 streamCall := Call{
913 Prompt: stepInputMessages,
914 MaxOutputTokens: call.MaxOutputTokens,
915 Temperature: call.Temperature,
916 TopP: call.TopP,
917 TopK: call.TopK,
918 PresencePenalty: call.PresencePenalty,
919 FrequencyPenalty: call.FrequencyPenalty,
920 Tools: preparedTools,
921 ToolChoice: &stepToolChoice,
922 ProviderOptions: call.ProviderOptions,
923 }
924
925 // Execute step with retry logic wrapping both stream creation and processing
926 retryOptions := DefaultRetryOptions()
927 if call.MaxRetries != nil {
928 retryOptions.MaxRetries = *call.MaxRetries
929 }
930 retryOptions.OnRetry = call.OnRetry
931 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
932
933 result, err := retry(ctx, func() (stepExecutionResult, error) {
934 // Create the stream
935 stream, err := stepModel.Stream(ctx, streamCall)
936 if err != nil {
937 return stepExecutionResult{}, err
938 }
939
940 // Process the stream
941 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
942 if err != nil {
943 return stepExecutionResult{}, err
944 }
945
946 return result, nil
947 })
948 if err != nil {
949 if opts.OnError != nil {
950 opts.OnError(err)
951 }
952 return nil, err
953 }
954
955 steps = append(steps, result.StepResult)
956 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
957
958 // Call step finished callback
959 if opts.OnStepFinish != nil {
960 _ = opts.OnStepFinish(result.StepResult)
961 }
962
963 // Add step messages to response messages
964 stepMessages := toResponseMessages(result.StepResult.Content)
965 responseMessages = append(responseMessages, stepMessages...)
966
967 // Check stop conditions
968 shouldStop := isStopConditionMet(call.StopWhen, steps)
969 if shouldStop || !result.ShouldContinue {
970 break
971 }
972 }
973
974 // Finish agent stream
975 agentResult := &AgentResult{
976 Steps: steps,
977 Response: steps[len(steps)-1].Response,
978 TotalUsage: totalUsage,
979 }
980
981 if opts.OnFinish != nil {
982 opts.OnFinish(agentResult)
983 }
984
985 if opts.OnAgentFinish != nil {
986 _ = opts.OnAgentFinish(agentResult)
987 }
988
989 return agentResult, nil
990}
991
992func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
993 preparedTools := make([]Tool, 0, len(tools))
994
995 // If explicitly disabling all tools, return no tools
996 if disableAllTools {
997 return preparedTools
998 }
999
1000 for _, tool := range tools {
1001 // If activeTools has items, only include tools in the list
1002 // If activeTools is empty, include all tools
1003 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
1004 continue
1005 }
1006 info := tool.Info()
1007 preparedTools = append(preparedTools, FunctionTool{
1008 Name: info.Name,
1009 Description: info.Description,
1010 InputSchema: map[string]any{
1011 "type": "object",
1012 "properties": info.Parameters,
1013 "required": info.Required,
1014 },
1015 ProviderOptions: tool.ProviderOptions(),
1016 })
1017 }
1018 return preparedTools
1019}
1020
1021// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
1022func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
1023 if err := a.validateToolCall(toolCall, availableTools); err == nil {
1024 return toolCall
1025 } else { //nolint: revive
1026 if repairFunc != nil {
1027 repairOptions := ToolCallRepairOptions{
1028 OriginalToolCall: toolCall,
1029 ValidationError: err,
1030 AvailableTools: availableTools,
1031 SystemPrompt: systemPrompt,
1032 Messages: messages,
1033 }
1034
1035 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
1036 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
1037 return *repairedToolCall
1038 }
1039 }
1040 }
1041
1042 invalidToolCall := toolCall
1043 invalidToolCall.Invalid = true
1044 invalidToolCall.ValidationError = err
1045 return invalidToolCall
1046 }
1047}
1048
1049// validateToolCall validates a tool call against available tools and their schemas.
1050func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
1051 var tool AgentTool
1052 for _, t := range availableTools {
1053 if t.Info().Name == toolCall.ToolName {
1054 tool = t
1055 break
1056 }
1057 }
1058
1059 if tool == nil {
1060 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1061 }
1062
1063 // Validate JSON parsing
1064 var input map[string]any
1065 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1066 return fmt.Errorf("invalid JSON input: %w", err)
1067 }
1068
1069 // Basic schema validation (check required fields)
1070 // TODO: more robust schema validation using JSON Schema or similar
1071 toolInfo := tool.Info()
1072 for _, required := range toolInfo.Required {
1073 if _, exists := input[required]; !exists {
1074 return fmt.Errorf("missing required parameter: %s", required)
1075 }
1076 }
1077 return nil
1078}
1079
1080func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1081 if prompt == "" {
1082 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
1083 }
1084
1085 var preparedPrompt Prompt
1086
1087 if system != "" {
1088 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1089 }
1090 preparedPrompt = append(preparedPrompt, messages...)
1091 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1092 return preparedPrompt, nil
1093}
1094
1095// WithSystemPrompt sets the system prompt for the agent.
1096func WithSystemPrompt(prompt string) AgentOption {
1097 return func(s *agentSettings) {
1098 s.systemPrompt = prompt
1099 }
1100}
1101
1102// WithMaxOutputTokens sets the maximum output tokens for the agent.
1103func WithMaxOutputTokens(tokens int64) AgentOption {
1104 return func(s *agentSettings) {
1105 s.maxOutputTokens = &tokens
1106 }
1107}
1108
1109// WithTemperature sets the temperature for the agent.
1110func WithTemperature(temp float64) AgentOption {
1111 return func(s *agentSettings) {
1112 s.temperature = &temp
1113 }
1114}
1115
1116// WithTopP sets the top-p value for the agent.
1117func WithTopP(topP float64) AgentOption {
1118 return func(s *agentSettings) {
1119 s.topP = &topP
1120 }
1121}
1122
1123// WithTopK sets the top-k value for the agent.
1124func WithTopK(topK int64) AgentOption {
1125 return func(s *agentSettings) {
1126 s.topK = &topK
1127 }
1128}
1129
1130// WithPresencePenalty sets the presence penalty for the agent.
1131func WithPresencePenalty(penalty float64) AgentOption {
1132 return func(s *agentSettings) {
1133 s.presencePenalty = &penalty
1134 }
1135}
1136
1137// WithFrequencyPenalty sets the frequency penalty for the agent.
1138func WithFrequencyPenalty(penalty float64) AgentOption {
1139 return func(s *agentSettings) {
1140 s.frequencyPenalty = &penalty
1141 }
1142}
1143
1144// WithTools sets the tools for the agent.
1145func WithTools(tools ...AgentTool) AgentOption {
1146 return func(s *agentSettings) {
1147 s.tools = append(s.tools, tools...)
1148 }
1149}
1150
1151// WithStopConditions sets the stop conditions for the agent.
1152func WithStopConditions(conditions ...StopCondition) AgentOption {
1153 return func(s *agentSettings) {
1154 s.stopWhen = append(s.stopWhen, conditions...)
1155 }
1156}
1157
1158// WithPrepareStep sets the prepare step function for the agent.
1159func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1160 return func(s *agentSettings) {
1161 s.prepareStep = fn
1162 }
1163}
1164
1165// WithRepairToolCall sets the repair tool call function for the agent.
1166func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1167 return func(s *agentSettings) {
1168 s.repairToolCall = fn
1169 }
1170}
1171
1172// WithMaxRetries sets the maximum number of retries for the agent.
1173func WithMaxRetries(maxRetries int) AgentOption {
1174 return func(s *agentSettings) {
1175 s.maxRetries = &maxRetries
1176 }
1177}
1178
1179// WithOnRetry sets the retry callback for the agent.
1180func WithOnRetry(callback OnRetryCallback) AgentOption {
1181 return func(s *agentSettings) {
1182 s.onRetry = callback
1183 }
1184}
1185
1186// processStepStream processes a single step's stream and returns the step result.
1187func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1188 var stepContent []Content
1189 var stepToolCalls []ToolCallContent
1190 var stepUsage Usage
1191 stepFinishReason := FinishReasonUnknown
1192 var stepWarnings []CallWarning
1193 var stepProviderMetadata ProviderMetadata
1194
1195 activeToolCalls := make(map[string]*ToolCallContent)
1196 activeTextContent := make(map[string]string)
1197 type reasoningContent struct {
1198 content string
1199 options ProviderMetadata
1200 }
1201 activeReasoningContent := make(map[string]reasoningContent)
1202
1203 // Set up concurrent tool execution
1204 type toolExecutionRequest struct {
1205 toolCall ToolCallContent
1206 parallel bool
1207 }
1208 toolChan := make(chan toolExecutionRequest, 10)
1209 var toolExecutionWg sync.WaitGroup
1210 var toolStateMu sync.Mutex
1211 toolResults := make([]ToolResultContent, 0)
1212 var toolExecutionErr error
1213
1214 // Create a map for quick tool lookup
1215 toolMap := make(map[string]AgentTool)
1216 for _, tool := range stepTools {
1217 toolMap[tool.Info().Name] = tool
1218 }
1219
1220 // Semaphores for controlling parallelism
1221 parallelSem := make(chan struct{}, 5)
1222 var sequentialMu sync.Mutex
1223
1224 // Single coordinator goroutine that dispatches tools
1225 toolExecutionWg.Go(func() {
1226 for req := range toolChan {
1227 if req.parallel {
1228 parallelSem <- struct{}{}
1229 toolExecutionWg.Go(func() {
1230 defer func() { <-parallelSem }()
1231 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1232 toolStateMu.Lock()
1233 toolResults = append(toolResults, result)
1234 if isCriticalError && toolExecutionErr == nil {
1235 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1236 toolExecutionErr = errorResult.Error
1237 }
1238 }
1239 toolStateMu.Unlock()
1240 })
1241 } else {
1242 sequentialMu.Lock()
1243 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1244 toolStateMu.Lock()
1245 toolResults = append(toolResults, result)
1246 if isCriticalError && toolExecutionErr == nil {
1247 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1248 toolExecutionErr = errorResult.Error
1249 }
1250 }
1251 toolStateMu.Unlock()
1252 sequentialMu.Unlock()
1253 }
1254 }
1255 })
1256
1257 // Process stream parts
1258 for part := range stream {
1259 // Forward all parts to chunk callback
1260 if opts.OnChunk != nil {
1261 err := opts.OnChunk(part)
1262 if err != nil {
1263 return stepExecutionResult{}, err
1264 }
1265 }
1266
1267 switch part.Type {
1268 case StreamPartTypeWarnings:
1269 stepWarnings = part.Warnings
1270 if opts.OnWarnings != nil {
1271 err := opts.OnWarnings(part.Warnings)
1272 if err != nil {
1273 return stepExecutionResult{}, err
1274 }
1275 }
1276
1277 case StreamPartTypeTextStart:
1278 activeTextContent[part.ID] = ""
1279 if opts.OnTextStart != nil {
1280 err := opts.OnTextStart(part.ID)
1281 if err != nil {
1282 return stepExecutionResult{}, err
1283 }
1284 }
1285
1286 case StreamPartTypeTextDelta:
1287 if _, exists := activeTextContent[part.ID]; exists {
1288 activeTextContent[part.ID] += part.Delta
1289 }
1290 if opts.OnTextDelta != nil {
1291 err := opts.OnTextDelta(part.ID, part.Delta)
1292 if err != nil {
1293 return stepExecutionResult{}, err
1294 }
1295 }
1296
1297 case StreamPartTypeTextEnd:
1298 if text, exists := activeTextContent[part.ID]; exists {
1299 stepContent = append(stepContent, TextContent{
1300 Text: text,
1301 ProviderMetadata: part.ProviderMetadata,
1302 })
1303 delete(activeTextContent, part.ID)
1304 }
1305 if opts.OnTextEnd != nil {
1306 err := opts.OnTextEnd(part.ID)
1307 if err != nil {
1308 return stepExecutionResult{}, err
1309 }
1310 }
1311
1312 case StreamPartTypeReasoningStart:
1313 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1314 if opts.OnReasoningStart != nil {
1315 content := ReasoningContent{
1316 Text: part.Delta,
1317 ProviderMetadata: part.ProviderMetadata,
1318 }
1319 err := opts.OnReasoningStart(part.ID, content)
1320 if err != nil {
1321 return stepExecutionResult{}, err
1322 }
1323 }
1324
1325 case StreamPartTypeReasoningDelta:
1326 if active, exists := activeReasoningContent[part.ID]; exists {
1327 active.content += part.Delta
1328 active.options = part.ProviderMetadata
1329 activeReasoningContent[part.ID] = active
1330 }
1331 if opts.OnReasoningDelta != nil {
1332 err := opts.OnReasoningDelta(part.ID, part.Delta)
1333 if err != nil {
1334 return stepExecutionResult{}, err
1335 }
1336 }
1337
1338 case StreamPartTypeReasoningEnd:
1339 if active, exists := activeReasoningContent[part.ID]; exists {
1340 if part.ProviderMetadata != nil {
1341 active.options = part.ProviderMetadata
1342 }
1343 content := ReasoningContent{
1344 Text: active.content,
1345 ProviderMetadata: active.options,
1346 }
1347 stepContent = append(stepContent, content)
1348 if opts.OnReasoningEnd != nil {
1349 err := opts.OnReasoningEnd(part.ID, content)
1350 if err != nil {
1351 return stepExecutionResult{}, err
1352 }
1353 }
1354 delete(activeReasoningContent, part.ID)
1355 }
1356
1357 case StreamPartTypeToolInputStart:
1358 activeToolCalls[part.ID] = &ToolCallContent{
1359 ToolCallID: part.ID,
1360 ToolName: part.ToolCallName,
1361 Input: "",
1362 ProviderExecuted: part.ProviderExecuted,
1363 }
1364 if opts.OnToolInputStart != nil {
1365 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1366 if err != nil {
1367 return stepExecutionResult{}, err
1368 }
1369 }
1370
1371 case StreamPartTypeToolInputDelta:
1372 if toolCall, exists := activeToolCalls[part.ID]; exists {
1373 toolCall.Input += part.Delta
1374 }
1375 if opts.OnToolInputDelta != nil {
1376 err := opts.OnToolInputDelta(part.ID, part.Delta)
1377 if err != nil {
1378 return stepExecutionResult{}, err
1379 }
1380 }
1381
1382 case StreamPartTypeToolInputEnd:
1383 if opts.OnToolInputEnd != nil {
1384 err := opts.OnToolInputEnd(part.ID)
1385 if err != nil {
1386 return stepExecutionResult{}, err
1387 }
1388 }
1389
1390 case StreamPartTypeToolCall:
1391 toolCall := ToolCallContent{
1392 ToolCallID: part.ID,
1393 ToolName: part.ToolCallName,
1394 Input: part.ToolCallInput,
1395 ProviderExecuted: part.ProviderExecuted,
1396 ProviderMetadata: part.ProviderMetadata,
1397 }
1398
1399 // Validate and potentially repair the tool call
1400 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1401 stepToolCalls = append(stepToolCalls, validatedToolCall)
1402 stepContent = append(stepContent, validatedToolCall)
1403
1404 if opts.OnToolCall != nil {
1405 err := opts.OnToolCall(validatedToolCall)
1406 if err != nil {
1407 return stepExecutionResult{}, err
1408 }
1409 }
1410
1411 // Determine if tool can run in parallel
1412 isParallel := false
1413 if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1414 isParallel = tool.Info().Parallel
1415 }
1416
1417 // Send tool call to execution channel
1418 toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1419
1420 // Clean up active tool call
1421 delete(activeToolCalls, part.ID)
1422
1423 case StreamPartTypeSource:
1424 sourceContent := SourceContent{
1425 SourceType: part.SourceType,
1426 ID: part.ID,
1427 URL: part.URL,
1428 Title: part.Title,
1429 ProviderMetadata: part.ProviderMetadata,
1430 }
1431 stepContent = append(stepContent, sourceContent)
1432 if opts.OnSource != nil {
1433 err := opts.OnSource(sourceContent)
1434 if err != nil {
1435 return stepExecutionResult{}, err
1436 }
1437 }
1438
1439 case StreamPartTypeFinish:
1440 stepUsage = part.Usage
1441 stepFinishReason = part.FinishReason
1442 stepProviderMetadata = part.ProviderMetadata
1443 if opts.OnStreamFinish != nil {
1444 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1445 if err != nil {
1446 return stepExecutionResult{}, err
1447 }
1448 }
1449
1450 case StreamPartTypeError:
1451 return stepExecutionResult{}, part.Error
1452 }
1453 }
1454
1455 // Close the tool execution channel and wait for all executions to complete
1456 close(toolChan)
1457 toolExecutionWg.Wait()
1458
1459 // Check for tool execution errors
1460 if toolExecutionErr != nil {
1461 return stepExecutionResult{}, toolExecutionErr
1462 }
1463
1464 // Add tool results to content if any
1465 if len(toolResults) > 0 {
1466 for _, result := range toolResults {
1467 stepContent = append(stepContent, result)
1468 }
1469 }
1470
1471 stepResult := StepResult{
1472 Response: Response{
1473 Content: stepContent,
1474 FinishReason: stepFinishReason,
1475 Usage: stepUsage,
1476 Warnings: stepWarnings,
1477 ProviderMetadata: stepProviderMetadata,
1478 },
1479 Messages: toResponseMessages(stepContent),
1480 }
1481
1482 // Determine if we should continue (has tool calls and not stopped)
1483 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1484
1485 return stepExecutionResult{
1486 StepResult: stepResult,
1487 ShouldContinue: shouldContinue,
1488 }, nil
1489}
1490
1491func addUsage(a, b Usage) Usage {
1492 return Usage{
1493 InputTokens: a.InputTokens + b.InputTokens,
1494 OutputTokens: a.OutputTokens + b.OutputTokens,
1495 TotalTokens: a.TotalTokens + b.TotalTokens,
1496 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1497 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1498 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1499 }
1500}
1501
1502// WithHeaders sets the headers for the agent.
1503func WithHeaders(headers map[string]string) AgentOption {
1504 return func(s *agentSettings) {
1505 s.headers = headers
1506 }
1507}
1508
1509// WithProviderOptions sets the provider options for the agent.
1510func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1511 return func(s *agentSettings) {
1512 s.providerOptions = providerOptions
1513 }
1514}