1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11)
12
13// StepResult represents the result of a single step in an agent execution.
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19// stepExecutionResult encapsulates the result of executing a step with stream processing.
20type stepExecutionResult struct {
21 StepResult StepResult
22 ShouldContinue bool
23}
24
25// StopCondition defines a function that determines when an agent should stop executing.
26type StopCondition = func(steps []StepResult) bool
27
28// StepCountIs returns a stop condition that stops after the specified number of steps.
29func StepCountIs(stepCount int) StopCondition {
30 return func(steps []StepResult) bool {
31 return len(steps) >= stepCount
32 }
33}
34
35// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
36func HasToolCall(toolName string) StopCondition {
37 return func(steps []StepResult) bool {
38 if len(steps) == 0 {
39 return false
40 }
41 lastStep := steps[len(steps)-1]
42 toolCalls := lastStep.Content.ToolCalls()
43 for _, toolCall := range toolCalls {
44 if toolCall.ToolName == toolName {
45 return true
46 }
47 }
48 return false
49 }
50}
51
52// HasContent returns a stop condition that stops when the specified content type appears in the last step.
53func HasContent(contentType ContentType) StopCondition {
54 return func(steps []StepResult) bool {
55 if len(steps) == 0 {
56 return false
57 }
58 lastStep := steps[len(steps)-1]
59 for _, content := range lastStep.Content {
60 if content.GetType() == contentType {
61 return true
62 }
63 }
64 return false
65 }
66}
67
68// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
69func FinishReasonIs(reason FinishReason) StopCondition {
70 return func(steps []StepResult) bool {
71 if len(steps) == 0 {
72 return false
73 }
74 lastStep := steps[len(steps)-1]
75 return lastStep.FinishReason == reason
76 }
77}
78
79// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
80func MaxTokensUsed(maxTokens int64) StopCondition {
81 return func(steps []StepResult) bool {
82 var totalTokens int64
83 for _, step := range steps {
84 totalTokens += step.Usage.TotalTokens
85 }
86 return totalTokens >= maxTokens
87 }
88}
89
90// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
91type PrepareStepFunctionOptions struct {
92 Steps []StepResult
93 StepNumber int
94 Model LanguageModel
95 Messages []Message
96}
97
98// PrepareStepResult contains the result of preparing a step in an agent execution.
99type PrepareStepResult struct {
100 Model LanguageModel
101 Messages []Message
102 System *string
103 ToolChoice *ToolChoice
104 ActiveTools []string
105 DisableAllTools bool
106 Tools []AgentTool
107}
108
109// ToolCallRepairOptions contains the options for repairing a tool call.
110type ToolCallRepairOptions struct {
111 OriginalToolCall ToolCallContent
112 ValidationError error
113 AvailableTools []AgentTool
114 SystemPrompt string
115 Messages []Message
116}
117
118type (
119 // PrepareStepFunction defines a function that prepares a step in an agent execution.
120 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
121
122 // OnStepFinishedFunction defines a function that is called when a step finishes.
123 OnStepFinishedFunction = func(step StepResult)
124
125 // RepairToolCallFunction defines a function that repairs a tool call.
126 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
127)
128
129type agentSettings struct {
130 systemPrompt string
131 maxOutputTokens *int64
132 temperature *float64
133 topP *float64
134 topK *int64
135 presencePenalty *float64
136 frequencyPenalty *float64
137 headers map[string]string
138 providerOptions ProviderOptions
139
140 // TODO: add support for provider tools
141 tools []AgentTool
142 maxRetries *int
143
144 model LanguageModel
145
146 stopWhen []StopCondition
147 prepareStep PrepareStepFunction
148 repairToolCall RepairToolCallFunction
149 onRetry OnRetryCallback
150}
151
152// AgentCall represents a call to an agent.
153type AgentCall struct {
154 Prompt string `json:"prompt"`
155 Files []FilePart `json:"files"`
156 Messages []Message `json:"messages"`
157 MaxOutputTokens *int64
158 Temperature *float64 `json:"temperature"`
159 TopP *float64 `json:"top_p"`
160 TopK *int64 `json:"top_k"`
161 PresencePenalty *float64 `json:"presence_penalty"`
162 FrequencyPenalty *float64 `json:"frequency_penalty"`
163 ActiveTools []string `json:"active_tools"`
164 ProviderOptions ProviderOptions
165 OnRetry OnRetryCallback
166 MaxRetries *int
167
168 StopWhen []StopCondition
169 PrepareStep PrepareStepFunction
170 RepairToolCall RepairToolCallFunction
171}
172
173// Agent-level callbacks.
174type (
175 // OnAgentStartFunc is called when agent starts.
176 OnAgentStartFunc func()
177
178 // OnAgentFinishFunc is called when agent finishes.
179 OnAgentFinishFunc func(result *AgentResult) error
180
181 // OnStepStartFunc is called when a step starts.
182 OnStepStartFunc func(stepNumber int) error
183
184 // OnStepFinishFunc is called when a step finishes.
185 OnStepFinishFunc func(stepResult StepResult) error
186
187 // OnFinishFunc is called when entire agent completes.
188 OnFinishFunc func(result *AgentResult)
189
190 // OnErrorFunc is called when an error occurs.
191 OnErrorFunc func(error)
192)
193
194// Stream part callbacks - called for each corresponding stream part type.
195type (
196 // OnChunkFunc is called for each stream part (catch-all).
197 OnChunkFunc func(StreamPart) error
198
199 // OnWarningsFunc is called for warnings.
200 OnWarningsFunc func(warnings []CallWarning) error
201
202 // OnTextStartFunc is called when text starts.
203 OnTextStartFunc func(id string) error
204
205 // OnTextDeltaFunc is called for text deltas.
206 OnTextDeltaFunc func(id, text string) error
207
208 // OnTextEndFunc is called when text ends.
209 OnTextEndFunc func(id string) error
210
211 // OnReasoningStartFunc is called when reasoning starts.
212 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
213
214 // OnReasoningDeltaFunc is called for reasoning deltas.
215 OnReasoningDeltaFunc func(id, text string) error
216
217 // OnReasoningEndFunc is called when reasoning ends.
218 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
219
220 // OnToolInputStartFunc is called when tool input starts.
221 OnToolInputStartFunc func(id, toolName string) error
222
223 // OnToolInputDeltaFunc is called for tool input deltas.
224 OnToolInputDeltaFunc func(id, delta string) error
225
226 // OnToolInputEndFunc is called when tool input ends.
227 OnToolInputEndFunc func(id string) error
228
229 // OnToolCallFunc is called when tool call is complete.
230 OnToolCallFunc func(toolCall ToolCallContent) error
231
232 // OnToolResultFunc is called when tool execution completes.
233 OnToolResultFunc func(result ToolResultContent) error
234
235 // OnSourceFunc is called for source references.
236 OnSourceFunc func(source SourceContent) error
237
238 // OnStreamFinishFunc is called when stream finishes.
239 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
240)
241
242// AgentStreamCall represents a streaming call to an agent.
243type AgentStreamCall struct {
244 Prompt string `json:"prompt"`
245 Files []FilePart `json:"files"`
246 Messages []Message `json:"messages"`
247 MaxOutputTokens *int64
248 Temperature *float64 `json:"temperature"`
249 TopP *float64 `json:"top_p"`
250 TopK *int64 `json:"top_k"`
251 PresencePenalty *float64 `json:"presence_penalty"`
252 FrequencyPenalty *float64 `json:"frequency_penalty"`
253 ActiveTools []string `json:"active_tools"`
254 Headers map[string]string
255 ProviderOptions ProviderOptions
256 OnRetry OnRetryCallback
257 MaxRetries *int
258
259 StopWhen []StopCondition
260 PrepareStep PrepareStepFunction
261 RepairToolCall RepairToolCallFunction
262
263 // Agent-level callbacks
264 OnAgentStart OnAgentStartFunc // Called when agent starts
265 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
266 OnStepStart OnStepStartFunc // Called when a step starts
267 OnStepFinish OnStepFinishFunc // Called when a step finishes
268 OnFinish OnFinishFunc // Called when entire agent completes
269 OnError OnErrorFunc // Called when an error occurs
270
271 // Stream part callbacks - called for each corresponding stream part type
272 OnChunk OnChunkFunc // Called for each stream part (catch-all)
273 OnWarnings OnWarningsFunc // Called for warnings
274 OnTextStart OnTextStartFunc // Called when text starts
275 OnTextDelta OnTextDeltaFunc // Called for text deltas
276 OnTextEnd OnTextEndFunc // Called when text ends
277 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
278 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
279 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
280 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
281 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
282 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
283 OnToolCall OnToolCallFunc // Called when tool call is complete
284 OnToolResult OnToolResultFunc // Called when tool execution completes
285 OnSource OnSourceFunc // Called for source references
286 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
287}
288
289// AgentResult represents the result of an agent execution.
290type AgentResult struct {
291 Steps []StepResult
292 // Final response
293 Response Response
294 TotalUsage Usage
295}
296
297// Agent represents an AI agent that can generate responses and stream responses.
298type Agent interface {
299 Generate(context.Context, AgentCall) (*AgentResult, error)
300 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
301}
302
303// AgentOption defines a function that configures agent settings.
304type AgentOption = func(*agentSettings)
305
306type agent struct {
307 settings agentSettings
308}
309
310// NewAgent creates a new agent with the given language model and options.
311func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
312 settings := agentSettings{
313 model: model,
314 }
315 for _, o := range opts {
316 o(&settings)
317 }
318 return &agent{
319 settings: settings,
320 }
321}
322
323func (a *agent) prepareCall(call AgentCall) AgentCall {
324 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
325 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
326 call.TopP = cmp.Or(call.TopP, a.settings.topP)
327 call.TopK = cmp.Or(call.TopK, a.settings.topK)
328 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
329 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
330 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
331
332 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
333 call.StopWhen = a.settings.stopWhen
334 }
335 if call.PrepareStep == nil && a.settings.prepareStep != nil {
336 call.PrepareStep = a.settings.prepareStep
337 }
338 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
339 call.RepairToolCall = a.settings.repairToolCall
340 }
341 if call.OnRetry == nil && a.settings.onRetry != nil {
342 call.OnRetry = a.settings.onRetry
343 }
344
345 providerOptions := ProviderOptions{}
346 if a.settings.providerOptions != nil {
347 maps.Copy(providerOptions, a.settings.providerOptions)
348 }
349 if call.ProviderOptions != nil {
350 maps.Copy(providerOptions, call.ProviderOptions)
351 }
352 call.ProviderOptions = providerOptions
353
354 headers := map[string]string{}
355
356 if a.settings.headers != nil {
357 maps.Copy(headers, a.settings.headers)
358 }
359
360 return call
361}
362
363// Generate implements Agent.
364func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
365 opts = a.prepareCall(opts)
366 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
367 if err != nil {
368 return nil, err
369 }
370 var responseMessages []Message
371 var steps []StepResult
372
373 for {
374 stepInputMessages := append(initialPrompt, responseMessages...)
375 stepModel := a.settings.model
376 stepSystemPrompt := a.settings.systemPrompt
377 stepActiveTools := opts.ActiveTools
378 stepToolChoice := ToolChoiceAuto
379 disableAllTools := false
380 stepTools := a.settings.tools
381 if opts.PrepareStep != nil {
382 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
383 Model: stepModel,
384 Steps: steps,
385 StepNumber: len(steps),
386 Messages: stepInputMessages,
387 })
388 if err != nil {
389 return nil, err
390 }
391
392 ctx = updatedCtx
393
394 // Apply prepared step modifications
395 if prepared.Messages != nil {
396 stepInputMessages = prepared.Messages
397 }
398 if prepared.Model != nil {
399 stepModel = prepared.Model
400 }
401 if prepared.System != nil {
402 stepSystemPrompt = *prepared.System
403 }
404 if prepared.ToolChoice != nil {
405 stepToolChoice = *prepared.ToolChoice
406 }
407 if len(prepared.ActiveTools) > 0 {
408 stepActiveTools = prepared.ActiveTools
409 }
410 disableAllTools = prepared.DisableAllTools
411 if prepared.Tools != nil {
412 stepTools = prepared.Tools
413 }
414 }
415
416 // Recreate prompt with potentially modified system prompt
417 if stepSystemPrompt != a.settings.systemPrompt {
418 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
419 if err != nil {
420 return nil, err
421 }
422 // Replace system message part, keep the rest
423 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
424 stepInputMessages[0] = stepPrompt[0] // Replace system message
425 }
426 }
427
428 preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
429
430 retryOptions := DefaultRetryOptions()
431 if opts.MaxRetries != nil {
432 retryOptions.MaxRetries = *opts.MaxRetries
433 }
434 retryOptions.OnRetry = opts.OnRetry
435 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
436
437 result, err := retry(ctx, func() (*Response, error) {
438 return stepModel.Generate(ctx, Call{
439 Prompt: stepInputMessages,
440 MaxOutputTokens: opts.MaxOutputTokens,
441 Temperature: opts.Temperature,
442 TopP: opts.TopP,
443 TopK: opts.TopK,
444 PresencePenalty: opts.PresencePenalty,
445 FrequencyPenalty: opts.FrequencyPenalty,
446 Tools: preparedTools,
447 ToolChoice: &stepToolChoice,
448 ProviderOptions: opts.ProviderOptions,
449 })
450 })
451 if err != nil {
452 return nil, err
453 }
454
455 var stepToolCalls []ToolCallContent
456 for _, content := range result.Content {
457 if content.GetType() == ContentTypeToolCall {
458 toolCall, ok := AsContentType[ToolCallContent](content)
459 if !ok {
460 continue
461 }
462
463 // Validate and potentially repair the tool call
464 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
465 stepToolCalls = append(stepToolCalls, validatedToolCall)
466 }
467 }
468
469 toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
470
471 // Build step content with validated tool calls and tool results
472 stepContent := []Content{}
473 toolCallIndex := 0
474 for _, content := range result.Content {
475 if content.GetType() == ContentTypeToolCall {
476 // Replace with validated tool call
477 if toolCallIndex < len(stepToolCalls) {
478 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
479 toolCallIndex++
480 }
481 } else {
482 // Keep other content as-is
483 stepContent = append(stepContent, content)
484 }
485 }
486 // Add tool results
487 for _, result := range toolResults {
488 stepContent = append(stepContent, result)
489 }
490 currentStepMessages := toResponseMessages(stepContent)
491 responseMessages = append(responseMessages, currentStepMessages...)
492
493 stepResult := StepResult{
494 Response: Response{
495 Content: stepContent,
496 FinishReason: result.FinishReason,
497 Usage: result.Usage,
498 Warnings: result.Warnings,
499 ProviderMetadata: result.ProviderMetadata,
500 },
501 Messages: currentStepMessages,
502 }
503 steps = append(steps, stepResult)
504 shouldStop := isStopConditionMet(opts.StopWhen, steps)
505
506 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
507 break
508 }
509 }
510
511 totalUsage := Usage{}
512
513 for _, step := range steps {
514 usage := step.Usage
515 totalUsage.InputTokens += usage.InputTokens
516 totalUsage.OutputTokens += usage.OutputTokens
517 totalUsage.ReasoningTokens += usage.ReasoningTokens
518 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
519 totalUsage.CacheReadTokens += usage.CacheReadTokens
520 totalUsage.TotalTokens += usage.TotalTokens
521 }
522
523 agentResult := &AgentResult{
524 Steps: steps,
525 Response: steps[len(steps)-1].Response,
526 TotalUsage: totalUsage,
527 }
528 return agentResult, nil
529}
530
531func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
532 if len(conditions) == 0 {
533 return false
534 }
535
536 for _, condition := range conditions {
537 if condition(steps) {
538 return true
539 }
540 }
541 return false
542}
543
544func toResponseMessages(content []Content) []Message {
545 var assistantParts []MessagePart
546 var toolParts []MessagePart
547
548 for _, c := range content {
549 switch c.GetType() {
550 case ContentTypeText:
551 text, ok := AsContentType[TextContent](c)
552 if !ok {
553 continue
554 }
555 assistantParts = append(assistantParts, TextPart{
556 Text: text.Text,
557 ProviderOptions: ProviderOptions(text.ProviderMetadata),
558 })
559 case ContentTypeReasoning:
560 reasoning, ok := AsContentType[ReasoningContent](c)
561 if !ok {
562 continue
563 }
564 assistantParts = append(assistantParts, ReasoningPart{
565 Text: reasoning.Text,
566 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
567 })
568 case ContentTypeToolCall:
569 toolCall, ok := AsContentType[ToolCallContent](c)
570 if !ok {
571 continue
572 }
573 assistantParts = append(assistantParts, ToolCallPart{
574 ToolCallID: toolCall.ToolCallID,
575 ToolName: toolCall.ToolName,
576 Input: toolCall.Input,
577 ProviderExecuted: toolCall.ProviderExecuted,
578 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
579 })
580 case ContentTypeFile:
581 file, ok := AsContentType[FileContent](c)
582 if !ok {
583 continue
584 }
585 assistantParts = append(assistantParts, FilePart{
586 Data: file.Data,
587 MediaType: file.MediaType,
588 ProviderOptions: ProviderOptions(file.ProviderMetadata),
589 })
590 case ContentTypeSource:
591 // Sources are metadata about references used to generate the response.
592 // They don't need to be included in the conversation messages.
593 continue
594 case ContentTypeToolResult:
595 result, ok := AsContentType[ToolResultContent](c)
596 if !ok {
597 continue
598 }
599 toolParts = append(toolParts, ToolResultPart{
600 ToolCallID: result.ToolCallID,
601 Output: result.Result,
602 ProviderOptions: ProviderOptions(result.ProviderMetadata),
603 })
604 }
605 }
606
607 var messages []Message
608 if len(assistantParts) > 0 {
609 messages = append(messages, Message{
610 Role: MessageRoleAssistant,
611 Content: assistantParts,
612 })
613 }
614 if len(toolParts) > 0 {
615 messages = append(messages, Message{
616 Role: MessageRoleTool,
617 Content: toolParts,
618 })
619 }
620 return messages
621}
622
623func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
624 if len(toolCalls) == 0 {
625 return nil, nil
626 }
627
628 // Create a map for quick tool lookup
629 toolMap := make(map[string]AgentTool)
630 for _, tool := range allTools {
631 toolMap[tool.Info().Name] = tool
632 }
633
634 // Execute all tool calls sequentially in order
635 results := make([]ToolResultContent, 0, len(toolCalls))
636
637 for _, toolCall := range toolCalls {
638 // Skip invalid tool calls - create error result
639 if toolCall.Invalid {
640 result := ToolResultContent{
641 ToolCallID: toolCall.ToolCallID,
642 ToolName: toolCall.ToolName,
643 Result: ToolResultOutputContentError{
644 Error: toolCall.ValidationError,
645 },
646 ProviderExecuted: false,
647 }
648 results = append(results, result)
649 if toolResultCallback != nil {
650 if err := toolResultCallback(result); err != nil {
651 return nil, err
652 }
653 }
654 continue
655 }
656
657 tool, exists := toolMap[toolCall.ToolName]
658 if !exists {
659 result := ToolResultContent{
660 ToolCallID: toolCall.ToolCallID,
661 ToolName: toolCall.ToolName,
662 Result: ToolResultOutputContentError{
663 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
664 },
665 ProviderExecuted: false,
666 }
667 results = append(results, result)
668 if toolResultCallback != nil {
669 if err := toolResultCallback(result); err != nil {
670 return nil, err
671 }
672 }
673 continue
674 }
675
676 // Execute the tool
677 toolResult, err := tool.Run(ctx, ToolCall{
678 ID: toolCall.ToolCallID,
679 Name: toolCall.ToolName,
680 Input: toolCall.Input,
681 })
682 if err != nil {
683 result := ToolResultContent{
684 ToolCallID: toolCall.ToolCallID,
685 ToolName: toolCall.ToolName,
686 Result: ToolResultOutputContentError{
687 Error: err,
688 },
689 ClientMetadata: toolResult.Metadata,
690 ProviderExecuted: false,
691 }
692 if toolResultCallback != nil {
693 if cbErr := toolResultCallback(result); cbErr != nil {
694 return nil, cbErr
695 }
696 }
697 return nil, err
698 }
699
700 var result ToolResultContent
701 if toolResult.IsError {
702 result = ToolResultContent{
703 ToolCallID: toolCall.ToolCallID,
704 ToolName: toolCall.ToolName,
705 Result: ToolResultOutputContentError{
706 Error: errors.New(toolResult.Content),
707 },
708 ClientMetadata: toolResult.Metadata,
709 ProviderExecuted: false,
710 }
711 } else {
712 result = ToolResultContent{
713 ToolCallID: toolCall.ToolCallID,
714 ToolName: toolCall.ToolName,
715 Result: ToolResultOutputContentText{
716 Text: toolResult.Content,
717 },
718 ClientMetadata: toolResult.Metadata,
719 ProviderExecuted: false,
720 }
721 }
722 results = append(results, result)
723 if toolResultCallback != nil {
724 if err := toolResultCallback(result); err != nil {
725 return nil, err
726 }
727 }
728 }
729
730 return results, nil
731}
732
733// Stream implements Agent.
734func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
735 // Convert AgentStreamCall to AgentCall for preparation
736 call := AgentCall{
737 Prompt: opts.Prompt,
738 Files: opts.Files,
739 Messages: opts.Messages,
740 MaxOutputTokens: opts.MaxOutputTokens,
741 Temperature: opts.Temperature,
742 TopP: opts.TopP,
743 TopK: opts.TopK,
744 PresencePenalty: opts.PresencePenalty,
745 FrequencyPenalty: opts.FrequencyPenalty,
746 ActiveTools: opts.ActiveTools,
747 ProviderOptions: opts.ProviderOptions,
748 MaxRetries: opts.MaxRetries,
749 OnRetry: opts.OnRetry,
750 StopWhen: opts.StopWhen,
751 PrepareStep: opts.PrepareStep,
752 RepairToolCall: opts.RepairToolCall,
753 }
754
755 call = a.prepareCall(call)
756
757 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
758 if err != nil {
759 return nil, err
760 }
761
762 var responseMessages []Message
763 var steps []StepResult
764 var totalUsage Usage
765
766 // Start agent stream
767 if opts.OnAgentStart != nil {
768 opts.OnAgentStart()
769 }
770
771 for stepNumber := 0; ; stepNumber++ {
772 stepInputMessages := append(initialPrompt, responseMessages...)
773 stepModel := a.settings.model
774 stepSystemPrompt := a.settings.systemPrompt
775 stepActiveTools := call.ActiveTools
776 stepToolChoice := ToolChoiceAuto
777 disableAllTools := false
778 stepTools := a.settings.tools
779 // Apply step preparation if provided
780 if call.PrepareStep != nil {
781 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
782 Model: stepModel,
783 Steps: steps,
784 StepNumber: stepNumber,
785 Messages: stepInputMessages,
786 })
787 if err != nil {
788 return nil, err
789 }
790
791 ctx = updatedCtx
792
793 if prepared.Messages != nil {
794 stepInputMessages = prepared.Messages
795 }
796 if prepared.Model != nil {
797 stepModel = prepared.Model
798 }
799 if prepared.System != nil {
800 stepSystemPrompt = *prepared.System
801 }
802 if prepared.ToolChoice != nil {
803 stepToolChoice = *prepared.ToolChoice
804 }
805 if len(prepared.ActiveTools) > 0 {
806 stepActiveTools = prepared.ActiveTools
807 }
808 disableAllTools = prepared.DisableAllTools
809 if prepared.Tools != nil {
810 stepTools = prepared.Tools
811 }
812 }
813
814 // Recreate prompt with potentially modified system prompt
815 if stepSystemPrompt != a.settings.systemPrompt {
816 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
817 if err != nil {
818 return nil, err
819 }
820 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
821 stepInputMessages[0] = stepPrompt[0]
822 }
823 }
824
825 preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
826
827 // Start step stream
828 if opts.OnStepStart != nil {
829 _ = opts.OnStepStart(stepNumber)
830 }
831
832 // Create streaming call
833 streamCall := Call{
834 Prompt: stepInputMessages,
835 MaxOutputTokens: call.MaxOutputTokens,
836 Temperature: call.Temperature,
837 TopP: call.TopP,
838 TopK: call.TopK,
839 PresencePenalty: call.PresencePenalty,
840 FrequencyPenalty: call.FrequencyPenalty,
841 Tools: preparedTools,
842 ToolChoice: &stepToolChoice,
843 ProviderOptions: call.ProviderOptions,
844 }
845
846 // Execute step with retry logic wrapping both stream creation and processing
847 retryOptions := DefaultRetryOptions()
848 if call.MaxRetries != nil {
849 retryOptions.MaxRetries = *call.MaxRetries
850 }
851 retryOptions.OnRetry = call.OnRetry
852 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
853
854 result, err := retry(ctx, func() (stepExecutionResult, error) {
855 // Create the stream
856 stream, err := stepModel.Stream(ctx, streamCall)
857 if err != nil {
858 return stepExecutionResult{}, err
859 }
860
861 // Process the stream
862 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
863 if err != nil {
864 return stepExecutionResult{}, err
865 }
866
867 return result, nil
868 })
869 if err != nil {
870 if opts.OnError != nil {
871 opts.OnError(err)
872 }
873 return nil, err
874 }
875
876 steps = append(steps, result.StepResult)
877 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
878
879 // Call step finished callback
880 if opts.OnStepFinish != nil {
881 _ = opts.OnStepFinish(result.StepResult)
882 }
883
884 // Add step messages to response messages
885 stepMessages := toResponseMessages(result.StepResult.Content)
886 responseMessages = append(responseMessages, stepMessages...)
887
888 // Check stop conditions
889 shouldStop := isStopConditionMet(call.StopWhen, steps)
890 if shouldStop || !result.ShouldContinue {
891 break
892 }
893 }
894
895 // Finish agent stream
896 agentResult := &AgentResult{
897 Steps: steps,
898 Response: steps[len(steps)-1].Response,
899 TotalUsage: totalUsage,
900 }
901
902 if opts.OnFinish != nil {
903 opts.OnFinish(agentResult)
904 }
905
906 if opts.OnAgentFinish != nil {
907 _ = opts.OnAgentFinish(agentResult)
908 }
909
910 return agentResult, nil
911}
912
913func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
914 preparedTools := make([]Tool, 0, len(tools))
915
916 // If explicitly disabling all tools, return no tools
917 if disableAllTools {
918 return preparedTools
919 }
920
921 for _, tool := range tools {
922 // If activeTools has items, only include tools in the list
923 // If activeTools is empty, include all tools
924 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
925 continue
926 }
927 info := tool.Info()
928 preparedTools = append(preparedTools, FunctionTool{
929 Name: info.Name,
930 Description: info.Description,
931 InputSchema: map[string]any{
932 "type": "object",
933 "properties": info.Parameters,
934 "required": info.Required,
935 },
936 ProviderOptions: tool.ProviderOptions(),
937 })
938 }
939 return preparedTools
940}
941
942// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
943func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
944 if err := a.validateToolCall(toolCall, availableTools); err == nil {
945 return toolCall
946 } else { //nolint: revive
947 if repairFunc != nil {
948 repairOptions := ToolCallRepairOptions{
949 OriginalToolCall: toolCall,
950 ValidationError: err,
951 AvailableTools: availableTools,
952 SystemPrompt: systemPrompt,
953 Messages: messages,
954 }
955
956 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
957 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
958 return *repairedToolCall
959 }
960 }
961 }
962
963 invalidToolCall := toolCall
964 invalidToolCall.Invalid = true
965 invalidToolCall.ValidationError = err
966 return invalidToolCall
967 }
968}
969
970// validateToolCall validates a tool call against available tools and their schemas.
971func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
972 var tool AgentTool
973 for _, t := range availableTools {
974 if t.Info().Name == toolCall.ToolName {
975 tool = t
976 break
977 }
978 }
979
980 if tool == nil {
981 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
982 }
983
984 // Validate JSON parsing
985 var input map[string]any
986 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
987 return fmt.Errorf("invalid JSON input: %w", err)
988 }
989
990 // Basic schema validation (check required fields)
991 // TODO: more robust schema validation using JSON Schema or similar
992 toolInfo := tool.Info()
993 for _, required := range toolInfo.Required {
994 if _, exists := input[required]; !exists {
995 return fmt.Errorf("missing required parameter: %s", required)
996 }
997 }
998 return nil
999}
1000
1001func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1002 if prompt == "" {
1003 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
1004 }
1005
1006 var preparedPrompt Prompt
1007
1008 if system != "" {
1009 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1010 }
1011 preparedPrompt = append(preparedPrompt, messages...)
1012 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1013 return preparedPrompt, nil
1014}
1015
1016// WithSystemPrompt sets the system prompt for the agent.
1017func WithSystemPrompt(prompt string) AgentOption {
1018 return func(s *agentSettings) {
1019 s.systemPrompt = prompt
1020 }
1021}
1022
1023// WithMaxOutputTokens sets the maximum output tokens for the agent.
1024func WithMaxOutputTokens(tokens int64) AgentOption {
1025 return func(s *agentSettings) {
1026 s.maxOutputTokens = &tokens
1027 }
1028}
1029
1030// WithTemperature sets the temperature for the agent.
1031func WithTemperature(temp float64) AgentOption {
1032 return func(s *agentSettings) {
1033 s.temperature = &temp
1034 }
1035}
1036
1037// WithTopP sets the top-p value for the agent.
1038func WithTopP(topP float64) AgentOption {
1039 return func(s *agentSettings) {
1040 s.topP = &topP
1041 }
1042}
1043
1044// WithTopK sets the top-k value for the agent.
1045func WithTopK(topK int64) AgentOption {
1046 return func(s *agentSettings) {
1047 s.topK = &topK
1048 }
1049}
1050
1051// WithPresencePenalty sets the presence penalty for the agent.
1052func WithPresencePenalty(penalty float64) AgentOption {
1053 return func(s *agentSettings) {
1054 s.presencePenalty = &penalty
1055 }
1056}
1057
1058// WithFrequencyPenalty sets the frequency penalty for the agent.
1059func WithFrequencyPenalty(penalty float64) AgentOption {
1060 return func(s *agentSettings) {
1061 s.frequencyPenalty = &penalty
1062 }
1063}
1064
1065// WithTools sets the tools for the agent.
1066func WithTools(tools ...AgentTool) AgentOption {
1067 return func(s *agentSettings) {
1068 s.tools = append(s.tools, tools...)
1069 }
1070}
1071
1072// WithStopConditions sets the stop conditions for the agent.
1073func WithStopConditions(conditions ...StopCondition) AgentOption {
1074 return func(s *agentSettings) {
1075 s.stopWhen = append(s.stopWhen, conditions...)
1076 }
1077}
1078
1079// WithPrepareStep sets the prepare step function for the agent.
1080func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1081 return func(s *agentSettings) {
1082 s.prepareStep = fn
1083 }
1084}
1085
1086// WithRepairToolCall sets the repair tool call function for the agent.
1087func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1088 return func(s *agentSettings) {
1089 s.repairToolCall = fn
1090 }
1091}
1092
1093// WithMaxRetries sets the maximum number of retries for the agent.
1094func WithMaxRetries(maxRetries int) AgentOption {
1095 return func(s *agentSettings) {
1096 s.maxRetries = &maxRetries
1097 }
1098}
1099
1100// WithOnRetry sets the retry callback for the agent.
1101func WithOnRetry(callback OnRetryCallback) AgentOption {
1102 return func(s *agentSettings) {
1103 s.onRetry = callback
1104 }
1105}
1106
1107// processStepStream processes a single step's stream and returns the step result.
1108func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1109 var stepContent []Content
1110 var stepToolCalls []ToolCallContent
1111 var stepUsage Usage
1112 stepFinishReason := FinishReasonUnknown
1113 var stepWarnings []CallWarning
1114 var stepProviderMetadata ProviderMetadata
1115
1116 activeToolCalls := make(map[string]*ToolCallContent)
1117 activeTextContent := make(map[string]string)
1118 type reasoningContent struct {
1119 content string
1120 options ProviderMetadata
1121 }
1122 activeReasoningContent := make(map[string]reasoningContent)
1123
1124 // Process stream parts
1125 for part := range stream {
1126 // Forward all parts to chunk callback
1127 if opts.OnChunk != nil {
1128 err := opts.OnChunk(part)
1129 if err != nil {
1130 return stepExecutionResult{}, err
1131 }
1132 }
1133
1134 switch part.Type {
1135 case StreamPartTypeWarnings:
1136 stepWarnings = part.Warnings
1137 if opts.OnWarnings != nil {
1138 err := opts.OnWarnings(part.Warnings)
1139 if err != nil {
1140 return stepExecutionResult{}, err
1141 }
1142 }
1143
1144 case StreamPartTypeTextStart:
1145 activeTextContent[part.ID] = ""
1146 if opts.OnTextStart != nil {
1147 err := opts.OnTextStart(part.ID)
1148 if err != nil {
1149 return stepExecutionResult{}, err
1150 }
1151 }
1152
1153 case StreamPartTypeTextDelta:
1154 if _, exists := activeTextContent[part.ID]; exists {
1155 activeTextContent[part.ID] += part.Delta
1156 }
1157 if opts.OnTextDelta != nil {
1158 err := opts.OnTextDelta(part.ID, part.Delta)
1159 if err != nil {
1160 return stepExecutionResult{}, err
1161 }
1162 }
1163
1164 case StreamPartTypeTextEnd:
1165 if text, exists := activeTextContent[part.ID]; exists {
1166 stepContent = append(stepContent, TextContent{
1167 Text: text,
1168 ProviderMetadata: part.ProviderMetadata,
1169 })
1170 delete(activeTextContent, part.ID)
1171 }
1172 if opts.OnTextEnd != nil {
1173 err := opts.OnTextEnd(part.ID)
1174 if err != nil {
1175 return stepExecutionResult{}, err
1176 }
1177 }
1178
1179 case StreamPartTypeReasoningStart:
1180 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1181 if opts.OnReasoningStart != nil {
1182 content := ReasoningContent{
1183 Text: part.Delta,
1184 ProviderMetadata: part.ProviderMetadata,
1185 }
1186 err := opts.OnReasoningStart(part.ID, content)
1187 if err != nil {
1188 return stepExecutionResult{}, err
1189 }
1190 }
1191
1192 case StreamPartTypeReasoningDelta:
1193 if active, exists := activeReasoningContent[part.ID]; exists {
1194 active.content += part.Delta
1195 active.options = part.ProviderMetadata
1196 activeReasoningContent[part.ID] = active
1197 }
1198 if opts.OnReasoningDelta != nil {
1199 err := opts.OnReasoningDelta(part.ID, part.Delta)
1200 if err != nil {
1201 return stepExecutionResult{}, err
1202 }
1203 }
1204
1205 case StreamPartTypeReasoningEnd:
1206 if active, exists := activeReasoningContent[part.ID]; exists {
1207 if part.ProviderMetadata != nil {
1208 active.options = part.ProviderMetadata
1209 }
1210 content := ReasoningContent{
1211 Text: active.content,
1212 ProviderMetadata: active.options,
1213 }
1214 stepContent = append(stepContent, content)
1215 if opts.OnReasoningEnd != nil {
1216 err := opts.OnReasoningEnd(part.ID, content)
1217 if err != nil {
1218 return stepExecutionResult{}, err
1219 }
1220 }
1221 delete(activeReasoningContent, part.ID)
1222 }
1223
1224 case StreamPartTypeToolInputStart:
1225 activeToolCalls[part.ID] = &ToolCallContent{
1226 ToolCallID: part.ID,
1227 ToolName: part.ToolCallName,
1228 Input: "",
1229 ProviderExecuted: part.ProviderExecuted,
1230 }
1231 if opts.OnToolInputStart != nil {
1232 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1233 if err != nil {
1234 return stepExecutionResult{}, err
1235 }
1236 }
1237
1238 case StreamPartTypeToolInputDelta:
1239 if toolCall, exists := activeToolCalls[part.ID]; exists {
1240 toolCall.Input += part.Delta
1241 }
1242 if opts.OnToolInputDelta != nil {
1243 err := opts.OnToolInputDelta(part.ID, part.Delta)
1244 if err != nil {
1245 return stepExecutionResult{}, err
1246 }
1247 }
1248
1249 case StreamPartTypeToolInputEnd:
1250 if opts.OnToolInputEnd != nil {
1251 err := opts.OnToolInputEnd(part.ID)
1252 if err != nil {
1253 return stepExecutionResult{}, err
1254 }
1255 }
1256
1257 case StreamPartTypeToolCall:
1258 toolCall := ToolCallContent{
1259 ToolCallID: part.ID,
1260 ToolName: part.ToolCallName,
1261 Input: part.ToolCallInput,
1262 ProviderExecuted: part.ProviderExecuted,
1263 ProviderMetadata: part.ProviderMetadata,
1264 }
1265
1266 // Validate and potentially repair the tool call
1267 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1268 stepToolCalls = append(stepToolCalls, validatedToolCall)
1269 stepContent = append(stepContent, validatedToolCall)
1270
1271 if opts.OnToolCall != nil {
1272 err := opts.OnToolCall(validatedToolCall)
1273 if err != nil {
1274 return stepExecutionResult{}, err
1275 }
1276 }
1277
1278 // Clean up active tool call
1279 delete(activeToolCalls, part.ID)
1280
1281 case StreamPartTypeSource:
1282 sourceContent := SourceContent{
1283 SourceType: part.SourceType,
1284 ID: part.ID,
1285 URL: part.URL,
1286 Title: part.Title,
1287 ProviderMetadata: part.ProviderMetadata,
1288 }
1289 stepContent = append(stepContent, sourceContent)
1290 if opts.OnSource != nil {
1291 err := opts.OnSource(sourceContent)
1292 if err != nil {
1293 return stepExecutionResult{}, err
1294 }
1295 }
1296
1297 case StreamPartTypeFinish:
1298 stepUsage = part.Usage
1299 stepFinishReason = part.FinishReason
1300 stepProviderMetadata = part.ProviderMetadata
1301 if opts.OnStreamFinish != nil {
1302 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1303 if err != nil {
1304 return stepExecutionResult{}, err
1305 }
1306 }
1307
1308 case StreamPartTypeError:
1309 return stepExecutionResult{}, part.Error
1310 }
1311 }
1312
1313 // Execute tools if any
1314 var toolResults []ToolResultContent
1315 if len(stepToolCalls) > 0 {
1316 var err error
1317 toolResults, err = a.executeTools(ctx, stepTools, stepToolCalls, opts.OnToolResult)
1318 if err != nil {
1319 return stepExecutionResult{}, err
1320 }
1321 // Add tool results to content
1322 for _, result := range toolResults {
1323 stepContent = append(stepContent, result)
1324 }
1325 }
1326
1327 stepResult := StepResult{
1328 Response: Response{
1329 Content: stepContent,
1330 FinishReason: stepFinishReason,
1331 Usage: stepUsage,
1332 Warnings: stepWarnings,
1333 ProviderMetadata: stepProviderMetadata,
1334 },
1335 Messages: toResponseMessages(stepContent),
1336 }
1337
1338 // Determine if we should continue (has tool calls and not stopped)
1339 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1340
1341 return stepExecutionResult{
1342 StepResult: stepResult,
1343 ShouldContinue: shouldContinue,
1344 }, nil
1345}
1346
1347func addUsage(a, b Usage) Usage {
1348 return Usage{
1349 InputTokens: a.InputTokens + b.InputTokens,
1350 OutputTokens: a.OutputTokens + b.OutputTokens,
1351 TotalTokens: a.TotalTokens + b.TotalTokens,
1352 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1353 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1354 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1355 }
1356}
1357
1358// WithHeaders sets the headers for the agent.
1359func WithHeaders(headers map[string]string) AgentOption {
1360 return func(s *agentSettings) {
1361 s.headers = headers
1362 }
1363}
1364
1365// WithProviderOptions sets the provider options for the agent.
1366func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1367 return func(s *agentSettings) {
1368 s.providerOptions = providerOptions
1369 }
1370}