1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12
13 "charm.land/fantasy/schema"
14)
15
16// StepResult represents the result of a single step in an agent execution.
17type StepResult struct {
18 Response
19 Messages []Message
20}
21
22// stepExecutionResult encapsulates the result of executing a step with stream processing.
23type stepExecutionResult struct {
24 StepResult StepResult
25 ShouldContinue bool
26}
27
28// StopCondition defines a function that determines when an agent should stop executing.
29type StopCondition = func(steps []StepResult) bool
30
31// StepCountIs returns a stop condition that stops after the specified number of steps.
32func StepCountIs(stepCount int) StopCondition {
33 return func(steps []StepResult) bool {
34 return len(steps) >= stepCount
35 }
36}
37
38// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
39func HasToolCall(toolName string) StopCondition {
40 return func(steps []StepResult) bool {
41 if len(steps) == 0 {
42 return false
43 }
44 lastStep := steps[len(steps)-1]
45 toolCalls := lastStep.Content.ToolCalls()
46 for _, toolCall := range toolCalls {
47 if toolCall.ToolName == toolName {
48 return true
49 }
50 }
51 return false
52 }
53}
54
55// HasContent returns a stop condition that stops when the specified content type appears in the last step.
56func HasContent(contentType ContentType) StopCondition {
57 return func(steps []StepResult) bool {
58 if len(steps) == 0 {
59 return false
60 }
61 lastStep := steps[len(steps)-1]
62 for _, content := range lastStep.Content {
63 if content.GetType() == contentType {
64 return true
65 }
66 }
67 return false
68 }
69}
70
71// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
72func FinishReasonIs(reason FinishReason) StopCondition {
73 return func(steps []StepResult) bool {
74 if len(steps) == 0 {
75 return false
76 }
77 lastStep := steps[len(steps)-1]
78 return lastStep.FinishReason == reason
79 }
80}
81
82// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
83func MaxTokensUsed(maxTokens int64) StopCondition {
84 return func(steps []StepResult) bool {
85 var totalTokens int64
86 for _, step := range steps {
87 totalTokens += step.Usage.TotalTokens
88 }
89 return totalTokens >= maxTokens
90 }
91}
92
93// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
94type PrepareStepFunctionOptions struct {
95 Steps []StepResult
96 StepNumber int
97 Model LanguageModel
98 Messages []Message
99}
100
101// PrepareStepResult contains the result of preparing a step in an agent execution.
102type PrepareStepResult struct {
103 Model LanguageModel
104 Messages []Message
105 System *string
106 ToolChoice *ToolChoice
107 ActiveTools []string
108 DisableAllTools bool
109 Tools []AgentTool
110}
111
112// ToolCallRepairOptions contains the options for repairing a tool call.
113type ToolCallRepairOptions struct {
114 OriginalToolCall ToolCallContent
115 ValidationError error
116 AvailableTools []AgentTool
117 SystemPrompt string
118 Messages []Message
119}
120
121type (
122 // PrepareStepFunction defines a function that prepares a step in an agent execution.
123 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
124
125 // OnStepFinishedFunction defines a function that is called when a step finishes.
126 OnStepFinishedFunction = func(step StepResult)
127
128 // RepairToolCallFunction defines a function that repairs a tool call.
129 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
130)
131
132type agentSettings struct {
133 systemPrompt string
134 maxOutputTokens *int64
135 temperature *float64
136 topP *float64
137 topK *int64
138 presencePenalty *float64
139 frequencyPenalty *float64
140 headers map[string]string
141 userAgent string
142 providerOptions ProviderOptions
143
144 providerDefinedTools []ProviderDefinedTool
145 tools []AgentTool
146 maxRetries *int
147
148 model LanguageModel
149
150 stopWhen []StopCondition
151 prepareStep PrepareStepFunction
152 repairToolCall RepairToolCallFunction
153 onRetry OnRetryCallback
154}
155
156// AgentCall represents a call to an agent.
157type AgentCall struct {
158 Prompt string `json:"prompt"`
159 Files []FilePart `json:"files"`
160 Messages []Message `json:"messages"`
161 MaxOutputTokens *int64
162 Temperature *float64 `json:"temperature"`
163 TopP *float64 `json:"top_p"`
164 TopK *int64 `json:"top_k"`
165 PresencePenalty *float64 `json:"presence_penalty"`
166 FrequencyPenalty *float64 `json:"frequency_penalty"`
167 ActiveTools []string `json:"active_tools"`
168 ProviderOptions ProviderOptions
169 OnRetry OnRetryCallback
170 MaxRetries *int
171
172 StopWhen []StopCondition
173 PrepareStep PrepareStepFunction
174 RepairToolCall RepairToolCallFunction
175}
176
177// Agent-level callbacks.
178type (
179 // OnAgentStartFunc is called when agent starts.
180 OnAgentStartFunc func()
181
182 // OnAgentFinishFunc is called when agent finishes.
183 OnAgentFinishFunc func(result *AgentResult) error
184
185 // OnStepStartFunc is called when a step starts.
186 OnStepStartFunc func(stepNumber int) error
187
188 // OnStepFinishFunc is called when a step finishes.
189 OnStepFinishFunc func(stepResult StepResult) error
190
191 // OnFinishFunc is called when entire agent completes.
192 OnFinishFunc func(result *AgentResult)
193
194 // OnErrorFunc is called when an error occurs.
195 OnErrorFunc func(error)
196)
197
198// Stream part callbacks - called for each corresponding stream part type.
199type (
200 // OnChunkFunc is called for each stream part (catch-all).
201 OnChunkFunc func(StreamPart) error
202
203 // OnWarningsFunc is called for warnings.
204 OnWarningsFunc func(warnings []CallWarning) error
205
206 // OnTextStartFunc is called when text starts.
207 OnTextStartFunc func(id string) error
208
209 // OnTextDeltaFunc is called for text deltas.
210 OnTextDeltaFunc func(id, text string) error
211
212 // OnTextEndFunc is called when text ends.
213 OnTextEndFunc func(id string) error
214
215 // OnReasoningStartFunc is called when reasoning starts.
216 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
217
218 // OnReasoningDeltaFunc is called for reasoning deltas.
219 OnReasoningDeltaFunc func(id, text string) error
220
221 // OnReasoningEndFunc is called when reasoning ends.
222 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
223
224 // OnToolInputStartFunc is called when tool input starts.
225 OnToolInputStartFunc func(id, toolName string) error
226
227 // OnToolInputDeltaFunc is called for tool input deltas.
228 OnToolInputDeltaFunc func(id, delta string) error
229
230 // OnToolInputEndFunc is called when tool input ends.
231 OnToolInputEndFunc func(id string) error
232
233 // OnToolCallFunc is called when tool call is complete.
234 OnToolCallFunc func(toolCall ToolCallContent) error
235
236 // OnToolResultFunc is called when tool execution completes.
237 OnToolResultFunc func(result ToolResultContent) error
238
239 // OnSourceFunc is called for source references.
240 OnSourceFunc func(source SourceContent) error
241
242 // OnStreamFinishFunc is called when stream finishes.
243 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
244)
245
246// AgentStreamCall represents a streaming call to an agent.
247type AgentStreamCall struct {
248 Prompt string `json:"prompt"`
249 Files []FilePart `json:"files"`
250 Messages []Message `json:"messages"`
251 MaxOutputTokens *int64
252 Temperature *float64 `json:"temperature"`
253 TopP *float64 `json:"top_p"`
254 TopK *int64 `json:"top_k"`
255 PresencePenalty *float64 `json:"presence_penalty"`
256 FrequencyPenalty *float64 `json:"frequency_penalty"`
257 ActiveTools []string `json:"active_tools"`
258 Headers map[string]string
259 ProviderOptions ProviderOptions
260 OnRetry OnRetryCallback
261 MaxRetries *int
262
263 StopWhen []StopCondition
264 PrepareStep PrepareStepFunction
265 RepairToolCall RepairToolCallFunction
266
267 // Agent-level callbacks
268 OnAgentStart OnAgentStartFunc // Called when agent starts
269 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
270 OnStepStart OnStepStartFunc // Called when a step starts
271 OnStepFinish OnStepFinishFunc // Called when a step finishes
272 OnFinish OnFinishFunc // Called when entire agent completes
273 OnError OnErrorFunc // Called when an error occurs
274
275 // Stream part callbacks - called for each corresponding stream part type
276 OnChunk OnChunkFunc // Called for each stream part (catch-all)
277 OnWarnings OnWarningsFunc // Called for warnings
278 OnTextStart OnTextStartFunc // Called when text starts
279 OnTextDelta OnTextDeltaFunc // Called for text deltas
280 OnTextEnd OnTextEndFunc // Called when text ends
281 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
282 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
283 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
284 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
285 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
286 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
287 OnToolCall OnToolCallFunc // Called when tool call is complete
288 OnToolResult OnToolResultFunc // Called when tool execution completes
289 OnSource OnSourceFunc // Called for source references
290 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
291}
292
293// AgentResult represents the result of an agent execution.
294type AgentResult struct {
295 Steps []StepResult
296 // Final response
297 Response Response
298 TotalUsage Usage
299}
300
301// Agent represents an AI agent that can generate responses and stream responses.
302type Agent interface {
303 Generate(context.Context, AgentCall) (*AgentResult, error)
304 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
305}
306
307// AgentOption defines a function that configures agent settings.
308type AgentOption = func(*agentSettings)
309
310type agent struct {
311 settings agentSettings
312}
313
314// NewAgent creates a new agent with the given language model and options.
315func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
316 settings := agentSettings{
317 model: model,
318 }
319 for _, o := range opts {
320 o(&settings)
321 }
322 return &agent{
323 settings: settings,
324 }
325}
326
327func (a *agent) prepareCall(call AgentCall) AgentCall {
328 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
329 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
330 call.TopP = cmp.Or(call.TopP, a.settings.topP)
331 call.TopK = cmp.Or(call.TopK, a.settings.topK)
332 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
333 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
334 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
335
336 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
337 call.StopWhen = a.settings.stopWhen
338 }
339 if call.PrepareStep == nil && a.settings.prepareStep != nil {
340 call.PrepareStep = a.settings.prepareStep
341 }
342 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
343 call.RepairToolCall = a.settings.repairToolCall
344 }
345 if call.OnRetry == nil && a.settings.onRetry != nil {
346 call.OnRetry = a.settings.onRetry
347 }
348
349 providerOptions := ProviderOptions{}
350 if a.settings.providerOptions != nil {
351 maps.Copy(providerOptions, a.settings.providerOptions)
352 }
353 if call.ProviderOptions != nil {
354 maps.Copy(providerOptions, call.ProviderOptions)
355 }
356 call.ProviderOptions = providerOptions
357
358 headers := map[string]string{}
359
360 if a.settings.headers != nil {
361 maps.Copy(headers, a.settings.headers)
362 }
363
364 return call
365}
366
367// Generate implements Agent.
368func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
369 opts = a.prepareCall(opts)
370 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
371 if err != nil {
372 return nil, err
373 }
374 var responseMessages []Message
375 var steps []StepResult
376
377 for {
378 stepInputMessages := append(initialPrompt, responseMessages...)
379 stepModel := a.settings.model
380 stepSystemPrompt := a.settings.systemPrompt
381 stepActiveTools := opts.ActiveTools
382 stepToolChoice := ToolChoiceAuto
383 disableAllTools := false
384 stepTools := a.settings.tools
385 if opts.PrepareStep != nil {
386 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
387 Model: stepModel,
388 Steps: steps,
389 StepNumber: len(steps),
390 Messages: stepInputMessages,
391 })
392 if err != nil {
393 return nil, err
394 }
395
396 ctx = updatedCtx
397
398 // Apply prepared step modifications
399 if prepared.Messages != nil {
400 stepInputMessages = prepared.Messages
401 }
402 if prepared.Model != nil {
403 stepModel = prepared.Model
404 }
405 if prepared.System != nil {
406 stepSystemPrompt = *prepared.System
407 }
408 if prepared.ToolChoice != nil {
409 stepToolChoice = *prepared.ToolChoice
410 }
411 if len(prepared.ActiveTools) > 0 {
412 stepActiveTools = prepared.ActiveTools
413 }
414 disableAllTools = prepared.DisableAllTools
415 if prepared.Tools != nil {
416 stepTools = prepared.Tools
417 }
418 }
419
420 // Recreate prompt with potentially modified system prompt
421 if stepSystemPrompt != a.settings.systemPrompt {
422 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
423 if err != nil {
424 return nil, err
425 }
426 // Replace system message part, keep the rest
427 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
428 stepInputMessages[0] = stepPrompt[0] // Replace system message
429 }
430 }
431
432 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
433
434 retryOptions := DefaultRetryOptions()
435 if opts.MaxRetries != nil {
436 retryOptions.MaxRetries = *opts.MaxRetries
437 }
438 retryOptions.OnRetry = opts.OnRetry
439 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
440
441 result, err := retry(ctx, func() (*Response, error) {
442 return stepModel.Generate(ctx, Call{
443 Prompt: stepInputMessages,
444 MaxOutputTokens: opts.MaxOutputTokens,
445 Temperature: opts.Temperature,
446 TopP: opts.TopP,
447 TopK: opts.TopK,
448 PresencePenalty: opts.PresencePenalty,
449 FrequencyPenalty: opts.FrequencyPenalty,
450 Tools: preparedTools,
451 ToolChoice: &stepToolChoice,
452 UserAgent: a.settings.userAgent,
453 ProviderOptions: opts.ProviderOptions,
454 })
455 })
456 if err != nil {
457 return nil, err
458 }
459
460 var stepToolCalls []ToolCallContent
461 for _, content := range result.Content {
462 if content.GetType() == ContentTypeToolCall {
463 toolCall, ok := AsContentType[ToolCallContent](content)
464 if !ok {
465 continue
466 }
467 // Provider-executed tool calls (e.g. web search) are
468 // handled by the provider and should not be validated
469 // or executed by the agent.
470 if toolCall.ProviderExecuted {
471 continue
472 }
473 // Validate and potentially repair the tool call
474 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
475 stepToolCalls = append(stepToolCalls, validatedToolCall)
476 }
477 }
478
479 toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
480
481 // Build step content with validated tool calls and tool results.
482 // Provider-executed tool calls are kept as-is.
483 stepContent := []Content{}
484 toolCallIndex := 0
485 for _, content := range result.Content {
486 if content.GetType() == ContentTypeToolCall {
487 tc, ok := AsContentType[ToolCallContent](content)
488 if ok && tc.ProviderExecuted {
489 stepContent = append(stepContent, content)
490 continue
491 }
492 // Replace with validated tool call.
493 if toolCallIndex < len(stepToolCalls) {
494 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
495 toolCallIndex++
496 }
497 } else {
498 stepContent = append(stepContent, content)
499 }
500 } // Add tool results
501 for _, result := range toolResults {
502 stepContent = append(stepContent, result)
503 }
504 currentStepMessages := toResponseMessages(stepContent)
505 responseMessages = append(responseMessages, currentStepMessages...)
506
507 stepResult := StepResult{
508 Response: Response{
509 Content: stepContent,
510 FinishReason: result.FinishReason,
511 Usage: result.Usage,
512 Warnings: result.Warnings,
513 ProviderMetadata: result.ProviderMetadata,
514 },
515 Messages: currentStepMessages,
516 }
517 steps = append(steps, stepResult)
518 shouldStop := isStopConditionMet(opts.StopWhen, steps)
519
520 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
521 break
522 }
523 }
524
525 totalUsage := Usage{}
526
527 for _, step := range steps {
528 usage := step.Usage
529 totalUsage.InputTokens += usage.InputTokens
530 totalUsage.OutputTokens += usage.OutputTokens
531 totalUsage.ReasoningTokens += usage.ReasoningTokens
532 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
533 totalUsage.CacheReadTokens += usage.CacheReadTokens
534 totalUsage.TotalTokens += usage.TotalTokens
535 }
536
537 agentResult := &AgentResult{
538 Steps: steps,
539 Response: steps[len(steps)-1].Response,
540 TotalUsage: totalUsage,
541 }
542 return agentResult, nil
543}
544
545func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
546 if len(conditions) == 0 {
547 return false
548 }
549
550 for _, condition := range conditions {
551 if condition(steps) {
552 return true
553 }
554 }
555 return false
556}
557
558func toResponseMessages(content []Content) []Message {
559 var assistantParts []MessagePart
560 var toolParts []MessagePart
561
562 for _, c := range content {
563 switch c.GetType() {
564 case ContentTypeText:
565 text, ok := AsContentType[TextContent](c)
566 if !ok {
567 continue
568 }
569 assistantParts = append(assistantParts, TextPart{
570 Text: text.Text,
571 ProviderOptions: ProviderOptions(text.ProviderMetadata),
572 })
573 case ContentTypeReasoning:
574 reasoning, ok := AsContentType[ReasoningContent](c)
575 if !ok {
576 continue
577 }
578 assistantParts = append(assistantParts, ReasoningPart{
579 Text: reasoning.Text,
580 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
581 })
582 case ContentTypeToolCall:
583 toolCall, ok := AsContentType[ToolCallContent](c)
584 if !ok {
585 continue
586 }
587 assistantParts = append(assistantParts, ToolCallPart{
588 ToolCallID: toolCall.ToolCallID,
589 ToolName: toolCall.ToolName,
590 Input: toolCall.Input,
591 ProviderExecuted: toolCall.ProviderExecuted,
592 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
593 })
594 case ContentTypeFile:
595 file, ok := AsContentType[FileContent](c)
596 if !ok {
597 continue
598 }
599 assistantParts = append(assistantParts, FilePart{
600 Data: file.Data,
601 MediaType: file.MediaType,
602 ProviderOptions: ProviderOptions(file.ProviderMetadata),
603 })
604 case ContentTypeSource:
605 // Sources are metadata about references used to generate the response.
606 // They don't need to be included in the conversation messages.
607 continue
608 case ContentTypeToolResult:
609 result, ok := AsContentType[ToolResultContent](c)
610 if !ok {
611 continue
612 }
613 resultPart := ToolResultPart{
614 ToolCallID: result.ToolCallID,
615 Output: result.Result,
616 ProviderExecuted: result.ProviderExecuted,
617 ProviderOptions: ProviderOptions(result.ProviderMetadata),
618 }
619 if result.ProviderExecuted {
620 // Provider-executed tool results (e.g. web search)
621 // belong in the assistant message alongside the
622 // server_tool_use block that produced them.
623 assistantParts = append(assistantParts, resultPart)
624 } else {
625 toolParts = append(toolParts, resultPart)
626 }
627 }
628 }
629
630 var messages []Message
631 if len(assistantParts) > 0 {
632 messages = append(messages, Message{
633 Role: MessageRoleAssistant,
634 Content: assistantParts,
635 })
636 }
637 if len(toolParts) > 0 {
638 messages = append(messages, Message{
639 Role: MessageRoleTool,
640 Content: toolParts,
641 })
642 }
643 return messages
644}
645
646func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
647 if len(toolCalls) == 0 {
648 return nil, nil
649 }
650
651 // Create a map for quick tool lookup
652 toolMap := make(map[string]AgentTool)
653 for _, tool := range allTools {
654 toolMap[tool.Info().Name] = tool
655 }
656
657 // Execute all tool calls sequentially in order
658 results := make([]ToolResultContent, 0, len(toolCalls))
659
660 for _, toolCall := range toolCalls {
661 result, isCriticalError := a.executeSingleTool(ctx, toolMap, toolCall, toolResultCallback)
662 results = append(results, result)
663 if isCriticalError {
664 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
665 return nil, errorResult.Error
666 }
667 }
668 }
669
670 return results, nil
671}
672
673// executeSingleTool executes a single tool and returns its result and a critical error flag.
674func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
675 result := ToolResultContent{
676 ToolCallID: toolCall.ToolCallID,
677 ToolName: toolCall.ToolName,
678 ProviderExecuted: false,
679 }
680
681 // Skip invalid tool calls - create error result (not critical)
682 if toolCall.Invalid {
683 result.Result = ToolResultOutputContentError{
684 Error: toolCall.ValidationError,
685 }
686 if toolResultCallback != nil {
687 _ = toolResultCallback(result)
688 }
689 return result, false
690 }
691
692 tool, exists := toolMap[toolCall.ToolName]
693 if !exists {
694 result.Result = ToolResultOutputContentError{
695 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
696 }
697 if toolResultCallback != nil {
698 _ = toolResultCallback(result)
699 }
700 return result, false
701 }
702
703 // Execute the tool
704 toolResult, err := tool.Run(ctx, ToolCall{
705 ID: toolCall.ToolCallID,
706 Name: toolCall.ToolName,
707 Input: toolCall.Input,
708 })
709 if err != nil {
710 result.Result = ToolResultOutputContentError{
711 Error: err,
712 }
713 result.ClientMetadata = toolResult.Metadata
714 if toolResultCallback != nil {
715 _ = toolResultCallback(result)
716 }
717 return result, true
718 }
719
720 result.ClientMetadata = toolResult.Metadata
721 if toolResult.IsError {
722 result.Result = ToolResultOutputContentError{
723 Error: errors.New(toolResult.Content),
724 }
725 } else if toolResult.Type == "image" || toolResult.Type == "media" {
726 result.Result = ToolResultOutputContentMedia{
727 Data: string(toolResult.Data),
728 MediaType: toolResult.MediaType,
729 Text: toolResult.Content,
730 }
731 } else {
732 result.Result = ToolResultOutputContentText{
733 Text: toolResult.Content,
734 }
735 }
736 if toolResultCallback != nil {
737 _ = toolResultCallback(result)
738 }
739 return result, false
740}
741
742// Stream implements Agent.
743func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
744 // Convert AgentStreamCall to AgentCall for preparation
745 call := AgentCall{
746 Prompt: opts.Prompt,
747 Files: opts.Files,
748 Messages: opts.Messages,
749 MaxOutputTokens: opts.MaxOutputTokens,
750 Temperature: opts.Temperature,
751 TopP: opts.TopP,
752 TopK: opts.TopK,
753 PresencePenalty: opts.PresencePenalty,
754 FrequencyPenalty: opts.FrequencyPenalty,
755 ActiveTools: opts.ActiveTools,
756 ProviderOptions: opts.ProviderOptions,
757 MaxRetries: opts.MaxRetries,
758 OnRetry: opts.OnRetry,
759 StopWhen: opts.StopWhen,
760 PrepareStep: opts.PrepareStep,
761 RepairToolCall: opts.RepairToolCall,
762 }
763
764 call = a.prepareCall(call)
765
766 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
767 if err != nil {
768 return nil, err
769 }
770
771 var responseMessages []Message
772 var steps []StepResult
773 var totalUsage Usage
774
775 // Start agent stream
776 if opts.OnAgentStart != nil {
777 opts.OnAgentStart()
778 }
779
780 for stepNumber := 0; ; stepNumber++ {
781 stepInputMessages := append(initialPrompt, responseMessages...)
782 stepModel := a.settings.model
783 stepSystemPrompt := a.settings.systemPrompt
784 stepActiveTools := call.ActiveTools
785 stepToolChoice := ToolChoiceAuto
786 disableAllTools := false
787 stepTools := a.settings.tools
788 // Apply step preparation if provided
789 if call.PrepareStep != nil {
790 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
791 Model: stepModel,
792 Steps: steps,
793 StepNumber: stepNumber,
794 Messages: stepInputMessages,
795 })
796 if err != nil {
797 return nil, err
798 }
799
800 ctx = updatedCtx
801
802 if prepared.Messages != nil {
803 stepInputMessages = prepared.Messages
804 }
805 if prepared.Model != nil {
806 stepModel = prepared.Model
807 }
808 if prepared.System != nil {
809 stepSystemPrompt = *prepared.System
810 }
811 if prepared.ToolChoice != nil {
812 stepToolChoice = *prepared.ToolChoice
813 }
814 if len(prepared.ActiveTools) > 0 {
815 stepActiveTools = prepared.ActiveTools
816 }
817 disableAllTools = prepared.DisableAllTools
818 if prepared.Tools != nil {
819 stepTools = prepared.Tools
820 }
821 }
822
823 // Recreate prompt with potentially modified system prompt
824 if stepSystemPrompt != a.settings.systemPrompt {
825 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
826 if err != nil {
827 return nil, err
828 }
829 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
830 stepInputMessages[0] = stepPrompt[0]
831 }
832 }
833
834 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
835
836 // Start step stream
837 if opts.OnStepStart != nil {
838 _ = opts.OnStepStart(stepNumber)
839 }
840
841 // Create streaming call
842 streamCall := Call{
843 Prompt: stepInputMessages,
844 MaxOutputTokens: call.MaxOutputTokens,
845 Temperature: call.Temperature,
846 TopP: call.TopP,
847 TopK: call.TopK,
848 PresencePenalty: call.PresencePenalty,
849 FrequencyPenalty: call.FrequencyPenalty,
850 Tools: preparedTools,
851 ToolChoice: &stepToolChoice,
852 UserAgent: a.settings.userAgent,
853 ProviderOptions: call.ProviderOptions,
854 }
855
856 // Execute step with retry logic wrapping both stream creation and processing
857 retryOptions := DefaultRetryOptions()
858 if call.MaxRetries != nil {
859 retryOptions.MaxRetries = *call.MaxRetries
860 }
861 retryOptions.OnRetry = call.OnRetry
862 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
863
864 result, err := retry(ctx, func() (stepExecutionResult, error) {
865 // Create the stream
866 stream, err := stepModel.Stream(ctx, streamCall)
867 if err != nil {
868 return stepExecutionResult{}, err
869 }
870
871 // Process the stream
872 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
873 if err != nil {
874 return stepExecutionResult{}, err
875 }
876
877 return result, nil
878 })
879 if err != nil {
880 if opts.OnError != nil {
881 opts.OnError(err)
882 }
883 return nil, err
884 }
885
886 steps = append(steps, result.StepResult)
887 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
888
889 // Call step finished callback
890 if opts.OnStepFinish != nil {
891 _ = opts.OnStepFinish(result.StepResult)
892 }
893
894 // Add step messages to response messages
895 stepMessages := toResponseMessages(result.StepResult.Content)
896 responseMessages = append(responseMessages, stepMessages...)
897
898 // Check stop conditions
899 shouldStop := isStopConditionMet(call.StopWhen, steps)
900 if shouldStop || !result.ShouldContinue {
901 break
902 }
903 }
904
905 // Finish agent stream
906 agentResult := &AgentResult{
907 Steps: steps,
908 Response: steps[len(steps)-1].Response,
909 TotalUsage: totalUsage,
910 }
911
912 if opts.OnFinish != nil {
913 opts.OnFinish(agentResult)
914 }
915
916 if opts.OnAgentFinish != nil {
917 _ = opts.OnAgentFinish(agentResult)
918 }
919
920 return agentResult, nil
921}
922
923func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
924 preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
925
926 // If explicitly disabling all tools, return no tools
927 if disableAllTools {
928 return preparedTools
929 }
930
931 for _, tool := range tools {
932 // If activeTools has items, only include tools in the list
933 // If activeTools is empty, include all tools
934 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
935 continue
936 }
937 info := tool.Info()
938 inputSchema := map[string]any{
939 "type": "object",
940 "properties": info.Parameters,
941 "required": info.Required,
942 }
943 schema.Normalize(inputSchema)
944 preparedTools = append(preparedTools, FunctionTool{
945 Name: info.Name,
946 Description: info.Description,
947 InputSchema: inputSchema,
948 ProviderOptions: tool.ProviderOptions(),
949 })
950 }
951 for _, tool := range providerDefinedTools {
952 // If activeTools has items, only include tools in the list. If
953 // activeTools is empty, include all tools
954 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
955 continue
956 }
957 preparedTools = append(preparedTools, tool)
958 }
959 return preparedTools
960}
961
962// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
963func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
964 if err := a.validateToolCall(toolCall, availableTools); err == nil {
965 return toolCall
966 } else { //nolint: revive
967 if repairFunc != nil {
968 repairOptions := ToolCallRepairOptions{
969 OriginalToolCall: toolCall,
970 ValidationError: err,
971 AvailableTools: availableTools,
972 SystemPrompt: systemPrompt,
973 Messages: messages,
974 }
975
976 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
977 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
978 return *repairedToolCall
979 }
980 }
981 }
982
983 invalidToolCall := toolCall
984 invalidToolCall.Invalid = true
985 invalidToolCall.ValidationError = err
986 return invalidToolCall
987 }
988}
989
990// validateToolCall validates a tool call against available tools and their schemas.
991func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
992 var tool AgentTool
993 for _, t := range availableTools {
994 if t.Info().Name == toolCall.ToolName {
995 tool = t
996 break
997 }
998 }
999
1000 if tool == nil {
1001 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1002 }
1003
1004 // Validate JSON parsing
1005 var input map[string]any
1006 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1007 return fmt.Errorf("invalid JSON input: %w", err)
1008 }
1009
1010 // Basic schema validation (check required fields)
1011 // TODO: more robust schema validation using JSON Schema or similar
1012 toolInfo := tool.Info()
1013 for _, required := range toolInfo.Required {
1014 if _, exists := input[required]; !exists {
1015 return fmt.Errorf("missing required parameter: %s", required)
1016 }
1017 }
1018 return nil
1019}
1020
1021func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1022 if prompt == "" {
1023 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
1024 }
1025
1026 var preparedPrompt Prompt
1027
1028 if system != "" {
1029 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1030 }
1031 preparedPrompt = append(preparedPrompt, messages...)
1032 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1033 return preparedPrompt, nil
1034}
1035
1036// WithSystemPrompt sets the system prompt for the agent.
1037func WithSystemPrompt(prompt string) AgentOption {
1038 return func(s *agentSettings) {
1039 s.systemPrompt = prompt
1040 }
1041}
1042
1043// WithMaxOutputTokens sets the maximum output tokens for the agent.
1044func WithMaxOutputTokens(tokens int64) AgentOption {
1045 return func(s *agentSettings) {
1046 s.maxOutputTokens = &tokens
1047 }
1048}
1049
1050// WithTemperature sets the temperature for the agent.
1051func WithTemperature(temp float64) AgentOption {
1052 return func(s *agentSettings) {
1053 s.temperature = &temp
1054 }
1055}
1056
1057// WithTopP sets the top-p value for the agent.
1058func WithTopP(topP float64) AgentOption {
1059 return func(s *agentSettings) {
1060 s.topP = &topP
1061 }
1062}
1063
1064// WithTopK sets the top-k value for the agent.
1065func WithTopK(topK int64) AgentOption {
1066 return func(s *agentSettings) {
1067 s.topK = &topK
1068 }
1069}
1070
1071// WithPresencePenalty sets the presence penalty for the agent.
1072func WithPresencePenalty(penalty float64) AgentOption {
1073 return func(s *agentSettings) {
1074 s.presencePenalty = &penalty
1075 }
1076}
1077
1078// WithFrequencyPenalty sets the frequency penalty for the agent.
1079func WithFrequencyPenalty(penalty float64) AgentOption {
1080 return func(s *agentSettings) {
1081 s.frequencyPenalty = &penalty
1082 }
1083}
1084
1085// WithTools sets the tools for the agent.
1086func WithTools(tools ...AgentTool) AgentOption {
1087 return func(s *agentSettings) {
1088 s.tools = append(s.tools, tools...)
1089 }
1090}
1091
1092// WithProviderDefinedTools sets the provider-defined tools for the agent.
1093// These tools are executed by the provider (e.g. web search) rather
1094// than by the client.
1095func WithProviderDefinedTools(tools ...ProviderDefinedTool) AgentOption {
1096 return func(s *agentSettings) {
1097 s.providerDefinedTools = append(s.providerDefinedTools, tools...)
1098 }
1099}
1100
1101// WithStopConditions sets the stop conditions for the agent.
1102func WithStopConditions(conditions ...StopCondition) AgentOption {
1103 return func(s *agentSettings) {
1104 s.stopWhen = append(s.stopWhen, conditions...)
1105 }
1106}
1107
1108// WithPrepareStep sets the prepare step function for the agent.
1109func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1110 return func(s *agentSettings) {
1111 s.prepareStep = fn
1112 }
1113}
1114
1115// WithRepairToolCall sets the repair tool call function for the agent.
1116func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1117 return func(s *agentSettings) {
1118 s.repairToolCall = fn
1119 }
1120}
1121
1122// WithMaxRetries sets the maximum number of retries for the agent.
1123func WithMaxRetries(maxRetries int) AgentOption {
1124 return func(s *agentSettings) {
1125 s.maxRetries = &maxRetries
1126 }
1127}
1128
1129// WithOnRetry sets the retry callback for the agent.
1130func WithOnRetry(callback OnRetryCallback) AgentOption {
1131 return func(s *agentSettings) {
1132 s.onRetry = callback
1133 }
1134}
1135
1136// processStepStream processes a single step's stream and returns the step result.
1137func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1138 var stepContent []Content
1139 var stepToolCalls []ToolCallContent
1140 var stepUsage Usage
1141 stepFinishReason := FinishReasonUnknown
1142 var stepWarnings []CallWarning
1143 var stepProviderMetadata ProviderMetadata
1144
1145 activeToolCalls := make(map[string]*ToolCallContent)
1146 activeTextContent := make(map[string]string)
1147 type reasoningContent struct {
1148 content string
1149 options ProviderMetadata
1150 }
1151 activeReasoningContent := make(map[string]reasoningContent)
1152
1153 // Set up concurrent tool execution
1154 type toolExecutionRequest struct {
1155 toolCall ToolCallContent
1156 parallel bool
1157 }
1158 toolChan := make(chan toolExecutionRequest, 10)
1159 var toolExecutionWg sync.WaitGroup
1160 var toolStateMu sync.Mutex
1161 toolResults := make([]ToolResultContent, 0)
1162 var toolExecutionErr error
1163
1164 // Create a map for quick tool lookup
1165 toolMap := make(map[string]AgentTool)
1166 for _, tool := range stepTools {
1167 toolMap[tool.Info().Name] = tool
1168 }
1169
1170 // Semaphores for controlling parallelism
1171 parallelSem := make(chan struct{}, 5)
1172 var sequentialMu sync.Mutex
1173
1174 // Single coordinator goroutine that dispatches tools
1175 toolExecutionWg.Go(func() {
1176 for req := range toolChan {
1177 if req.parallel {
1178 parallelSem <- struct{}{}
1179 toolExecutionWg.Go(func() {
1180 defer func() { <-parallelSem }()
1181 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1182 toolStateMu.Lock()
1183 toolResults = append(toolResults, result)
1184 if isCriticalError && toolExecutionErr == nil {
1185 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1186 toolExecutionErr = errorResult.Error
1187 }
1188 }
1189 toolStateMu.Unlock()
1190 })
1191 } else {
1192 sequentialMu.Lock()
1193 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1194 toolStateMu.Lock()
1195 toolResults = append(toolResults, result)
1196 if isCriticalError && toolExecutionErr == nil {
1197 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1198 toolExecutionErr = errorResult.Error
1199 }
1200 }
1201 toolStateMu.Unlock()
1202 sequentialMu.Unlock()
1203 }
1204 }
1205 })
1206
1207 // Process stream parts
1208 for part := range stream {
1209 // Forward all parts to chunk callback
1210 if opts.OnChunk != nil {
1211 err := opts.OnChunk(part)
1212 if err != nil {
1213 return stepExecutionResult{}, err
1214 }
1215 }
1216
1217 switch part.Type {
1218 case StreamPartTypeWarnings:
1219 stepWarnings = part.Warnings
1220 if opts.OnWarnings != nil {
1221 err := opts.OnWarnings(part.Warnings)
1222 if err != nil {
1223 return stepExecutionResult{}, err
1224 }
1225 }
1226
1227 case StreamPartTypeTextStart:
1228 activeTextContent[part.ID] = ""
1229 if opts.OnTextStart != nil {
1230 err := opts.OnTextStart(part.ID)
1231 if err != nil {
1232 return stepExecutionResult{}, err
1233 }
1234 }
1235
1236 case StreamPartTypeTextDelta:
1237 if _, exists := activeTextContent[part.ID]; exists {
1238 activeTextContent[part.ID] += part.Delta
1239 }
1240 if opts.OnTextDelta != nil {
1241 err := opts.OnTextDelta(part.ID, part.Delta)
1242 if err != nil {
1243 return stepExecutionResult{}, err
1244 }
1245 }
1246
1247 case StreamPartTypeTextEnd:
1248 if text, exists := activeTextContent[part.ID]; exists {
1249 stepContent = append(stepContent, TextContent{
1250 Text: text,
1251 ProviderMetadata: part.ProviderMetadata,
1252 })
1253 delete(activeTextContent, part.ID)
1254 }
1255 if opts.OnTextEnd != nil {
1256 err := opts.OnTextEnd(part.ID)
1257 if err != nil {
1258 return stepExecutionResult{}, err
1259 }
1260 }
1261
1262 case StreamPartTypeReasoningStart:
1263 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1264 if opts.OnReasoningStart != nil {
1265 content := ReasoningContent{
1266 Text: part.Delta,
1267 ProviderMetadata: part.ProviderMetadata,
1268 }
1269 err := opts.OnReasoningStart(part.ID, content)
1270 if err != nil {
1271 return stepExecutionResult{}, err
1272 }
1273 }
1274
1275 case StreamPartTypeReasoningDelta:
1276 if active, exists := activeReasoningContent[part.ID]; exists {
1277 active.content += part.Delta
1278 if part.ProviderMetadata != nil {
1279 active.options = part.ProviderMetadata
1280 }
1281 activeReasoningContent[part.ID] = active
1282 }
1283 if opts.OnReasoningDelta != nil {
1284 err := opts.OnReasoningDelta(part.ID, part.Delta)
1285 if err != nil {
1286 return stepExecutionResult{}, err
1287 }
1288 }
1289
1290 case StreamPartTypeReasoningEnd:
1291 if active, exists := activeReasoningContent[part.ID]; exists {
1292 if part.ProviderMetadata != nil {
1293 active.options = part.ProviderMetadata
1294 }
1295 content := ReasoningContent{
1296 Text: active.content,
1297 ProviderMetadata: active.options,
1298 }
1299 stepContent = append(stepContent, content)
1300 if opts.OnReasoningEnd != nil {
1301 err := opts.OnReasoningEnd(part.ID, content)
1302 if err != nil {
1303 return stepExecutionResult{}, err
1304 }
1305 }
1306 delete(activeReasoningContent, part.ID)
1307 }
1308
1309 case StreamPartTypeToolInputStart:
1310 activeToolCalls[part.ID] = &ToolCallContent{
1311 ToolCallID: part.ID,
1312 ToolName: part.ToolCallName,
1313 Input: "",
1314 ProviderExecuted: part.ProviderExecuted,
1315 }
1316 if opts.OnToolInputStart != nil {
1317 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1318 if err != nil {
1319 return stepExecutionResult{}, err
1320 }
1321 }
1322
1323 case StreamPartTypeToolInputDelta:
1324 if toolCall, exists := activeToolCalls[part.ID]; exists {
1325 toolCall.Input += part.Delta
1326 }
1327 if opts.OnToolInputDelta != nil {
1328 err := opts.OnToolInputDelta(part.ID, part.Delta)
1329 if err != nil {
1330 return stepExecutionResult{}, err
1331 }
1332 }
1333
1334 case StreamPartTypeToolInputEnd:
1335 if opts.OnToolInputEnd != nil {
1336 err := opts.OnToolInputEnd(part.ID)
1337 if err != nil {
1338 return stepExecutionResult{}, err
1339 }
1340 }
1341
1342 case StreamPartTypeToolCall:
1343 toolCall := ToolCallContent{
1344 ToolCallID: part.ID,
1345 ToolName: part.ToolCallName,
1346 Input: part.ToolCallInput,
1347 ProviderExecuted: part.ProviderExecuted,
1348 ProviderMetadata: part.ProviderMetadata,
1349 }
1350
1351 // Provider-executed tool calls are handled by the provider
1352 // and should not be validated or executed by the agent.
1353 if toolCall.ProviderExecuted {
1354 stepContent = append(stepContent, toolCall)
1355 if opts.OnToolCall != nil {
1356 err := opts.OnToolCall(toolCall)
1357 if err != nil {
1358 return stepExecutionResult{}, err
1359 }
1360 }
1361 delete(activeToolCalls, part.ID)
1362 } else {
1363 // Validate and potentially repair the tool call
1364 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1365 stepToolCalls = append(stepToolCalls, validatedToolCall)
1366 stepContent = append(stepContent, validatedToolCall)
1367
1368 if opts.OnToolCall != nil {
1369 err := opts.OnToolCall(validatedToolCall)
1370 if err != nil {
1371 return stepExecutionResult{}, err
1372 }
1373 }
1374
1375 // Determine if tool can run in parallel
1376 isParallel := false
1377 if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1378 isParallel = tool.Info().Parallel
1379 }
1380
1381 // Send tool call to execution channel
1382 toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1383
1384 // Clean up active tool call
1385 delete(activeToolCalls, part.ID)
1386 }
1387
1388 case StreamPartTypeToolResult:
1389 // Provider-executed tool results (e.g. web search)
1390 // are emitted by the provider and added directly
1391 // to the step content for multi-turn round-tripping.
1392 if part.ProviderExecuted {
1393 resultContent := ToolResultContent{
1394 ToolCallID: part.ID,
1395 ToolName: part.ToolCallName,
1396 ProviderExecuted: true,
1397 ProviderMetadata: part.ProviderMetadata,
1398 }
1399 stepContent = append(stepContent, resultContent)
1400 if opts.OnToolResult != nil {
1401 err := opts.OnToolResult(resultContent)
1402 if err != nil {
1403 return stepExecutionResult{}, err
1404 }
1405 }
1406 }
1407
1408 case StreamPartTypeSource:
1409 sourceContent := SourceContent{
1410 SourceType: part.SourceType,
1411 ID: part.ID,
1412 URL: part.URL,
1413 Title: part.Title,
1414 ProviderMetadata: part.ProviderMetadata,
1415 }
1416 stepContent = append(stepContent, sourceContent)
1417 if opts.OnSource != nil {
1418 err := opts.OnSource(sourceContent)
1419 if err != nil {
1420 return stepExecutionResult{}, err
1421 }
1422 }
1423
1424 case StreamPartTypeFinish:
1425 stepUsage = part.Usage
1426 stepFinishReason = part.FinishReason
1427 stepProviderMetadata = part.ProviderMetadata
1428 if opts.OnStreamFinish != nil {
1429 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1430 if err != nil {
1431 return stepExecutionResult{}, err
1432 }
1433 }
1434
1435 case StreamPartTypeError:
1436 return stepExecutionResult{}, part.Error
1437 }
1438 }
1439
1440 // Close the tool execution channel and wait for all executions to complete
1441 close(toolChan)
1442 toolExecutionWg.Wait()
1443
1444 // Check for tool execution errors
1445 if toolExecutionErr != nil {
1446 return stepExecutionResult{}, toolExecutionErr
1447 }
1448
1449 // Add tool results to content if any
1450 if len(toolResults) > 0 {
1451 for _, result := range toolResults {
1452 stepContent = append(stepContent, result)
1453 }
1454 }
1455
1456 stepResult := StepResult{
1457 Response: Response{
1458 Content: stepContent,
1459 FinishReason: stepFinishReason,
1460 Usage: stepUsage,
1461 Warnings: stepWarnings,
1462 ProviderMetadata: stepProviderMetadata,
1463 },
1464 Messages: toResponseMessages(stepContent),
1465 }
1466
1467 // Determine if we should continue (has tool calls and not stopped)
1468 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1469
1470 return stepExecutionResult{
1471 StepResult: stepResult,
1472 ShouldContinue: shouldContinue,
1473 }, nil
1474}
1475
1476func addUsage(a, b Usage) Usage {
1477 return Usage{
1478 InputTokens: a.InputTokens + b.InputTokens,
1479 OutputTokens: a.OutputTokens + b.OutputTokens,
1480 TotalTokens: a.TotalTokens + b.TotalTokens,
1481 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1482 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1483 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1484 }
1485}
1486
1487// WithHeaders sets the headers for the agent.
1488func WithHeaders(headers map[string]string) AgentOption {
1489 return func(s *agentSettings) {
1490 s.headers = headers
1491 }
1492}
1493
1494// WithUserAgent sets the User-Agent header for the agent. This overrides any
1495// provider-level User-Agent setting.
1496func WithUserAgent(ua string) AgentOption {
1497 return func(s *agentSettings) {
1498 s.userAgent = ua
1499 }
1500}
1501
1502// WithProviderOptions sets the provider options for the agent.
1503func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1504 return func(s *agentSettings) {
1505 s.providerOptions = providerOptions
1506 }
1507}