1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11)
12
13// StepResult represents the result of a single step in an agent execution.
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19// stepExecutionResult encapsulates the result of executing a step with stream processing.
20type stepExecutionResult struct {
21 StepResult StepResult
22 ShouldContinue bool
23}
24
25// StopCondition defines a function that determines when an agent should stop executing.
26type StopCondition = func(steps []StepResult) bool
27
28// StepCountIs returns a stop condition that stops after the specified number of steps.
29func StepCountIs(stepCount int) StopCondition {
30 return func(steps []StepResult) bool {
31 return len(steps) >= stepCount
32 }
33}
34
35// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
36func HasToolCall(toolName string) StopCondition {
37 return func(steps []StepResult) bool {
38 if len(steps) == 0 {
39 return false
40 }
41 lastStep := steps[len(steps)-1]
42 toolCalls := lastStep.Content.ToolCalls()
43 for _, toolCall := range toolCalls {
44 if toolCall.ToolName == toolName {
45 return true
46 }
47 }
48 return false
49 }
50}
51
52// HasContent returns a stop condition that stops when the specified content type appears in the last step.
53func HasContent(contentType ContentType) StopCondition {
54 return func(steps []StepResult) bool {
55 if len(steps) == 0 {
56 return false
57 }
58 lastStep := steps[len(steps)-1]
59 for _, content := range lastStep.Content {
60 if content.GetType() == contentType {
61 return true
62 }
63 }
64 return false
65 }
66}
67
68// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
69func FinishReasonIs(reason FinishReason) StopCondition {
70 return func(steps []StepResult) bool {
71 if len(steps) == 0 {
72 return false
73 }
74 lastStep := steps[len(steps)-1]
75 return lastStep.FinishReason == reason
76 }
77}
78
79// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
80func MaxTokensUsed(maxTokens int64) StopCondition {
81 return func(steps []StepResult) bool {
82 var totalTokens int64
83 for _, step := range steps {
84 totalTokens += step.Usage.TotalTokens
85 }
86 return totalTokens >= maxTokens
87 }
88}
89
90// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
91type PrepareStepFunctionOptions struct {
92 Steps []StepResult
93 StepNumber int
94 Model LanguageModel
95 Messages []Message
96}
97
98// PrepareStepResult contains the result of preparing a step in an agent execution.
99type PrepareStepResult struct {
100 Model LanguageModel
101 Messages []Message
102 System *string
103 ToolChoice *ToolChoice
104 ActiveTools []string
105 DisableAllTools bool
106}
107
108// ToolCallRepairOptions contains the options for repairing a tool call.
109type ToolCallRepairOptions struct {
110 OriginalToolCall ToolCallContent
111 ValidationError error
112 AvailableTools []AgentTool
113 SystemPrompt string
114 Messages []Message
115}
116
117type (
118 // PrepareStepFunction defines a function that prepares a step in an agent execution.
119 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
120
121 // OnStepFinishedFunction defines a function that is called when a step finishes.
122 OnStepFinishedFunction = func(step StepResult)
123
124 // RepairToolCallFunction defines a function that repairs a tool call.
125 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
126)
127
128type agentSettings struct {
129 systemPrompt string
130 maxOutputTokens *int64
131 temperature *float64
132 topP *float64
133 topK *int64
134 presencePenalty *float64
135 frequencyPenalty *float64
136 headers map[string]string
137 providerOptions ProviderOptions
138
139 // TODO: add support for provider tools
140 tools []AgentTool
141 maxRetries *int
142
143 model LanguageModel
144
145 stopWhen []StopCondition
146 prepareStep PrepareStepFunction
147 repairToolCall RepairToolCallFunction
148 onRetry OnRetryCallback
149}
150
151// AgentCall represents a call to an agent.
152type AgentCall struct {
153 Prompt string `json:"prompt"`
154 Files []FilePart `json:"files"`
155 Messages []Message `json:"messages"`
156 MaxOutputTokens *int64
157 Temperature *float64 `json:"temperature"`
158 TopP *float64 `json:"top_p"`
159 TopK *int64 `json:"top_k"`
160 PresencePenalty *float64 `json:"presence_penalty"`
161 FrequencyPenalty *float64 `json:"frequency_penalty"`
162 ActiveTools []string `json:"active_tools"`
163 ProviderOptions ProviderOptions
164 OnRetry OnRetryCallback
165 MaxRetries *int
166
167 StopWhen []StopCondition
168 PrepareStep PrepareStepFunction
169 RepairToolCall RepairToolCallFunction
170}
171
172// Agent-level callbacks.
173type (
174 // OnAgentStartFunc is called when agent starts.
175 OnAgentStartFunc func()
176
177 // OnAgentFinishFunc is called when agent finishes.
178 OnAgentFinishFunc func(result *AgentResult) error
179
180 // OnStepStartFunc is called when a step starts.
181 OnStepStartFunc func(stepNumber int) error
182
183 // OnStepFinishFunc is called when a step finishes.
184 OnStepFinishFunc func(stepResult StepResult) error
185
186 // OnFinishFunc is called when entire agent completes.
187 OnFinishFunc func(result *AgentResult)
188
189 // OnErrorFunc is called when an error occurs.
190 OnErrorFunc func(error)
191)
192
193// Stream part callbacks - called for each corresponding stream part type.
194type (
195 // OnChunkFunc is called for each stream part (catch-all).
196 OnChunkFunc func(StreamPart) error
197
198 // OnWarningsFunc is called for warnings.
199 OnWarningsFunc func(warnings []CallWarning) error
200
201 // OnTextStartFunc is called when text starts.
202 OnTextStartFunc func(id string) error
203
204 // OnTextDeltaFunc is called for text deltas.
205 OnTextDeltaFunc func(id, text string) error
206
207 // OnTextEndFunc is called when text ends.
208 OnTextEndFunc func(id string) error
209
210 // OnReasoningStartFunc is called when reasoning starts.
211 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
212
213 // OnReasoningDeltaFunc is called for reasoning deltas.
214 OnReasoningDeltaFunc func(id, text string) error
215
216 // OnReasoningEndFunc is called when reasoning ends.
217 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
218
219 // OnToolInputStartFunc is called when tool input starts.
220 OnToolInputStartFunc func(id, toolName string) error
221
222 // OnToolInputDeltaFunc is called for tool input deltas.
223 OnToolInputDeltaFunc func(id, delta string) error
224
225 // OnToolInputEndFunc is called when tool input ends.
226 OnToolInputEndFunc func(id string) error
227
228 // OnToolCallFunc is called when tool call is complete.
229 OnToolCallFunc func(toolCall ToolCallContent) error
230
231 // OnToolResultFunc is called when tool execution completes.
232 OnToolResultFunc func(result ToolResultContent) error
233
234 // OnSourceFunc is called for source references.
235 OnSourceFunc func(source SourceContent) error
236
237 // OnStreamFinishFunc is called when stream finishes.
238 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
239)
240
241// AgentStreamCall represents a streaming call to an agent.
242type AgentStreamCall struct {
243 Prompt string `json:"prompt"`
244 Files []FilePart `json:"files"`
245 Messages []Message `json:"messages"`
246 MaxOutputTokens *int64
247 Temperature *float64 `json:"temperature"`
248 TopP *float64 `json:"top_p"`
249 TopK *int64 `json:"top_k"`
250 PresencePenalty *float64 `json:"presence_penalty"`
251 FrequencyPenalty *float64 `json:"frequency_penalty"`
252 ActiveTools []string `json:"active_tools"`
253 Headers map[string]string
254 ProviderOptions ProviderOptions
255 OnRetry OnRetryCallback
256 MaxRetries *int
257
258 StopWhen []StopCondition
259 PrepareStep PrepareStepFunction
260 RepairToolCall RepairToolCallFunction
261
262 // Agent-level callbacks
263 OnAgentStart OnAgentStartFunc // Called when agent starts
264 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
265 OnStepStart OnStepStartFunc // Called when a step starts
266 OnStepFinish OnStepFinishFunc // Called when a step finishes
267 OnFinish OnFinishFunc // Called when entire agent completes
268 OnError OnErrorFunc // Called when an error occurs
269
270 // Stream part callbacks - called for each corresponding stream part type
271 OnChunk OnChunkFunc // Called for each stream part (catch-all)
272 OnWarnings OnWarningsFunc // Called for warnings
273 OnTextStart OnTextStartFunc // Called when text starts
274 OnTextDelta OnTextDeltaFunc // Called for text deltas
275 OnTextEnd OnTextEndFunc // Called when text ends
276 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
277 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
278 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
279 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
280 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
281 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
282 OnToolCall OnToolCallFunc // Called when tool call is complete
283 OnToolResult OnToolResultFunc // Called when tool execution completes
284 OnSource OnSourceFunc // Called for source references
285 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
286}
287
288// AgentResult represents the result of an agent execution.
289type AgentResult struct {
290 Steps []StepResult
291 // Final response
292 Response Response
293 TotalUsage Usage
294}
295
296// Agent represents an AI agent that can generate responses and stream responses.
297type Agent interface {
298 Generate(context.Context, AgentCall) (*AgentResult, error)
299 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
300}
301
302// AgentOption defines a function that configures agent settings.
303type AgentOption = func(*agentSettings)
304
305type agent struct {
306 settings agentSettings
307}
308
309// NewAgent creates a new agent with the given language model and options.
310func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
311 settings := agentSettings{
312 model: model,
313 }
314 for _, o := range opts {
315 o(&settings)
316 }
317 return &agent{
318 settings: settings,
319 }
320}
321
322func (a *agent) prepareCall(call AgentCall) AgentCall {
323 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
324 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
325 call.TopP = cmp.Or(call.TopP, a.settings.topP)
326 call.TopK = cmp.Or(call.TopK, a.settings.topK)
327 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
328 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
329 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
330
331 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
332 call.StopWhen = a.settings.stopWhen
333 }
334 if call.PrepareStep == nil && a.settings.prepareStep != nil {
335 call.PrepareStep = a.settings.prepareStep
336 }
337 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
338 call.RepairToolCall = a.settings.repairToolCall
339 }
340 if call.OnRetry == nil && a.settings.onRetry != nil {
341 call.OnRetry = a.settings.onRetry
342 }
343
344 providerOptions := ProviderOptions{}
345 if a.settings.providerOptions != nil {
346 maps.Copy(providerOptions, a.settings.providerOptions)
347 }
348 if call.ProviderOptions != nil {
349 maps.Copy(providerOptions, call.ProviderOptions)
350 }
351 call.ProviderOptions = providerOptions
352
353 headers := map[string]string{}
354
355 if a.settings.headers != nil {
356 maps.Copy(headers, a.settings.headers)
357 }
358
359 return call
360}
361
362// Generate implements Agent.
363func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
364 opts = a.prepareCall(opts)
365 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
366 if err != nil {
367 return nil, err
368 }
369 var responseMessages []Message
370 var steps []StepResult
371
372 for {
373 stepInputMessages := append(initialPrompt, responseMessages...)
374 stepModel := a.settings.model
375 stepSystemPrompt := a.settings.systemPrompt
376 stepActiveTools := opts.ActiveTools
377 stepToolChoice := ToolChoiceAuto
378 disableAllTools := false
379
380 if opts.PrepareStep != nil {
381 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
382 Model: stepModel,
383 Steps: steps,
384 StepNumber: len(steps),
385 Messages: stepInputMessages,
386 })
387 if err != nil {
388 return nil, err
389 }
390
391 ctx = updatedCtx
392
393 // Apply prepared step modifications
394 if prepared.Messages != nil {
395 stepInputMessages = prepared.Messages
396 }
397 if prepared.Model != nil {
398 stepModel = prepared.Model
399 }
400 if prepared.System != nil {
401 stepSystemPrompt = *prepared.System
402 }
403 if prepared.ToolChoice != nil {
404 stepToolChoice = *prepared.ToolChoice
405 }
406 if len(prepared.ActiveTools) > 0 {
407 stepActiveTools = prepared.ActiveTools
408 }
409 disableAllTools = prepared.DisableAllTools
410 }
411
412 // Recreate prompt with potentially modified system prompt
413 if stepSystemPrompt != a.settings.systemPrompt {
414 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
415 if err != nil {
416 return nil, err
417 }
418 // Replace system message part, keep the rest
419 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
420 stepInputMessages[0] = stepPrompt[0] // Replace system message
421 }
422 }
423
424 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
425
426 retryOptions := DefaultRetryOptions()
427 if opts.MaxRetries != nil {
428 retryOptions.MaxRetries = *opts.MaxRetries
429 }
430 retryOptions.OnRetry = opts.OnRetry
431 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
432
433 result, err := retry(ctx, func() (*Response, error) {
434 return stepModel.Generate(ctx, Call{
435 Prompt: stepInputMessages,
436 MaxOutputTokens: opts.MaxOutputTokens,
437 Temperature: opts.Temperature,
438 TopP: opts.TopP,
439 TopK: opts.TopK,
440 PresencePenalty: opts.PresencePenalty,
441 FrequencyPenalty: opts.FrequencyPenalty,
442 Tools: preparedTools,
443 ToolChoice: &stepToolChoice,
444 ProviderOptions: opts.ProviderOptions,
445 })
446 })
447 if err != nil {
448 return nil, err
449 }
450
451 var stepToolCalls []ToolCallContent
452 for _, content := range result.Content {
453 if content.GetType() == ContentTypeToolCall {
454 toolCall, ok := AsContentType[ToolCallContent](content)
455 if !ok {
456 continue
457 }
458
459 // Validate and potentially repair the tool call
460 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
461 stepToolCalls = append(stepToolCalls, validatedToolCall)
462 }
463 }
464
465 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
466
467 // Build step content with validated tool calls and tool results
468 stepContent := []Content{}
469 toolCallIndex := 0
470 for _, content := range result.Content {
471 if content.GetType() == ContentTypeToolCall {
472 // Replace with validated tool call
473 if toolCallIndex < len(stepToolCalls) {
474 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
475 toolCallIndex++
476 }
477 } else {
478 // Keep other content as-is
479 stepContent = append(stepContent, content)
480 }
481 }
482 // Add tool results
483 for _, result := range toolResults {
484 stepContent = append(stepContent, result)
485 }
486 currentStepMessages := toResponseMessages(stepContent)
487 responseMessages = append(responseMessages, currentStepMessages...)
488
489 stepResult := StepResult{
490 Response: Response{
491 Content: stepContent,
492 FinishReason: result.FinishReason,
493 Usage: result.Usage,
494 Warnings: result.Warnings,
495 ProviderMetadata: result.ProviderMetadata,
496 },
497 Messages: currentStepMessages,
498 }
499 steps = append(steps, stepResult)
500 shouldStop := isStopConditionMet(opts.StopWhen, steps)
501
502 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
503 break
504 }
505 }
506
507 totalUsage := Usage{}
508
509 for _, step := range steps {
510 usage := step.Usage
511 totalUsage.InputTokens += usage.InputTokens
512 totalUsage.OutputTokens += usage.OutputTokens
513 totalUsage.ReasoningTokens += usage.ReasoningTokens
514 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
515 totalUsage.CacheReadTokens += usage.CacheReadTokens
516 totalUsage.TotalTokens += usage.TotalTokens
517 }
518
519 agentResult := &AgentResult{
520 Steps: steps,
521 Response: steps[len(steps)-1].Response,
522 TotalUsage: totalUsage,
523 }
524 return agentResult, nil
525}
526
527func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
528 if len(conditions) == 0 {
529 return false
530 }
531
532 for _, condition := range conditions {
533 if condition(steps) {
534 return true
535 }
536 }
537 return false
538}
539
540func toResponseMessages(content []Content) []Message {
541 var assistantParts []MessagePart
542 var toolParts []MessagePart
543
544 for _, c := range content {
545 switch c.GetType() {
546 case ContentTypeText:
547 text, ok := AsContentType[TextContent](c)
548 if !ok {
549 continue
550 }
551 assistantParts = append(assistantParts, TextPart{
552 Text: text.Text,
553 ProviderOptions: ProviderOptions(text.ProviderMetadata),
554 })
555 case ContentTypeReasoning:
556 reasoning, ok := AsContentType[ReasoningContent](c)
557 if !ok {
558 continue
559 }
560 assistantParts = append(assistantParts, ReasoningPart{
561 Text: reasoning.Text,
562 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
563 })
564 case ContentTypeToolCall:
565 toolCall, ok := AsContentType[ToolCallContent](c)
566 if !ok {
567 continue
568 }
569 assistantParts = append(assistantParts, ToolCallPart{
570 ToolCallID: toolCall.ToolCallID,
571 ToolName: toolCall.ToolName,
572 Input: toolCall.Input,
573 ProviderExecuted: toolCall.ProviderExecuted,
574 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
575 })
576 case ContentTypeFile:
577 file, ok := AsContentType[FileContent](c)
578 if !ok {
579 continue
580 }
581 assistantParts = append(assistantParts, FilePart{
582 Data: file.Data,
583 MediaType: file.MediaType,
584 ProviderOptions: ProviderOptions(file.ProviderMetadata),
585 })
586 case ContentTypeSource:
587 // Sources are metadata about references used to generate the response.
588 // They don't need to be included in the conversation messages.
589 continue
590 case ContentTypeToolResult:
591 result, ok := AsContentType[ToolResultContent](c)
592 if !ok {
593 continue
594 }
595 toolParts = append(toolParts, ToolResultPart{
596 ToolCallID: result.ToolCallID,
597 Output: result.Result,
598 ProviderOptions: ProviderOptions(result.ProviderMetadata),
599 })
600 }
601 }
602
603 var messages []Message
604 if len(assistantParts) > 0 {
605 messages = append(messages, Message{
606 Role: MessageRoleAssistant,
607 Content: assistantParts,
608 })
609 }
610 if len(toolParts) > 0 {
611 messages = append(messages, Message{
612 Role: MessageRoleTool,
613 Content: toolParts,
614 })
615 }
616 return messages
617}
618
619func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
620 if len(toolCalls) == 0 {
621 return nil, nil
622 }
623
624 // Create a map for quick tool lookup
625 toolMap := make(map[string]AgentTool)
626 for _, tool := range allTools {
627 toolMap[tool.Info().Name] = tool
628 }
629
630 // Execute all tool calls sequentially in order
631 results := make([]ToolResultContent, 0, len(toolCalls))
632
633 for _, toolCall := range toolCalls {
634 // Skip invalid tool calls - create error result
635 if toolCall.Invalid {
636 result := ToolResultContent{
637 ToolCallID: toolCall.ToolCallID,
638 ToolName: toolCall.ToolName,
639 Result: ToolResultOutputContentError{
640 Error: toolCall.ValidationError,
641 },
642 ProviderExecuted: false,
643 }
644 results = append(results, result)
645 if toolResultCallback != nil {
646 if err := toolResultCallback(result); err != nil {
647 return nil, err
648 }
649 }
650 continue
651 }
652
653 tool, exists := toolMap[toolCall.ToolName]
654 if !exists {
655 result := ToolResultContent{
656 ToolCallID: toolCall.ToolCallID,
657 ToolName: toolCall.ToolName,
658 Result: ToolResultOutputContentError{
659 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
660 },
661 ProviderExecuted: false,
662 }
663 results = append(results, result)
664 if toolResultCallback != nil {
665 if err := toolResultCallback(result); err != nil {
666 return nil, err
667 }
668 }
669 continue
670 }
671
672 // Execute the tool
673 toolResult, err := tool.Run(ctx, ToolCall{
674 ID: toolCall.ToolCallID,
675 Name: toolCall.ToolName,
676 Input: toolCall.Input,
677 })
678 if err != nil {
679 result := ToolResultContent{
680 ToolCallID: toolCall.ToolCallID,
681 ToolName: toolCall.ToolName,
682 Result: ToolResultOutputContentError{
683 Error: err,
684 },
685 ClientMetadata: toolResult.Metadata,
686 ProviderExecuted: false,
687 }
688 if toolResultCallback != nil {
689 if cbErr := toolResultCallback(result); cbErr != nil {
690 return nil, cbErr
691 }
692 }
693 return nil, err
694 }
695
696 var result ToolResultContent
697 if toolResult.IsError {
698 result = ToolResultContent{
699 ToolCallID: toolCall.ToolCallID,
700 ToolName: toolCall.ToolName,
701 Result: ToolResultOutputContentError{
702 Error: errors.New(toolResult.Content),
703 },
704 ClientMetadata: toolResult.Metadata,
705 ProviderExecuted: false,
706 }
707 } else {
708 result = ToolResultContent{
709 ToolCallID: toolCall.ToolCallID,
710 ToolName: toolCall.ToolName,
711 Result: ToolResultOutputContentText{
712 Text: toolResult.Content,
713 },
714 ClientMetadata: toolResult.Metadata,
715 ProviderExecuted: false,
716 }
717 }
718 results = append(results, result)
719 if toolResultCallback != nil {
720 if err := toolResultCallback(result); err != nil {
721 return nil, err
722 }
723 }
724 }
725
726 return results, nil
727}
728
729// Stream implements Agent.
730func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
731 // Convert AgentStreamCall to AgentCall for preparation
732 call := AgentCall{
733 Prompt: opts.Prompt,
734 Files: opts.Files,
735 Messages: opts.Messages,
736 MaxOutputTokens: opts.MaxOutputTokens,
737 Temperature: opts.Temperature,
738 TopP: opts.TopP,
739 TopK: opts.TopK,
740 PresencePenalty: opts.PresencePenalty,
741 FrequencyPenalty: opts.FrequencyPenalty,
742 ActiveTools: opts.ActiveTools,
743 ProviderOptions: opts.ProviderOptions,
744 MaxRetries: opts.MaxRetries,
745 OnRetry: opts.OnRetry,
746 StopWhen: opts.StopWhen,
747 PrepareStep: opts.PrepareStep,
748 RepairToolCall: opts.RepairToolCall,
749 }
750
751 call = a.prepareCall(call)
752
753 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
754 if err != nil {
755 return nil, err
756 }
757
758 var responseMessages []Message
759 var steps []StepResult
760 var totalUsage Usage
761
762 // Start agent stream
763 if opts.OnAgentStart != nil {
764 opts.OnAgentStart()
765 }
766
767 for stepNumber := 0; ; stepNumber++ {
768 stepInputMessages := append(initialPrompt, responseMessages...)
769 stepModel := a.settings.model
770 stepSystemPrompt := a.settings.systemPrompt
771 stepActiveTools := call.ActiveTools
772 stepToolChoice := ToolChoiceAuto
773 disableAllTools := false
774
775 // Apply step preparation if provided
776 if call.PrepareStep != nil {
777 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
778 Model: stepModel,
779 Steps: steps,
780 StepNumber: stepNumber,
781 Messages: stepInputMessages,
782 })
783 if err != nil {
784 return nil, err
785 }
786
787 ctx = updatedCtx
788
789 if prepared.Messages != nil {
790 stepInputMessages = prepared.Messages
791 }
792 if prepared.Model != nil {
793 stepModel = prepared.Model
794 }
795 if prepared.System != nil {
796 stepSystemPrompt = *prepared.System
797 }
798 if prepared.ToolChoice != nil {
799 stepToolChoice = *prepared.ToolChoice
800 }
801 if len(prepared.ActiveTools) > 0 {
802 stepActiveTools = prepared.ActiveTools
803 }
804 disableAllTools = prepared.DisableAllTools
805 }
806
807 // Recreate prompt with potentially modified system prompt
808 if stepSystemPrompt != a.settings.systemPrompt {
809 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
810 if err != nil {
811 return nil, err
812 }
813 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
814 stepInputMessages[0] = stepPrompt[0]
815 }
816 }
817
818 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
819
820 // Start step stream
821 if opts.OnStepStart != nil {
822 _ = opts.OnStepStart(stepNumber)
823 }
824
825 // Create streaming call
826 streamCall := Call{
827 Prompt: stepInputMessages,
828 MaxOutputTokens: call.MaxOutputTokens,
829 Temperature: call.Temperature,
830 TopP: call.TopP,
831 TopK: call.TopK,
832 PresencePenalty: call.PresencePenalty,
833 FrequencyPenalty: call.FrequencyPenalty,
834 Tools: preparedTools,
835 ToolChoice: &stepToolChoice,
836 ProviderOptions: call.ProviderOptions,
837 }
838
839 // Execute step with retry logic wrapping both stream creation and processing
840 retryOptions := DefaultRetryOptions()
841 if call.MaxRetries != nil {
842 retryOptions.MaxRetries = *call.MaxRetries
843 }
844 retryOptions.OnRetry = call.OnRetry
845 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
846
847 result, err := retry(ctx, func() (stepExecutionResult, error) {
848 // Create the stream
849 stream, err := stepModel.Stream(ctx, streamCall)
850 if err != nil {
851 return stepExecutionResult{}, err
852 }
853
854 // Process the stream
855 result, err := a.processStepStream(ctx, stream, opts, steps)
856 if err != nil {
857 return stepExecutionResult{}, err
858 }
859
860 return result, nil
861 })
862 if err != nil {
863 if opts.OnError != nil {
864 opts.OnError(err)
865 }
866 return nil, err
867 }
868
869 steps = append(steps, result.StepResult)
870 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
871
872 // Call step finished callback
873 if opts.OnStepFinish != nil {
874 _ = opts.OnStepFinish(result.StepResult)
875 }
876
877 // Add step messages to response messages
878 stepMessages := toResponseMessages(result.StepResult.Content)
879 responseMessages = append(responseMessages, stepMessages...)
880
881 // Check stop conditions
882 shouldStop := isStopConditionMet(call.StopWhen, steps)
883 if shouldStop || !result.ShouldContinue {
884 break
885 }
886 }
887
888 // Finish agent stream
889 agentResult := &AgentResult{
890 Steps: steps,
891 Response: steps[len(steps)-1].Response,
892 TotalUsage: totalUsage,
893 }
894
895 if opts.OnFinish != nil {
896 opts.OnFinish(agentResult)
897 }
898
899 if opts.OnAgentFinish != nil {
900 _ = opts.OnAgentFinish(agentResult)
901 }
902
903 return agentResult, nil
904}
905
906func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
907 preparedTools := make([]Tool, 0, len(tools))
908
909 // If explicitly disabling all tools, return no tools
910 if disableAllTools {
911 return preparedTools
912 }
913
914 for _, tool := range tools {
915 // If activeTools has items, only include tools in the list
916 // If activeTools is empty, include all tools
917 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
918 continue
919 }
920 info := tool.Info()
921 preparedTools = append(preparedTools, FunctionTool{
922 Name: info.Name,
923 Description: info.Description,
924 InputSchema: map[string]any{
925 "type": "object",
926 "properties": info.Parameters,
927 "required": info.Required,
928 },
929 ProviderOptions: tool.ProviderOptions(),
930 })
931 }
932 return preparedTools
933}
934
935// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
936func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
937 if err := a.validateToolCall(toolCall, availableTools); err == nil {
938 return toolCall
939 } else { //nolint: revive
940 if repairFunc != nil {
941 repairOptions := ToolCallRepairOptions{
942 OriginalToolCall: toolCall,
943 ValidationError: err,
944 AvailableTools: availableTools,
945 SystemPrompt: systemPrompt,
946 Messages: messages,
947 }
948
949 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
950 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
951 return *repairedToolCall
952 }
953 }
954 }
955
956 invalidToolCall := toolCall
957 invalidToolCall.Invalid = true
958 invalidToolCall.ValidationError = err
959 return invalidToolCall
960 }
961}
962
963// validateToolCall validates a tool call against available tools and their schemas.
964func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
965 var tool AgentTool
966 for _, t := range availableTools {
967 if t.Info().Name == toolCall.ToolName {
968 tool = t
969 break
970 }
971 }
972
973 if tool == nil {
974 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
975 }
976
977 // Validate JSON parsing
978 var input map[string]any
979 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
980 return fmt.Errorf("invalid JSON input: %w", err)
981 }
982
983 // Basic schema validation (check required fields)
984 // TODO: more robust schema validation using JSON Schema or similar
985 toolInfo := tool.Info()
986 for _, required := range toolInfo.Required {
987 if _, exists := input[required]; !exists {
988 return fmt.Errorf("missing required parameter: %s", required)
989 }
990 }
991 return nil
992}
993
994func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
995 if prompt == "" {
996 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
997 }
998
999 var preparedPrompt Prompt
1000
1001 if system != "" {
1002 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1003 }
1004 preparedPrompt = append(preparedPrompt, messages...)
1005 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1006 return preparedPrompt, nil
1007}
1008
1009// WithSystemPrompt sets the system prompt for the agent.
1010func WithSystemPrompt(prompt string) AgentOption {
1011 return func(s *agentSettings) {
1012 s.systemPrompt = prompt
1013 }
1014}
1015
1016// WithMaxOutputTokens sets the maximum output tokens for the agent.
1017func WithMaxOutputTokens(tokens int64) AgentOption {
1018 return func(s *agentSettings) {
1019 s.maxOutputTokens = &tokens
1020 }
1021}
1022
1023// WithTemperature sets the temperature for the agent.
1024func WithTemperature(temp float64) AgentOption {
1025 return func(s *agentSettings) {
1026 s.temperature = &temp
1027 }
1028}
1029
1030// WithTopP sets the top-p value for the agent.
1031func WithTopP(topP float64) AgentOption {
1032 return func(s *agentSettings) {
1033 s.topP = &topP
1034 }
1035}
1036
1037// WithTopK sets the top-k value for the agent.
1038func WithTopK(topK int64) AgentOption {
1039 return func(s *agentSettings) {
1040 s.topK = &topK
1041 }
1042}
1043
1044// WithPresencePenalty sets the presence penalty for the agent.
1045func WithPresencePenalty(penalty float64) AgentOption {
1046 return func(s *agentSettings) {
1047 s.presencePenalty = &penalty
1048 }
1049}
1050
1051// WithFrequencyPenalty sets the frequency penalty for the agent.
1052func WithFrequencyPenalty(penalty float64) AgentOption {
1053 return func(s *agentSettings) {
1054 s.frequencyPenalty = &penalty
1055 }
1056}
1057
1058// WithTools sets the tools for the agent.
1059func WithTools(tools ...AgentTool) AgentOption {
1060 return func(s *agentSettings) {
1061 s.tools = append(s.tools, tools...)
1062 }
1063}
1064
1065// WithStopConditions sets the stop conditions for the agent.
1066func WithStopConditions(conditions ...StopCondition) AgentOption {
1067 return func(s *agentSettings) {
1068 s.stopWhen = append(s.stopWhen, conditions...)
1069 }
1070}
1071
1072// WithPrepareStep sets the prepare step function for the agent.
1073func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1074 return func(s *agentSettings) {
1075 s.prepareStep = fn
1076 }
1077}
1078
1079// WithRepairToolCall sets the repair tool call function for the agent.
1080func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1081 return func(s *agentSettings) {
1082 s.repairToolCall = fn
1083 }
1084}
1085
1086// WithMaxRetries sets the maximum number of retries for the agent.
1087func WithMaxRetries(maxRetries int) AgentOption {
1088 return func(s *agentSettings) {
1089 s.maxRetries = &maxRetries
1090 }
1091}
1092
1093// WithOnRetry sets the retry callback for the agent.
1094func WithOnRetry(callback OnRetryCallback) AgentOption {
1095 return func(s *agentSettings) {
1096 s.onRetry = callback
1097 }
1098}
1099
1100// processStepStream processes a single step's stream and returns the step result.
1101func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) {
1102 var stepContent []Content
1103 var stepToolCalls []ToolCallContent
1104 var stepUsage Usage
1105 stepFinishReason := FinishReasonUnknown
1106 var stepWarnings []CallWarning
1107 var stepProviderMetadata ProviderMetadata
1108
1109 activeToolCalls := make(map[string]*ToolCallContent)
1110 activeTextContent := make(map[string]string)
1111 type reasoningContent struct {
1112 content string
1113 options ProviderMetadata
1114 }
1115 activeReasoningContent := make(map[string]reasoningContent)
1116
1117 // Process stream parts
1118 for part := range stream {
1119 // Forward all parts to chunk callback
1120 if opts.OnChunk != nil {
1121 err := opts.OnChunk(part)
1122 if err != nil {
1123 return stepExecutionResult{}, err
1124 }
1125 }
1126
1127 switch part.Type {
1128 case StreamPartTypeWarnings:
1129 stepWarnings = part.Warnings
1130 if opts.OnWarnings != nil {
1131 err := opts.OnWarnings(part.Warnings)
1132 if err != nil {
1133 return stepExecutionResult{}, err
1134 }
1135 }
1136
1137 case StreamPartTypeTextStart:
1138 activeTextContent[part.ID] = ""
1139 if opts.OnTextStart != nil {
1140 err := opts.OnTextStart(part.ID)
1141 if err != nil {
1142 return stepExecutionResult{}, err
1143 }
1144 }
1145
1146 case StreamPartTypeTextDelta:
1147 if _, exists := activeTextContent[part.ID]; exists {
1148 activeTextContent[part.ID] += part.Delta
1149 }
1150 if opts.OnTextDelta != nil {
1151 err := opts.OnTextDelta(part.ID, part.Delta)
1152 if err != nil {
1153 return stepExecutionResult{}, err
1154 }
1155 }
1156
1157 case StreamPartTypeTextEnd:
1158 if text, exists := activeTextContent[part.ID]; exists {
1159 stepContent = append(stepContent, TextContent{
1160 Text: text,
1161 ProviderMetadata: part.ProviderMetadata,
1162 })
1163 delete(activeTextContent, part.ID)
1164 }
1165 if opts.OnTextEnd != nil {
1166 err := opts.OnTextEnd(part.ID)
1167 if err != nil {
1168 return stepExecutionResult{}, err
1169 }
1170 }
1171
1172 case StreamPartTypeReasoningStart:
1173 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1174 if opts.OnReasoningStart != nil {
1175 content := ReasoningContent{
1176 Text: part.Delta,
1177 ProviderMetadata: part.ProviderMetadata,
1178 }
1179 err := opts.OnReasoningStart(part.ID, content)
1180 if err != nil {
1181 return stepExecutionResult{}, err
1182 }
1183 }
1184
1185 case StreamPartTypeReasoningDelta:
1186 if active, exists := activeReasoningContent[part.ID]; exists {
1187 active.content += part.Delta
1188 active.options = part.ProviderMetadata
1189 activeReasoningContent[part.ID] = active
1190 }
1191 if opts.OnReasoningDelta != nil {
1192 err := opts.OnReasoningDelta(part.ID, part.Delta)
1193 if err != nil {
1194 return stepExecutionResult{}, err
1195 }
1196 }
1197
1198 case StreamPartTypeReasoningEnd:
1199 if active, exists := activeReasoningContent[part.ID]; exists {
1200 if part.ProviderMetadata != nil {
1201 active.options = part.ProviderMetadata
1202 }
1203 content := ReasoningContent{
1204 Text: active.content,
1205 ProviderMetadata: active.options,
1206 }
1207 stepContent = append(stepContent, content)
1208 if opts.OnReasoningEnd != nil {
1209 err := opts.OnReasoningEnd(part.ID, content)
1210 if err != nil {
1211 return stepExecutionResult{}, err
1212 }
1213 }
1214 delete(activeReasoningContent, part.ID)
1215 }
1216
1217 case StreamPartTypeToolInputStart:
1218 activeToolCalls[part.ID] = &ToolCallContent{
1219 ToolCallID: part.ID,
1220 ToolName: part.ToolCallName,
1221 Input: "",
1222 ProviderExecuted: part.ProviderExecuted,
1223 }
1224 if opts.OnToolInputStart != nil {
1225 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1226 if err != nil {
1227 return stepExecutionResult{}, err
1228 }
1229 }
1230
1231 case StreamPartTypeToolInputDelta:
1232 if toolCall, exists := activeToolCalls[part.ID]; exists {
1233 toolCall.Input += part.Delta
1234 }
1235 if opts.OnToolInputDelta != nil {
1236 err := opts.OnToolInputDelta(part.ID, part.Delta)
1237 if err != nil {
1238 return stepExecutionResult{}, err
1239 }
1240 }
1241
1242 case StreamPartTypeToolInputEnd:
1243 if opts.OnToolInputEnd != nil {
1244 err := opts.OnToolInputEnd(part.ID)
1245 if err != nil {
1246 return stepExecutionResult{}, err
1247 }
1248 }
1249
1250 case StreamPartTypeToolCall:
1251 toolCall := ToolCallContent{
1252 ToolCallID: part.ID,
1253 ToolName: part.ToolCallName,
1254 Input: part.ToolCallInput,
1255 ProviderExecuted: part.ProviderExecuted,
1256 ProviderMetadata: part.ProviderMetadata,
1257 }
1258
1259 // Validate and potentially repair the tool call
1260 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1261 stepToolCalls = append(stepToolCalls, validatedToolCall)
1262 stepContent = append(stepContent, validatedToolCall)
1263
1264 if opts.OnToolCall != nil {
1265 err := opts.OnToolCall(validatedToolCall)
1266 if err != nil {
1267 return stepExecutionResult{}, err
1268 }
1269 }
1270
1271 // Clean up active tool call
1272 delete(activeToolCalls, part.ID)
1273
1274 case StreamPartTypeSource:
1275 sourceContent := SourceContent{
1276 SourceType: part.SourceType,
1277 ID: part.ID,
1278 URL: part.URL,
1279 Title: part.Title,
1280 ProviderMetadata: part.ProviderMetadata,
1281 }
1282 stepContent = append(stepContent, sourceContent)
1283 if opts.OnSource != nil {
1284 err := opts.OnSource(sourceContent)
1285 if err != nil {
1286 return stepExecutionResult{}, err
1287 }
1288 }
1289
1290 case StreamPartTypeFinish:
1291 stepUsage = part.Usage
1292 stepFinishReason = part.FinishReason
1293 stepProviderMetadata = part.ProviderMetadata
1294 if opts.OnStreamFinish != nil {
1295 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1296 if err != nil {
1297 return stepExecutionResult{}, err
1298 }
1299 }
1300
1301 case StreamPartTypeError:
1302 return stepExecutionResult{}, part.Error
1303 }
1304 }
1305
1306 // Execute tools if any
1307 var toolResults []ToolResultContent
1308 if len(stepToolCalls) > 0 {
1309 var err error
1310 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1311 if err != nil {
1312 return stepExecutionResult{}, err
1313 }
1314 // Add tool results to content
1315 for _, result := range toolResults {
1316 stepContent = append(stepContent, result)
1317 }
1318 }
1319
1320 stepResult := StepResult{
1321 Response: Response{
1322 Content: stepContent,
1323 FinishReason: stepFinishReason,
1324 Usage: stepUsage,
1325 Warnings: stepWarnings,
1326 ProviderMetadata: stepProviderMetadata,
1327 },
1328 Messages: toResponseMessages(stepContent),
1329 }
1330
1331 // Determine if we should continue (has tool calls and not stopped)
1332 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1333
1334 return stepExecutionResult{
1335 StepResult: stepResult,
1336 ShouldContinue: shouldContinue,
1337 }, nil
1338}
1339
1340func addUsage(a, b Usage) Usage {
1341 return Usage{
1342 InputTokens: a.InputTokens + b.InputTokens,
1343 OutputTokens: a.OutputTokens + b.OutputTokens,
1344 TotalTokens: a.TotalTokens + b.TotalTokens,
1345 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1346 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1347 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1348 }
1349}
1350
1351// WithHeaders sets the headers for the agent.
1352func WithHeaders(headers map[string]string) AgentOption {
1353 return func(s *agentSettings) {
1354 s.headers = headers
1355 }
1356}
1357
1358// WithProviderOptions sets the provider options for the agent.
1359func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1360 return func(s *agentSettings) {
1361 s.providerOptions = providerOptions
1362 }
1363}