1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12
13 "charm.land/fantasy/schema"
14 "github.com/charmbracelet/x/exp/slice"
15)
16
17// StepResult represents the result of a single step in an agent execution.
18type StepResult struct {
19 Response
20 Messages []Message
21}
22
23// stepExecutionResult encapsulates the result of executing a step with stream processing.
24type stepExecutionResult struct {
25 StepResult StepResult
26 ShouldContinue bool
27}
28
29// StopCondition defines a function that determines when an agent should stop executing.
30type StopCondition = func(steps []StepResult) bool
31
32// StepCountIs returns a stop condition that stops after the specified number of steps.
33func StepCountIs(stepCount int) StopCondition {
34 return func(steps []StepResult) bool {
35 return len(steps) >= stepCount
36 }
37}
38
39// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
40func HasToolCall(toolName string) StopCondition {
41 return func(steps []StepResult) bool {
42 if len(steps) == 0 {
43 return false
44 }
45 lastStep := steps[len(steps)-1]
46 toolCalls := lastStep.Content.ToolCalls()
47 for _, toolCall := range toolCalls {
48 if toolCall.ToolName == toolName {
49 return true
50 }
51 }
52 return false
53 }
54}
55
56// HasContent returns a stop condition that stops when the specified content type appears in the last step.
57func HasContent(contentType ContentType) StopCondition {
58 return func(steps []StepResult) bool {
59 if len(steps) == 0 {
60 return false
61 }
62 lastStep := steps[len(steps)-1]
63 for _, content := range lastStep.Content {
64 if content.GetType() == contentType {
65 return true
66 }
67 }
68 return false
69 }
70}
71
72// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
73func FinishReasonIs(reason FinishReason) StopCondition {
74 return func(steps []StepResult) bool {
75 if len(steps) == 0 {
76 return false
77 }
78 lastStep := steps[len(steps)-1]
79 return lastStep.FinishReason == reason
80 }
81}
82
83// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
84func MaxTokensUsed(maxTokens int64) StopCondition {
85 return func(steps []StepResult) bool {
86 var totalTokens int64
87 for _, step := range steps {
88 totalTokens += step.Usage.TotalTokens
89 }
90 return totalTokens >= maxTokens
91 }
92}
93
94// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
95type PrepareStepFunctionOptions struct {
96 Steps []StepResult
97 StepNumber int
98 Model LanguageModel
99 Messages []Message
100}
101
102// PrepareStepResult contains the result of preparing a step in an agent execution.
103type PrepareStepResult struct {
104 Model LanguageModel
105 Messages []Message
106 System *string
107 ToolChoice *ToolChoice
108 ActiveTools []string
109 DisableAllTools bool
110 Tools []AgentTool
111}
112
113// ToolCallRepairOptions contains the options for repairing a tool call.
114type ToolCallRepairOptions struct {
115 OriginalToolCall ToolCallContent
116 ValidationError error
117 AvailableTools []AgentTool
118 SystemPrompt string
119 Messages []Message
120}
121
122type (
123 // PrepareStepFunction defines a function that prepares a step in an agent execution.
124 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
125
126 // OnStepFinishedFunction defines a function that is called when a step finishes.
127 OnStepFinishedFunction = func(step StepResult)
128
129 // RepairToolCallFunction defines a function that repairs a tool call.
130 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
131)
132
133type agentSettings struct {
134 systemPrompt string
135 maxOutputTokens *int64
136 temperature *float64
137 topP *float64
138 topK *int64
139 presencePenalty *float64
140 frequencyPenalty *float64
141 headers map[string]string
142 userAgent string
143 providerOptions ProviderOptions
144
145 providerDefinedTools []ProviderDefinedTool
146 tools []AgentTool
147 maxRetries *int
148
149 model LanguageModel
150
151 stopWhen []StopCondition
152 prepareStep PrepareStepFunction
153 repairToolCall RepairToolCallFunction
154 onRetry OnRetryCallback
155}
156
157// AgentCall represents a call to an agent.
158type AgentCall struct {
159 Prompt string `json:"prompt"`
160 Files []FilePart `json:"files"`
161 Messages []Message `json:"messages"`
162 MaxOutputTokens *int64
163 Temperature *float64 `json:"temperature"`
164 TopP *float64 `json:"top_p"`
165 TopK *int64 `json:"top_k"`
166 PresencePenalty *float64 `json:"presence_penalty"`
167 FrequencyPenalty *float64 `json:"frequency_penalty"`
168 ActiveTools []string `json:"active_tools"`
169 ProviderOptions ProviderOptions
170 OnRetry OnRetryCallback
171 MaxRetries *int
172
173 StopWhen []StopCondition
174 PrepareStep PrepareStepFunction
175 RepairToolCall RepairToolCallFunction
176}
177
178// Agent-level callbacks.
179type (
180 // OnAgentStartFunc is called when agent starts.
181 OnAgentStartFunc func()
182
183 // OnAgentFinishFunc is called when agent finishes.
184 OnAgentFinishFunc func(result *AgentResult) error
185
186 // OnStepStartFunc is called when a step starts.
187 OnStepStartFunc func(stepNumber int) error
188
189 // OnStepFinishFunc is called when a step finishes.
190 OnStepFinishFunc func(stepResult StepResult) error
191
192 // OnFinishFunc is called when entire agent completes.
193 OnFinishFunc func(result *AgentResult)
194
195 // OnErrorFunc is called when an error occurs.
196 OnErrorFunc func(error)
197)
198
199// Stream part callbacks - called for each corresponding stream part type.
200type (
201 // OnChunkFunc is called for each stream part (catch-all).
202 OnChunkFunc func(StreamPart) error
203
204 // OnWarningsFunc is called for warnings.
205 OnWarningsFunc func(warnings []CallWarning) error
206
207 // OnTextStartFunc is called when text starts.
208 OnTextStartFunc func(id string) error
209
210 // OnTextDeltaFunc is called for text deltas.
211 OnTextDeltaFunc func(id, text string) error
212
213 // OnTextEndFunc is called when text ends.
214 OnTextEndFunc func(id string) error
215
216 // OnReasoningStartFunc is called when reasoning starts.
217 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
218
219 // OnReasoningDeltaFunc is called for reasoning deltas.
220 OnReasoningDeltaFunc func(id, text string) error
221
222 // OnReasoningEndFunc is called when reasoning ends.
223 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
224
225 // OnToolInputStartFunc is called when tool input starts.
226 OnToolInputStartFunc func(id, toolName string) error
227
228 // OnToolInputDeltaFunc is called for tool input deltas.
229 OnToolInputDeltaFunc func(id, delta string) error
230
231 // OnToolInputEndFunc is called when tool input ends.
232 OnToolInputEndFunc func(id string) error
233
234 // OnToolCallFunc is called when tool call is complete.
235 OnToolCallFunc func(toolCall ToolCallContent) error
236
237 // OnToolResultFunc is called when tool execution completes.
238 OnToolResultFunc func(result ToolResultContent) error
239
240 // OnSourceFunc is called for source references.
241 OnSourceFunc func(source SourceContent) error
242
243 // OnStreamFinishFunc is called when stream finishes.
244 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
245)
246
247// AgentStreamCall represents a streaming call to an agent.
248type AgentStreamCall struct {
249 Prompt string `json:"prompt"`
250 Files []FilePart `json:"files"`
251 Messages []Message `json:"messages"`
252 MaxOutputTokens *int64
253 Temperature *float64 `json:"temperature"`
254 TopP *float64 `json:"top_p"`
255 TopK *int64 `json:"top_k"`
256 PresencePenalty *float64 `json:"presence_penalty"`
257 FrequencyPenalty *float64 `json:"frequency_penalty"`
258 ActiveTools []string `json:"active_tools"`
259 Headers map[string]string
260 ProviderOptions ProviderOptions
261 OnRetry OnRetryCallback
262 MaxRetries *int
263
264 StopWhen []StopCondition
265 PrepareStep PrepareStepFunction
266 RepairToolCall RepairToolCallFunction
267
268 // Agent-level callbacks
269 OnAgentStart OnAgentStartFunc // Called when agent starts
270 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
271 OnStepStart OnStepStartFunc // Called when a step starts
272 OnStepFinish OnStepFinishFunc // Called when a step finishes
273 OnFinish OnFinishFunc // Called when entire agent completes
274 OnError OnErrorFunc // Called when an error occurs
275
276 // Stream part callbacks - called for each corresponding stream part type
277 OnChunk OnChunkFunc // Called for each stream part (catch-all)
278 OnWarnings OnWarningsFunc // Called for warnings
279 OnTextStart OnTextStartFunc // Called when text starts
280 OnTextDelta OnTextDeltaFunc // Called for text deltas
281 OnTextEnd OnTextEndFunc // Called when text ends
282 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
283 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
284 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
285 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
286 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
287 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
288 OnToolCall OnToolCallFunc // Called when tool call is complete
289 OnToolResult OnToolResultFunc // Called when tool execution completes
290 OnSource OnSourceFunc // Called for source references
291 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
292}
293
294// AgentResult represents the result of an agent execution.
295type AgentResult struct {
296 Steps []StepResult
297 // Final response
298 Response Response
299 TotalUsage Usage
300}
301
302// Agent represents an AI agent that can generate responses and stream responses.
303type Agent interface {
304 Generate(context.Context, AgentCall) (*AgentResult, error)
305 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
306}
307
308// AgentOption defines a function that configures agent settings.
309type AgentOption = func(*agentSettings)
310
311type agent struct {
312 settings agentSettings
313}
314
315// NewAgent creates a new agent with the given language model and options.
316func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
317 settings := agentSettings{
318 model: model,
319 }
320 for _, o := range opts {
321 o(&settings)
322 }
323 return &agent{
324 settings: settings,
325 }
326}
327
328func (a *agent) prepareCall(call AgentCall) AgentCall {
329 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
330 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
331 call.TopP = cmp.Or(call.TopP, a.settings.topP)
332 call.TopK = cmp.Or(call.TopK, a.settings.topK)
333 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
334 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
335 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
336
337 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
338 call.StopWhen = a.settings.stopWhen
339 }
340 if call.PrepareStep == nil && a.settings.prepareStep != nil {
341 call.PrepareStep = a.settings.prepareStep
342 }
343 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
344 call.RepairToolCall = a.settings.repairToolCall
345 }
346 if call.OnRetry == nil && a.settings.onRetry != nil {
347 call.OnRetry = a.settings.onRetry
348 }
349
350 providerOptions := ProviderOptions{}
351 if a.settings.providerOptions != nil {
352 maps.Copy(providerOptions, a.settings.providerOptions)
353 }
354 if call.ProviderOptions != nil {
355 maps.Copy(providerOptions, call.ProviderOptions)
356 }
357 call.ProviderOptions = providerOptions
358
359 headers := map[string]string{}
360
361 if a.settings.headers != nil {
362 maps.Copy(headers, a.settings.headers)
363 }
364
365 return call
366}
367
368// Generate implements Agent.
369func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
370 opts = a.prepareCall(opts)
371 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
372 if err != nil {
373 return nil, err
374 }
375 var responseMessages []Message
376 var steps []StepResult
377
378 for {
379 stepInputMessages := append(initialPrompt, responseMessages...)
380 stepModel := a.settings.model
381 stepSystemPrompt := a.settings.systemPrompt
382 stepActiveTools := opts.ActiveTools
383 stepToolChoice := ToolChoiceAuto
384 disableAllTools := false
385 stepTools := a.settings.tools
386 if opts.PrepareStep != nil {
387 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
388 Model: stepModel,
389 Steps: steps,
390 StepNumber: len(steps),
391 Messages: stepInputMessages,
392 })
393 if err != nil {
394 return nil, err
395 }
396
397 ctx = updatedCtx
398
399 // Apply prepared step modifications
400 if prepared.Messages != nil {
401 stepInputMessages = prepared.Messages
402 }
403 if prepared.Model != nil {
404 stepModel = prepared.Model
405 }
406 if prepared.System != nil {
407 stepSystemPrompt = *prepared.System
408 }
409 if prepared.ToolChoice != nil {
410 stepToolChoice = *prepared.ToolChoice
411 }
412 if len(prepared.ActiveTools) > 0 {
413 stepActiveTools = prepared.ActiveTools
414 }
415 disableAllTools = prepared.DisableAllTools
416 if prepared.Tools != nil {
417 stepTools = prepared.Tools
418 }
419 }
420
421 // Recreate prompt with potentially modified system prompt
422 if stepSystemPrompt != a.settings.systemPrompt {
423 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
424 if err != nil {
425 return nil, err
426 }
427 // Replace system message part, keep the rest
428 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
429 stepInputMessages[0] = stepPrompt[0] // Replace system message
430 }
431 }
432
433 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
434
435 retryOptions := DefaultRetryOptions()
436 if opts.MaxRetries != nil {
437 retryOptions.MaxRetries = *opts.MaxRetries
438 }
439 retryOptions.OnRetry = opts.OnRetry
440 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
441
442 result, err := retry(ctx, func() (*Response, error) {
443 return stepModel.Generate(ctx, Call{
444 Prompt: stepInputMessages,
445 MaxOutputTokens: opts.MaxOutputTokens,
446 Temperature: opts.Temperature,
447 TopP: opts.TopP,
448 TopK: opts.TopK,
449 PresencePenalty: opts.PresencePenalty,
450 FrequencyPenalty: opts.FrequencyPenalty,
451 Tools: preparedTools,
452 ToolChoice: &stepToolChoice,
453 UserAgent: a.settings.userAgent,
454 ProviderOptions: opts.ProviderOptions,
455 })
456 })
457 if err != nil {
458 return nil, err
459 }
460
461 var stepToolCalls []ToolCallContent
462 for _, content := range result.Content {
463 if content.GetType() == ContentTypeToolCall {
464 toolCall, ok := AsContentType[ToolCallContent](content)
465 if !ok {
466 continue
467 }
468 // Provider-executed tool calls (e.g. web search) are
469 // handled by the provider and should not be validated
470 // or executed by the agent.
471 if toolCall.ProviderExecuted {
472 continue
473 }
474 // Validate and potentially repair the tool call
475 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
476 stepToolCalls = append(stepToolCalls, validatedToolCall)
477 }
478 }
479
480 toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
481
482 // Build step content with validated tool calls and tool results.
483 // Provider-executed tool calls are kept as-is.
484 stepContent := []Content{}
485 toolCallIndex := 0
486 for _, content := range result.Content {
487 if content.GetType() == ContentTypeToolCall {
488 tc, ok := AsContentType[ToolCallContent](content)
489 if ok && tc.ProviderExecuted {
490 stepContent = append(stepContent, content)
491 continue
492 }
493 // Replace with validated tool call.
494 if toolCallIndex < len(stepToolCalls) {
495 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
496 toolCallIndex++
497 }
498 } else {
499 stepContent = append(stepContent, content)
500 }
501 } // Add tool results
502 for _, result := range toolResults {
503 stepContent = append(stepContent, result)
504 }
505 currentStepMessages := toResponseMessages(stepContent)
506 responseMessages = append(responseMessages, currentStepMessages...)
507
508 stepResult := StepResult{
509 Response: Response{
510 Content: stepContent,
511 FinishReason: result.FinishReason,
512 Usage: result.Usage,
513 Warnings: result.Warnings,
514 ProviderMetadata: result.ProviderMetadata,
515 },
516 Messages: currentStepMessages,
517 }
518 steps = append(steps, stepResult)
519 shouldStop := isStopConditionMet(opts.StopWhen, steps)
520
521 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
522 break
523 }
524 }
525
526 totalUsage := Usage{}
527
528 for _, step := range steps {
529 usage := step.Usage
530 totalUsage.InputTokens += usage.InputTokens
531 totalUsage.OutputTokens += usage.OutputTokens
532 totalUsage.ReasoningTokens += usage.ReasoningTokens
533 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
534 totalUsage.CacheReadTokens += usage.CacheReadTokens
535 totalUsage.TotalTokens += usage.TotalTokens
536 }
537
538 agentResult := &AgentResult{
539 Steps: steps,
540 Response: steps[len(steps)-1].Response,
541 TotalUsage: totalUsage,
542 }
543 return agentResult, nil
544}
545
546func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
547 if len(conditions) == 0 {
548 return false
549 }
550
551 for _, condition := range conditions {
552 if condition(steps) {
553 return true
554 }
555 }
556 return false
557}
558
559func toResponseMessages(content []Content) []Message {
560 var assistantParts []MessagePart
561 var toolParts []MessagePart
562
563 for _, c := range content {
564 switch c.GetType() {
565 case ContentTypeText:
566 text, ok := AsContentType[TextContent](c)
567 if !ok {
568 continue
569 }
570 assistantParts = append(assistantParts, TextPart{
571 Text: text.Text,
572 ProviderOptions: ProviderOptions(text.ProviderMetadata),
573 })
574 case ContentTypeReasoning:
575 reasoning, ok := AsContentType[ReasoningContent](c)
576 if !ok {
577 continue
578 }
579 assistantParts = append(assistantParts, ReasoningPart{
580 Text: reasoning.Text,
581 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
582 })
583 case ContentTypeToolCall:
584 toolCall, ok := AsContentType[ToolCallContent](c)
585 if !ok {
586 continue
587 }
588 assistantParts = append(assistantParts, ToolCallPart{
589 ToolCallID: toolCall.ToolCallID,
590 ToolName: toolCall.ToolName,
591 Input: toolCall.Input,
592 ProviderExecuted: toolCall.ProviderExecuted,
593 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
594 })
595 case ContentTypeFile:
596 file, ok := AsContentType[FileContent](c)
597 if !ok {
598 continue
599 }
600 assistantParts = append(assistantParts, FilePart{
601 Data: file.Data,
602 MediaType: file.MediaType,
603 ProviderOptions: ProviderOptions(file.ProviderMetadata),
604 })
605 case ContentTypeSource:
606 // Sources are metadata about references used to generate the response.
607 // They don't need to be included in the conversation messages.
608 continue
609 case ContentTypeToolResult:
610 result, ok := AsContentType[ToolResultContent](c)
611 if !ok {
612 continue
613 }
614 resultPart := ToolResultPart{
615 ToolCallID: result.ToolCallID,
616 Output: result.Result,
617 ProviderExecuted: result.ProviderExecuted,
618 ProviderOptions: ProviderOptions(result.ProviderMetadata),
619 }
620 if result.ProviderExecuted {
621 // Provider-executed tool results (e.g. web search)
622 // belong in the assistant message alongside the
623 // server_tool_use block that produced them.
624 assistantParts = append(assistantParts, resultPart)
625 } else {
626 toolParts = append(toolParts, resultPart)
627 }
628 }
629 }
630
631 var messages []Message
632 if len(assistantParts) > 0 {
633 messages = append(messages, Message{
634 Role: MessageRoleAssistant,
635 Content: assistantParts,
636 })
637 }
638 if len(toolParts) > 0 {
639 messages = append(messages, Message{
640 Role: MessageRoleTool,
641 Content: toolParts,
642 })
643 }
644 return messages
645}
646
647func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
648 if len(toolCalls) == 0 {
649 return nil, nil
650 }
651
652 // Create a map for quick tool lookup
653 toolMap := make(map[string]AgentTool)
654 for _, tool := range allTools {
655 toolMap[tool.Info().Name] = tool
656 }
657
658 // Execute all tool calls sequentially in order
659 results := make([]ToolResultContent, 0, len(toolCalls))
660
661 for _, toolCall := range toolCalls {
662 result, isCriticalError := a.executeSingleTool(ctx, toolMap, toolCall, toolResultCallback)
663 results = append(results, result)
664 if isCriticalError {
665 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
666 return nil, errorResult.Error
667 }
668 }
669 }
670
671 return results, nil
672}
673
674// executeSingleTool executes a single tool and returns its result and a critical error flag.
675func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
676 result := ToolResultContent{
677 ToolCallID: toolCall.ToolCallID,
678 ToolName: toolCall.ToolName,
679 ProviderExecuted: false,
680 }
681
682 // Skip invalid tool calls - create error result (not critical)
683 if toolCall.Invalid {
684 result.Result = ToolResultOutputContentError{
685 Error: toolCall.ValidationError,
686 }
687 if toolResultCallback != nil {
688 _ = toolResultCallback(result)
689 }
690 return result, false
691 }
692
693 tool, exists := toolMap[toolCall.ToolName]
694 if !exists {
695 result.Result = ToolResultOutputContentError{
696 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
697 }
698 if toolResultCallback != nil {
699 _ = toolResultCallback(result)
700 }
701 return result, false
702 }
703
704 // Execute the tool
705 toolResult, err := tool.Run(ctx, ToolCall{
706 ID: toolCall.ToolCallID,
707 Name: toolCall.ToolName,
708 Input: toolCall.Input,
709 })
710 if err != nil {
711 result.Result = ToolResultOutputContentError{
712 Error: err,
713 }
714 result.ClientMetadata = toolResult.Metadata
715 if toolResultCallback != nil {
716 _ = toolResultCallback(result)
717 }
718 return result, true
719 }
720
721 result.ClientMetadata = toolResult.Metadata
722 if toolResult.IsError {
723 result.Result = ToolResultOutputContentError{
724 Error: errors.New(toolResult.Content),
725 }
726 } else if toolResult.Type == "image" || toolResult.Type == "media" {
727 result.Result = ToolResultOutputContentMedia{
728 Data: string(toolResult.Data),
729 MediaType: toolResult.MediaType,
730 Text: toolResult.Content,
731 }
732 } else {
733 result.Result = ToolResultOutputContentText{
734 Text: toolResult.Content,
735 }
736 }
737 if toolResultCallback != nil {
738 _ = toolResultCallback(result)
739 }
740 return result, false
741}
742
743// Stream implements Agent.
744func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
745 // Convert AgentStreamCall to AgentCall for preparation
746 call := AgentCall{
747 Prompt: opts.Prompt,
748 Files: opts.Files,
749 Messages: opts.Messages,
750 MaxOutputTokens: opts.MaxOutputTokens,
751 Temperature: opts.Temperature,
752 TopP: opts.TopP,
753 TopK: opts.TopK,
754 PresencePenalty: opts.PresencePenalty,
755 FrequencyPenalty: opts.FrequencyPenalty,
756 ActiveTools: opts.ActiveTools,
757 ProviderOptions: opts.ProviderOptions,
758 MaxRetries: opts.MaxRetries,
759 OnRetry: opts.OnRetry,
760 StopWhen: opts.StopWhen,
761 PrepareStep: opts.PrepareStep,
762 RepairToolCall: opts.RepairToolCall,
763 }
764
765 call = a.prepareCall(call)
766
767 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
768 if err != nil {
769 return nil, err
770 }
771
772 var responseMessages []Message
773 var steps []StepResult
774 var totalUsage Usage
775
776 // Start agent stream
777 if opts.OnAgentStart != nil {
778 opts.OnAgentStart()
779 }
780
781 for stepNumber := 0; ; stepNumber++ {
782 stepInputMessages := append(initialPrompt, responseMessages...)
783 stepModel := a.settings.model
784 stepSystemPrompt := a.settings.systemPrompt
785 stepActiveTools := call.ActiveTools
786 stepToolChoice := ToolChoiceAuto
787 disableAllTools := false
788 stepTools := a.settings.tools
789 // Apply step preparation if provided
790 if call.PrepareStep != nil {
791 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
792 Model: stepModel,
793 Steps: steps,
794 StepNumber: stepNumber,
795 Messages: stepInputMessages,
796 })
797 if err != nil {
798 return nil, err
799 }
800
801 ctx = updatedCtx
802
803 if prepared.Messages != nil {
804 stepInputMessages = prepared.Messages
805 }
806 if prepared.Model != nil {
807 stepModel = prepared.Model
808 }
809 if prepared.System != nil {
810 stepSystemPrompt = *prepared.System
811 }
812 if prepared.ToolChoice != nil {
813 stepToolChoice = *prepared.ToolChoice
814 }
815 if len(prepared.ActiveTools) > 0 {
816 stepActiveTools = prepared.ActiveTools
817 }
818 disableAllTools = prepared.DisableAllTools
819 if prepared.Tools != nil {
820 stepTools = prepared.Tools
821 }
822 }
823
824 // Recreate prompt with potentially modified system prompt
825 if stepSystemPrompt != a.settings.systemPrompt {
826 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
827 if err != nil {
828 return nil, err
829 }
830 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
831 stepInputMessages[0] = stepPrompt[0]
832 }
833 }
834
835 preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)
836
837 // Start step stream
838 if opts.OnStepStart != nil {
839 _ = opts.OnStepStart(stepNumber)
840 }
841
842 // Create streaming call
843 streamCall := Call{
844 Prompt: stepInputMessages,
845 MaxOutputTokens: call.MaxOutputTokens,
846 Temperature: call.Temperature,
847 TopP: call.TopP,
848 TopK: call.TopK,
849 PresencePenalty: call.PresencePenalty,
850 FrequencyPenalty: call.FrequencyPenalty,
851 Tools: preparedTools,
852 ToolChoice: &stepToolChoice,
853 UserAgent: a.settings.userAgent,
854 ProviderOptions: call.ProviderOptions,
855 }
856
857 // Execute step with retry logic wrapping both stream creation and processing
858 retryOptions := DefaultRetryOptions()
859 if call.MaxRetries != nil {
860 retryOptions.MaxRetries = *call.MaxRetries
861 }
862 retryOptions.OnRetry = call.OnRetry
863 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
864
865 result, err := retry(ctx, func() (stepExecutionResult, error) {
866 // Create the stream
867 stream, err := stepModel.Stream(ctx, streamCall)
868 if err != nil {
869 return stepExecutionResult{}, err
870 }
871
872 // Process the stream
873 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
874 if err != nil {
875 return stepExecutionResult{}, err
876 }
877
878 return result, nil
879 })
880 if err != nil {
881 if opts.OnError != nil {
882 opts.OnError(err)
883 }
884 return nil, err
885 }
886
887 steps = append(steps, result.StepResult)
888 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
889
890 // Call step finished callback
891 if opts.OnStepFinish != nil {
892 _ = opts.OnStepFinish(result.StepResult)
893 }
894
895 // Add step messages to response messages
896 stepMessages := toResponseMessages(result.StepResult.Content)
897 responseMessages = append(responseMessages, stepMessages...)
898
899 // Check stop conditions
900 shouldStop := isStopConditionMet(call.StopWhen, steps)
901 if shouldStop || !result.ShouldContinue {
902 break
903 }
904 }
905
906 // Finish agent stream
907 agentResult := &AgentResult{
908 Steps: steps,
909 Response: steps[len(steps)-1].Response,
910 TotalUsage: totalUsage,
911 }
912
913 if opts.OnFinish != nil {
914 opts.OnFinish(agentResult)
915 }
916
917 if opts.OnAgentFinish != nil {
918 _ = opts.OnAgentFinish(agentResult)
919 }
920
921 return agentResult, nil
922}
923
924func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
925 preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
926
927 // If explicitly disabling all tools, return no tools
928 if disableAllTools {
929 return preparedTools
930 }
931
932 for _, tool := range tools {
933 // If activeTools has items, only include tools in the list
934 // If activeTools is empty, include all tools
935 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
936 continue
937 }
938 info := tool.Info()
939 inputSchema := map[string]any{
940 "type": "object",
941 "properties": info.Parameters,
942 "required": info.Required,
943 }
944 schema.Normalize(inputSchema)
945 preparedTools = append(preparedTools, FunctionTool{
946 Name: info.Name,
947 Description: info.Description,
948 InputSchema: inputSchema,
949 ProviderOptions: tool.ProviderOptions(),
950 })
951 }
952 for _, tool := range providerDefinedTools {
953 // If activeTools has items, only include tools in the list. If
954 // activeTools is empty, include all tools
955 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
956 continue
957 }
958 preparedTools = append(preparedTools, tool)
959 }
960 return preparedTools
961}
962
963// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
964func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
965 if err := a.validateToolCall(toolCall, availableTools); err == nil {
966 return toolCall
967 } else { //nolint: revive
968 if repairFunc != nil {
969 repairOptions := ToolCallRepairOptions{
970 OriginalToolCall: toolCall,
971 ValidationError: err,
972 AvailableTools: availableTools,
973 SystemPrompt: systemPrompt,
974 Messages: messages,
975 }
976
977 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
978 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
979 return *repairedToolCall
980 }
981 }
982 }
983
984 invalidToolCall := toolCall
985 invalidToolCall.Invalid = true
986 invalidToolCall.ValidationError = err
987 return invalidToolCall
988 }
989}
990
991// validateToolCall validates a tool call against available tools and their schemas.
992func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
993 var tool AgentTool
994 for _, t := range availableTools {
995 if t.Info().Name == toolCall.ToolName {
996 tool = t
997 break
998 }
999 }
1000
1001 if tool == nil {
1002 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1003 }
1004
1005 // Validate JSON parsing
1006 var input map[string]any
1007 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1008 return fmt.Errorf("invalid JSON input: %w", err)
1009 }
1010
1011 // Basic schema validation (check required fields)
1012 // TODO: more robust schema validation using JSON Schema or similar
1013 toolInfo := tool.Info()
1014 for _, required := range toolInfo.Required {
1015 if _, exists := input[required]; !exists {
1016 return fmt.Errorf("missing required parameter: %s", required)
1017 }
1018 }
1019 return nil
1020}
1021
1022func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1023 // Validation: empty prompt is only allowed when there are messages,
1024 // no files to attach, and the last message is a user or tool message.
1025 if prompt == "" {
1026 lastMessage, hasMessages := slice.Last(messages)
1027
1028 if !hasMessages {
1029 return nil, &Error{
1030 Title: "invalid argument",
1031 Message: "prompt can't be empty when there are no messages",
1032 }
1033 }
1034
1035 if len(files) > 0 {
1036 return nil, &Error{
1037 Title: "invalid argument",
1038 Message: "prompt can't be empty when there are files",
1039 }
1040 }
1041
1042 switch lastMessage.Role {
1043 case MessageRoleUser, MessageRoleTool:
1044 default:
1045 return nil, &Error{
1046 Title: "invalid argument",
1047 Message: "prompt can't be empty when the last message is not a user or tool message",
1048 }
1049 }
1050 }
1051
1052 var preparedPrompt Prompt
1053
1054 if system != "" {
1055 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1056 }
1057 preparedPrompt = append(preparedPrompt, messages...)
1058 if prompt != "" {
1059 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1060 }
1061 return preparedPrompt, nil
1062}
1063
1064// WithSystemPrompt sets the system prompt for the agent.
1065func WithSystemPrompt(prompt string) AgentOption {
1066 return func(s *agentSettings) {
1067 s.systemPrompt = prompt
1068 }
1069}
1070
1071// WithMaxOutputTokens sets the maximum output tokens for the agent.
1072func WithMaxOutputTokens(tokens int64) AgentOption {
1073 return func(s *agentSettings) {
1074 s.maxOutputTokens = &tokens
1075 }
1076}
1077
1078// WithTemperature sets the temperature for the agent.
1079func WithTemperature(temp float64) AgentOption {
1080 return func(s *agentSettings) {
1081 s.temperature = &temp
1082 }
1083}
1084
1085// WithTopP sets the top-p value for the agent.
1086func WithTopP(topP float64) AgentOption {
1087 return func(s *agentSettings) {
1088 s.topP = &topP
1089 }
1090}
1091
1092// WithTopK sets the top-k value for the agent.
1093func WithTopK(topK int64) AgentOption {
1094 return func(s *agentSettings) {
1095 s.topK = &topK
1096 }
1097}
1098
1099// WithPresencePenalty sets the presence penalty for the agent.
1100func WithPresencePenalty(penalty float64) AgentOption {
1101 return func(s *agentSettings) {
1102 s.presencePenalty = &penalty
1103 }
1104}
1105
1106// WithFrequencyPenalty sets the frequency penalty for the agent.
1107func WithFrequencyPenalty(penalty float64) AgentOption {
1108 return func(s *agentSettings) {
1109 s.frequencyPenalty = &penalty
1110 }
1111}
1112
1113// WithTools sets the tools for the agent.
1114func WithTools(tools ...AgentTool) AgentOption {
1115 return func(s *agentSettings) {
1116 s.tools = append(s.tools, tools...)
1117 }
1118}
1119
1120// WithProviderDefinedTools sets the provider-defined tools for the agent.
1121// These tools are executed by the provider (e.g. web search) rather
1122// than by the client.
1123func WithProviderDefinedTools(tools ...ProviderDefinedTool) AgentOption {
1124 return func(s *agentSettings) {
1125 s.providerDefinedTools = append(s.providerDefinedTools, tools...)
1126 }
1127}
1128
1129// WithStopConditions sets the stop conditions for the agent.
1130func WithStopConditions(conditions ...StopCondition) AgentOption {
1131 return func(s *agentSettings) {
1132 s.stopWhen = append(s.stopWhen, conditions...)
1133 }
1134}
1135
1136// WithPrepareStep sets the prepare step function for the agent.
1137func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1138 return func(s *agentSettings) {
1139 s.prepareStep = fn
1140 }
1141}
1142
1143// WithRepairToolCall sets the repair tool call function for the agent.
1144func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1145 return func(s *agentSettings) {
1146 s.repairToolCall = fn
1147 }
1148}
1149
1150// WithMaxRetries sets the maximum number of retries for the agent.
1151func WithMaxRetries(maxRetries int) AgentOption {
1152 return func(s *agentSettings) {
1153 s.maxRetries = &maxRetries
1154 }
1155}
1156
1157// WithOnRetry sets the retry callback for the agent.
1158func WithOnRetry(callback OnRetryCallback) AgentOption {
1159 return func(s *agentSettings) {
1160 s.onRetry = callback
1161 }
1162}
1163
1164// processStepStream processes a single step's stream and returns the step result.
1165func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1166 var stepContent []Content
1167 var stepToolCalls []ToolCallContent
1168 var stepUsage Usage
1169 stepFinishReason := FinishReasonUnknown
1170 var stepWarnings []CallWarning
1171 var stepProviderMetadata ProviderMetadata
1172
1173 activeToolCalls := make(map[string]*ToolCallContent)
1174 activeTextContent := make(map[string]string)
1175 type reasoningContent struct {
1176 content string
1177 options ProviderMetadata
1178 }
1179 activeReasoningContent := make(map[string]reasoningContent)
1180
1181 // Set up concurrent tool execution
1182 type toolExecutionRequest struct {
1183 toolCall ToolCallContent
1184 parallel bool
1185 }
1186 toolChan := make(chan toolExecutionRequest, 10)
1187 var toolExecutionWg sync.WaitGroup
1188 var toolStateMu sync.Mutex
1189 toolResults := make([]ToolResultContent, 0)
1190 var toolExecutionErr error
1191
1192 // Create a map for quick tool lookup
1193 toolMap := make(map[string]AgentTool)
1194 for _, tool := range stepTools {
1195 toolMap[tool.Info().Name] = tool
1196 }
1197
1198 // Semaphores for controlling parallelism
1199 parallelSem := make(chan struct{}, 5)
1200 var sequentialMu sync.Mutex
1201
1202 // Single coordinator goroutine that dispatches tools
1203 toolExecutionWg.Go(func() {
1204 for req := range toolChan {
1205 if req.parallel {
1206 parallelSem <- struct{}{}
1207 toolExecutionWg.Go(func() {
1208 defer func() { <-parallelSem }()
1209 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1210 toolStateMu.Lock()
1211 toolResults = append(toolResults, result)
1212 if isCriticalError && toolExecutionErr == nil {
1213 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1214 toolExecutionErr = errorResult.Error
1215 }
1216 }
1217 toolStateMu.Unlock()
1218 })
1219 } else {
1220 sequentialMu.Lock()
1221 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1222 toolStateMu.Lock()
1223 toolResults = append(toolResults, result)
1224 if isCriticalError && toolExecutionErr == nil {
1225 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1226 toolExecutionErr = errorResult.Error
1227 }
1228 }
1229 toolStateMu.Unlock()
1230 sequentialMu.Unlock()
1231 }
1232 }
1233 })
1234
1235 // Process stream parts
1236 for part := range stream {
1237 // Forward all parts to chunk callback
1238 if opts.OnChunk != nil {
1239 err := opts.OnChunk(part)
1240 if err != nil {
1241 return stepExecutionResult{}, err
1242 }
1243 }
1244
1245 switch part.Type {
1246 case StreamPartTypeWarnings:
1247 stepWarnings = part.Warnings
1248 if opts.OnWarnings != nil {
1249 err := opts.OnWarnings(part.Warnings)
1250 if err != nil {
1251 return stepExecutionResult{}, err
1252 }
1253 }
1254
1255 case StreamPartTypeTextStart:
1256 activeTextContent[part.ID] = ""
1257 if opts.OnTextStart != nil {
1258 err := opts.OnTextStart(part.ID)
1259 if err != nil {
1260 return stepExecutionResult{}, err
1261 }
1262 }
1263
1264 case StreamPartTypeTextDelta:
1265 if _, exists := activeTextContent[part.ID]; exists {
1266 activeTextContent[part.ID] += part.Delta
1267 }
1268 if opts.OnTextDelta != nil {
1269 err := opts.OnTextDelta(part.ID, part.Delta)
1270 if err != nil {
1271 return stepExecutionResult{}, err
1272 }
1273 }
1274
1275 case StreamPartTypeTextEnd:
1276 if text, exists := activeTextContent[part.ID]; exists {
1277 stepContent = append(stepContent, TextContent{
1278 Text: text,
1279 ProviderMetadata: part.ProviderMetadata,
1280 })
1281 delete(activeTextContent, part.ID)
1282 }
1283 if opts.OnTextEnd != nil {
1284 err := opts.OnTextEnd(part.ID)
1285 if err != nil {
1286 return stepExecutionResult{}, err
1287 }
1288 }
1289
1290 case StreamPartTypeReasoningStart:
1291 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1292 if opts.OnReasoningStart != nil {
1293 content := ReasoningContent{
1294 Text: part.Delta,
1295 ProviderMetadata: part.ProviderMetadata,
1296 }
1297 err := opts.OnReasoningStart(part.ID, content)
1298 if err != nil {
1299 return stepExecutionResult{}, err
1300 }
1301 }
1302
1303 case StreamPartTypeReasoningDelta:
1304 if active, exists := activeReasoningContent[part.ID]; exists {
1305 active.content += part.Delta
1306 if part.ProviderMetadata != nil {
1307 active.options = part.ProviderMetadata
1308 }
1309 activeReasoningContent[part.ID] = active
1310 }
1311 if opts.OnReasoningDelta != nil {
1312 err := opts.OnReasoningDelta(part.ID, part.Delta)
1313 if err != nil {
1314 return stepExecutionResult{}, err
1315 }
1316 }
1317
1318 case StreamPartTypeReasoningEnd:
1319 if active, exists := activeReasoningContent[part.ID]; exists {
1320 if part.ProviderMetadata != nil {
1321 active.options = part.ProviderMetadata
1322 }
1323 content := ReasoningContent{
1324 Text: active.content,
1325 ProviderMetadata: active.options,
1326 }
1327 stepContent = append(stepContent, content)
1328 if opts.OnReasoningEnd != nil {
1329 err := opts.OnReasoningEnd(part.ID, content)
1330 if err != nil {
1331 return stepExecutionResult{}, err
1332 }
1333 }
1334 delete(activeReasoningContent, part.ID)
1335 }
1336
1337 case StreamPartTypeToolInputStart:
1338 activeToolCalls[part.ID] = &ToolCallContent{
1339 ToolCallID: part.ID,
1340 ToolName: part.ToolCallName,
1341 Input: "",
1342 ProviderExecuted: part.ProviderExecuted,
1343 }
1344 if opts.OnToolInputStart != nil {
1345 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1346 if err != nil {
1347 return stepExecutionResult{}, err
1348 }
1349 }
1350
1351 case StreamPartTypeToolInputDelta:
1352 if toolCall, exists := activeToolCalls[part.ID]; exists {
1353 toolCall.Input += part.Delta
1354 }
1355 if opts.OnToolInputDelta != nil {
1356 err := opts.OnToolInputDelta(part.ID, part.Delta)
1357 if err != nil {
1358 return stepExecutionResult{}, err
1359 }
1360 }
1361
1362 case StreamPartTypeToolInputEnd:
1363 if opts.OnToolInputEnd != nil {
1364 err := opts.OnToolInputEnd(part.ID)
1365 if err != nil {
1366 return stepExecutionResult{}, err
1367 }
1368 }
1369
1370 case StreamPartTypeToolCall:
1371 toolCall := ToolCallContent{
1372 ToolCallID: part.ID,
1373 ToolName: part.ToolCallName,
1374 Input: part.ToolCallInput,
1375 ProviderExecuted: part.ProviderExecuted,
1376 ProviderMetadata: part.ProviderMetadata,
1377 }
1378
1379 // Provider-executed tool calls are handled by the provider
1380 // and should not be validated or executed by the agent.
1381 if toolCall.ProviderExecuted {
1382 stepContent = append(stepContent, toolCall)
1383 if opts.OnToolCall != nil {
1384 err := opts.OnToolCall(toolCall)
1385 if err != nil {
1386 return stepExecutionResult{}, err
1387 }
1388 }
1389 delete(activeToolCalls, part.ID)
1390 } else {
1391 // Validate and potentially repair the tool call
1392 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1393 stepToolCalls = append(stepToolCalls, validatedToolCall)
1394 stepContent = append(stepContent, validatedToolCall)
1395
1396 if opts.OnToolCall != nil {
1397 err := opts.OnToolCall(validatedToolCall)
1398 if err != nil {
1399 return stepExecutionResult{}, err
1400 }
1401 }
1402
1403 // Determine if tool can run in parallel
1404 isParallel := false
1405 if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1406 isParallel = tool.Info().Parallel
1407 }
1408
1409 // Send tool call to execution channel
1410 toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1411
1412 // Clean up active tool call
1413 delete(activeToolCalls, part.ID)
1414 }
1415
1416 case StreamPartTypeToolResult:
1417 // Provider-executed tool results (e.g. web search)
1418 // are emitted by the provider and added directly
1419 // to the step content for multi-turn round-tripping.
1420 if part.ProviderExecuted {
1421 resultContent := ToolResultContent{
1422 ToolCallID: part.ID,
1423 ToolName: part.ToolCallName,
1424 ProviderExecuted: true,
1425 ProviderMetadata: part.ProviderMetadata,
1426 }
1427 stepContent = append(stepContent, resultContent)
1428 if opts.OnToolResult != nil {
1429 err := opts.OnToolResult(resultContent)
1430 if err != nil {
1431 return stepExecutionResult{}, err
1432 }
1433 }
1434 }
1435
1436 case StreamPartTypeSource:
1437 sourceContent := SourceContent{
1438 SourceType: part.SourceType,
1439 ID: part.ID,
1440 URL: part.URL,
1441 Title: part.Title,
1442 ProviderMetadata: part.ProviderMetadata,
1443 }
1444 stepContent = append(stepContent, sourceContent)
1445 if opts.OnSource != nil {
1446 err := opts.OnSource(sourceContent)
1447 if err != nil {
1448 return stepExecutionResult{}, err
1449 }
1450 }
1451
1452 case StreamPartTypeFinish:
1453 stepUsage = part.Usage
1454 stepFinishReason = part.FinishReason
1455 stepProviderMetadata = part.ProviderMetadata
1456 if opts.OnStreamFinish != nil {
1457 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1458 if err != nil {
1459 return stepExecutionResult{}, err
1460 }
1461 }
1462
1463 case StreamPartTypeError:
1464 return stepExecutionResult{}, part.Error
1465 }
1466 }
1467
1468 // Close the tool execution channel and wait for all executions to complete
1469 close(toolChan)
1470 toolExecutionWg.Wait()
1471
1472 // Check for tool execution errors
1473 if toolExecutionErr != nil {
1474 return stepExecutionResult{}, toolExecutionErr
1475 }
1476
1477 // Add tool results to content if any
1478 if len(toolResults) > 0 {
1479 for _, result := range toolResults {
1480 stepContent = append(stepContent, result)
1481 }
1482 }
1483
1484 stepResult := StepResult{
1485 Response: Response{
1486 Content: stepContent,
1487 FinishReason: stepFinishReason,
1488 Usage: stepUsage,
1489 Warnings: stepWarnings,
1490 ProviderMetadata: stepProviderMetadata,
1491 },
1492 Messages: toResponseMessages(stepContent),
1493 }
1494
1495 // Determine if we should continue (has tool calls and not stopped)
1496 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1497
1498 return stepExecutionResult{
1499 StepResult: stepResult,
1500 ShouldContinue: shouldContinue,
1501 }, nil
1502}
1503
1504func addUsage(a, b Usage) Usage {
1505 return Usage{
1506 InputTokens: a.InputTokens + b.InputTokens,
1507 OutputTokens: a.OutputTokens + b.OutputTokens,
1508 TotalTokens: a.TotalTokens + b.TotalTokens,
1509 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1510 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1511 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1512 }
1513}
1514
1515// WithHeaders sets the headers for the agent.
1516func WithHeaders(headers map[string]string) AgentOption {
1517 return func(s *agentSettings) {
1518 s.headers = headers
1519 }
1520}
1521
1522// WithUserAgent sets the User-Agent header for the agent. This overrides any
1523// provider-level User-Agent setting.
1524func WithUserAgent(ua string) AgentOption {
1525 return func(s *agentSettings) {
1526 s.userAgent = ua
1527 }
1528}
1529
1530// WithProviderOptions sets the provider options for the agent.
1531func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1532 return func(s *agentSettings) {
1533 s.providerOptions = providerOptions
1534 }
1535}