1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11)
12
13// StepResult represents the result of a single step in an agent execution.
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19// StopCondition defines a function that determines when an agent should stop executing.
20type StopCondition = func(steps []StepResult) bool
21
22// StepCountIs returns a stop condition that stops after the specified number of steps.
23func StepCountIs(stepCount int) StopCondition {
24 return func(steps []StepResult) bool {
25 return len(steps) >= stepCount
26 }
27}
28
29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
30func HasToolCall(toolName string) StopCondition {
31 return func(steps []StepResult) bool {
32 if len(steps) == 0 {
33 return false
34 }
35 lastStep := steps[len(steps)-1]
36 toolCalls := lastStep.Content.ToolCalls()
37 for _, toolCall := range toolCalls {
38 if toolCall.ToolName == toolName {
39 return true
40 }
41 }
42 return false
43 }
44}
45
46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
47func HasContent(contentType ContentType) StopCondition {
48 return func(steps []StepResult) bool {
49 if len(steps) == 0 {
50 return false
51 }
52 lastStep := steps[len(steps)-1]
53 for _, content := range lastStep.Content {
54 if content.GetType() == contentType {
55 return true
56 }
57 }
58 return false
59 }
60}
61
62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
63func FinishReasonIs(reason FinishReason) StopCondition {
64 return func(steps []StepResult) bool {
65 if len(steps) == 0 {
66 return false
67 }
68 lastStep := steps[len(steps)-1]
69 return lastStep.FinishReason == reason
70 }
71}
72
73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
74func MaxTokensUsed(maxTokens int64) StopCondition {
75 return func(steps []StepResult) bool {
76 var totalTokens int64
77 for _, step := range steps {
78 totalTokens += step.Usage.TotalTokens
79 }
80 return totalTokens >= maxTokens
81 }
82}
83
84// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
85type PrepareStepFunctionOptions struct {
86 Steps []StepResult
87 StepNumber int
88 Model LanguageModel
89 Messages []Message
90}
91
92// PrepareStepResult contains the result of preparing a step in an agent execution.
93type PrepareStepResult struct {
94 Model LanguageModel
95 Messages []Message
96 System *string
97 ToolChoice *ToolChoice
98 ActiveTools []string
99 DisableAllTools bool
100}
101
102// ToolCallRepairOptions contains the options for repairing a tool call.
103type ToolCallRepairOptions struct {
104 OriginalToolCall ToolCallContent
105 ValidationError error
106 AvailableTools []AgentTool
107 SystemPrompt string
108 Messages []Message
109}
110
111type (
112 // PrepareStepFunction defines a function that prepares a step in an agent execution.
113 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
114
115 // OnStepFinishedFunction defines a function that is called when a step finishes.
116 OnStepFinishedFunction = func(step StepResult)
117
118 // RepairToolCallFunction defines a function that repairs a tool call.
119 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
120)
121
122type agentSettings struct {
123 systemPrompt string
124 maxOutputTokens *int64
125 temperature *float64
126 topP *float64
127 topK *int64
128 presencePenalty *float64
129 frequencyPenalty *float64
130 headers map[string]string
131 providerOptions ProviderOptions
132
133 // TODO: add support for provider tools
134 tools []AgentTool
135 maxRetries *int
136
137 model LanguageModel
138
139 stopWhen []StopCondition
140 prepareStep PrepareStepFunction
141 repairToolCall RepairToolCallFunction
142 onRetry OnRetryCallback
143}
144
145// AgentCall represents a call to an agent.
146type AgentCall struct {
147 Prompt string `json:"prompt"`
148 Files []FilePart `json:"files"`
149 Messages []Message `json:"messages"`
150 MaxOutputTokens *int64
151 Temperature *float64 `json:"temperature"`
152 TopP *float64 `json:"top_p"`
153 TopK *int64 `json:"top_k"`
154 PresencePenalty *float64 `json:"presence_penalty"`
155 FrequencyPenalty *float64 `json:"frequency_penalty"`
156 ActiveTools []string `json:"active_tools"`
157 ProviderOptions ProviderOptions
158 OnRetry OnRetryCallback
159 MaxRetries *int
160
161 StopWhen []StopCondition
162 PrepareStep PrepareStepFunction
163 RepairToolCall RepairToolCallFunction
164}
165
166// Agent-level callbacks.
167type (
168 // OnAgentStartFunc is called when agent starts.
169 OnAgentStartFunc func()
170
171 // OnAgentFinishFunc is called when agent finishes.
172 OnAgentFinishFunc func(result *AgentResult) error
173
174 // OnStepStartFunc is called when a step starts.
175 OnStepStartFunc func(stepNumber int) error
176
177 // OnStepFinishFunc is called when a step finishes.
178 OnStepFinishFunc func(stepResult StepResult) error
179
180 // OnFinishFunc is called when entire agent completes.
181 OnFinishFunc func(result *AgentResult)
182
183 // OnErrorFunc is called when an error occurs.
184 OnErrorFunc func(error)
185)
186
187// Stream part callbacks - called for each corresponding stream part type.
188type (
189 // OnChunkFunc is called for each stream part (catch-all).
190 OnChunkFunc func(StreamPart) error
191
192 // OnWarningsFunc is called for warnings.
193 OnWarningsFunc func(warnings []CallWarning) error
194
195 // OnTextStartFunc is called when text starts.
196 OnTextStartFunc func(id string) error
197
198 // OnTextDeltaFunc is called for text deltas.
199 OnTextDeltaFunc func(id, text string) error
200
201 // OnTextEndFunc is called when text ends.
202 OnTextEndFunc func(id string) error
203
204 // OnReasoningStartFunc is called when reasoning starts.
205 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
206
207 // OnReasoningDeltaFunc is called for reasoning deltas.
208 OnReasoningDeltaFunc func(id, text string) error
209
210 // OnReasoningEndFunc is called when reasoning ends.
211 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
212
213 // OnToolInputStartFunc is called when tool input starts.
214 OnToolInputStartFunc func(id, toolName string) error
215
216 // OnToolInputDeltaFunc is called for tool input deltas.
217 OnToolInputDeltaFunc func(id, delta string) error
218
219 // OnToolInputEndFunc is called when tool input ends.
220 OnToolInputEndFunc func(id string) error
221
222 // OnToolCallFunc is called when tool call is complete.
223 OnToolCallFunc func(toolCall ToolCallContent) error
224
225 // OnToolResultFunc is called when tool execution completes.
226 OnToolResultFunc func(result ToolResultContent) error
227
228 // OnSourceFunc is called for source references.
229 OnSourceFunc func(source SourceContent) error
230
231 // OnStreamFinishFunc is called when stream finishes.
232 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
233)
234
235// AgentStreamCall represents a streaming call to an agent.
236type AgentStreamCall struct {
237 Prompt string `json:"prompt"`
238 Files []FilePart `json:"files"`
239 Messages []Message `json:"messages"`
240 MaxOutputTokens *int64
241 Temperature *float64 `json:"temperature"`
242 TopP *float64 `json:"top_p"`
243 TopK *int64 `json:"top_k"`
244 PresencePenalty *float64 `json:"presence_penalty"`
245 FrequencyPenalty *float64 `json:"frequency_penalty"`
246 ActiveTools []string `json:"active_tools"`
247 Headers map[string]string
248 ProviderOptions ProviderOptions
249 OnRetry OnRetryCallback
250 MaxRetries *int
251
252 StopWhen []StopCondition
253 PrepareStep PrepareStepFunction
254 RepairToolCall RepairToolCallFunction
255
256 // Agent-level callbacks
257 OnAgentStart OnAgentStartFunc // Called when agent starts
258 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
259 OnStepStart OnStepStartFunc // Called when a step starts
260 OnStepFinish OnStepFinishFunc // Called when a step finishes
261 OnFinish OnFinishFunc // Called when entire agent completes
262 OnError OnErrorFunc // Called when an error occurs
263
264 // Stream part callbacks - called for each corresponding stream part type
265 OnChunk OnChunkFunc // Called for each stream part (catch-all)
266 OnWarnings OnWarningsFunc // Called for warnings
267 OnTextStart OnTextStartFunc // Called when text starts
268 OnTextDelta OnTextDeltaFunc // Called for text deltas
269 OnTextEnd OnTextEndFunc // Called when text ends
270 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
271 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
272 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
273 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
274 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
275 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
276 OnToolCall OnToolCallFunc // Called when tool call is complete
277 OnToolResult OnToolResultFunc // Called when tool execution completes
278 OnSource OnSourceFunc // Called for source references
279 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
280}
281
282// AgentResult represents the result of an agent execution.
283type AgentResult struct {
284 Steps []StepResult
285 // Final response
286 Response Response
287 TotalUsage Usage
288}
289
290// Agent represents an AI agent that can generate responses and stream responses.
291type Agent interface {
292 Generate(context.Context, AgentCall) (*AgentResult, error)
293 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
294}
295
296// AgentOption defines a function that configures agent settings.
297type AgentOption = func(*agentSettings)
298
299type agent struct {
300 settings agentSettings
301}
302
303// NewAgent creates a new agent with the given language model and options.
304func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
305 settings := agentSettings{
306 model: model,
307 }
308 for _, o := range opts {
309 o(&settings)
310 }
311 return &agent{
312 settings: settings,
313 }
314}
315
316func (a *agent) prepareCall(call AgentCall) AgentCall {
317 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
318 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
319 call.TopP = cmp.Or(call.TopP, a.settings.topP)
320 call.TopK = cmp.Or(call.TopK, a.settings.topK)
321 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
322 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
323 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
324
325 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
326 call.StopWhen = a.settings.stopWhen
327 }
328 if call.PrepareStep == nil && a.settings.prepareStep != nil {
329 call.PrepareStep = a.settings.prepareStep
330 }
331 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
332 call.RepairToolCall = a.settings.repairToolCall
333 }
334 if call.OnRetry == nil && a.settings.onRetry != nil {
335 call.OnRetry = a.settings.onRetry
336 }
337
338 providerOptions := ProviderOptions{}
339 if a.settings.providerOptions != nil {
340 maps.Copy(providerOptions, a.settings.providerOptions)
341 }
342 if call.ProviderOptions != nil {
343 maps.Copy(providerOptions, call.ProviderOptions)
344 }
345 call.ProviderOptions = providerOptions
346
347 headers := map[string]string{}
348
349 if a.settings.headers != nil {
350 maps.Copy(headers, a.settings.headers)
351 }
352
353 return call
354}
355
356// Generate implements Agent.
357func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
358 opts = a.prepareCall(opts)
359 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
360 if err != nil {
361 return nil, err
362 }
363 var responseMessages []Message
364 var steps []StepResult
365
366 for {
367 stepInputMessages := append(initialPrompt, responseMessages...)
368 stepModel := a.settings.model
369 stepSystemPrompt := a.settings.systemPrompt
370 stepActiveTools := opts.ActiveTools
371 stepToolChoice := ToolChoiceAuto
372 disableAllTools := false
373
374 if opts.PrepareStep != nil {
375 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
376 Model: stepModel,
377 Steps: steps,
378 StepNumber: len(steps),
379 Messages: stepInputMessages,
380 })
381 if err != nil {
382 return nil, err
383 }
384
385 ctx = updatedCtx
386
387 // Apply prepared step modifications
388 if prepared.Messages != nil {
389 stepInputMessages = prepared.Messages
390 }
391 if prepared.Model != nil {
392 stepModel = prepared.Model
393 }
394 if prepared.System != nil {
395 stepSystemPrompt = *prepared.System
396 }
397 if prepared.ToolChoice != nil {
398 stepToolChoice = *prepared.ToolChoice
399 }
400 if len(prepared.ActiveTools) > 0 {
401 stepActiveTools = prepared.ActiveTools
402 }
403 disableAllTools = prepared.DisableAllTools
404 }
405
406 // Recreate prompt with potentially modified system prompt
407 if stepSystemPrompt != a.settings.systemPrompt {
408 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
409 if err != nil {
410 return nil, err
411 }
412 // Replace system message part, keep the rest
413 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
414 stepInputMessages[0] = stepPrompt[0] // Replace system message
415 }
416 }
417
418 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
419
420 retryOptions := DefaultRetryOptions()
421 if opts.MaxRetries != nil {
422 retryOptions.MaxRetries = *opts.MaxRetries
423 }
424 retryOptions.OnRetry = opts.OnRetry
425 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
426
427 result, err := retry(ctx, func() (*Response, error) {
428 return stepModel.Generate(ctx, Call{
429 Prompt: stepInputMessages,
430 MaxOutputTokens: opts.MaxOutputTokens,
431 Temperature: opts.Temperature,
432 TopP: opts.TopP,
433 TopK: opts.TopK,
434 PresencePenalty: opts.PresencePenalty,
435 FrequencyPenalty: opts.FrequencyPenalty,
436 Tools: preparedTools,
437 ToolChoice: &stepToolChoice,
438 ProviderOptions: opts.ProviderOptions,
439 })
440 })
441 if err != nil {
442 return nil, err
443 }
444
445 var stepToolCalls []ToolCallContent
446 for _, content := range result.Content {
447 if content.GetType() == ContentTypeToolCall {
448 toolCall, ok := AsContentType[ToolCallContent](content)
449 if !ok {
450 continue
451 }
452
453 // Validate and potentially repair the tool call
454 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
455 stepToolCalls = append(stepToolCalls, validatedToolCall)
456 }
457 }
458
459 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
460
461 // Build step content with validated tool calls and tool results
462 stepContent := []Content{}
463 toolCallIndex := 0
464 for _, content := range result.Content {
465 if content.GetType() == ContentTypeToolCall {
466 // Replace with validated tool call
467 if toolCallIndex < len(stepToolCalls) {
468 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
469 toolCallIndex++
470 }
471 } else {
472 // Keep other content as-is
473 stepContent = append(stepContent, content)
474 }
475 }
476 // Add tool results
477 for _, result := range toolResults {
478 stepContent = append(stepContent, result)
479 }
480 currentStepMessages := toResponseMessages(stepContent)
481 responseMessages = append(responseMessages, currentStepMessages...)
482
483 stepResult := StepResult{
484 Response: Response{
485 Content: stepContent,
486 FinishReason: result.FinishReason,
487 Usage: result.Usage,
488 Warnings: result.Warnings,
489 ProviderMetadata: result.ProviderMetadata,
490 },
491 Messages: currentStepMessages,
492 }
493 steps = append(steps, stepResult)
494 shouldStop := isStopConditionMet(opts.StopWhen, steps)
495
496 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
497 break
498 }
499 }
500
501 totalUsage := Usage{}
502
503 for _, step := range steps {
504 usage := step.Usage
505 totalUsage.InputTokens += usage.InputTokens
506 totalUsage.OutputTokens += usage.OutputTokens
507 totalUsage.ReasoningTokens += usage.ReasoningTokens
508 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
509 totalUsage.CacheReadTokens += usage.CacheReadTokens
510 totalUsage.TotalTokens += usage.TotalTokens
511 }
512
513 agentResult := &AgentResult{
514 Steps: steps,
515 Response: steps[len(steps)-1].Response,
516 TotalUsage: totalUsage,
517 }
518 return agentResult, nil
519}
520
521func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
522 if len(conditions) == 0 {
523 return false
524 }
525
526 for _, condition := range conditions {
527 if condition(steps) {
528 return true
529 }
530 }
531 return false
532}
533
534func toResponseMessages(content []Content) []Message {
535 var assistantParts []MessagePart
536 var toolParts []MessagePart
537
538 for _, c := range content {
539 switch c.GetType() {
540 case ContentTypeText:
541 text, ok := AsContentType[TextContent](c)
542 if !ok {
543 continue
544 }
545 assistantParts = append(assistantParts, TextPart{
546 Text: text.Text,
547 ProviderOptions: ProviderOptions(text.ProviderMetadata),
548 })
549 case ContentTypeReasoning:
550 reasoning, ok := AsContentType[ReasoningContent](c)
551 if !ok {
552 continue
553 }
554 assistantParts = append(assistantParts, ReasoningPart{
555 Text: reasoning.Text,
556 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
557 })
558 case ContentTypeToolCall:
559 toolCall, ok := AsContentType[ToolCallContent](c)
560 if !ok {
561 continue
562 }
563 assistantParts = append(assistantParts, ToolCallPart{
564 ToolCallID: toolCall.ToolCallID,
565 ToolName: toolCall.ToolName,
566 Input: toolCall.Input,
567 ProviderExecuted: toolCall.ProviderExecuted,
568 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
569 })
570 case ContentTypeFile:
571 file, ok := AsContentType[FileContent](c)
572 if !ok {
573 continue
574 }
575 assistantParts = append(assistantParts, FilePart{
576 Data: file.Data,
577 MediaType: file.MediaType,
578 ProviderOptions: ProviderOptions(file.ProviderMetadata),
579 })
580 case ContentTypeSource:
581 // Sources are metadata about references used to generate the response.
582 // They don't need to be included in the conversation messages.
583 continue
584 case ContentTypeToolResult:
585 result, ok := AsContentType[ToolResultContent](c)
586 if !ok {
587 continue
588 }
589 toolParts = append(toolParts, ToolResultPart{
590 ToolCallID: result.ToolCallID,
591 Output: result.Result,
592 ProviderOptions: ProviderOptions(result.ProviderMetadata),
593 })
594 }
595 }
596
597 var messages []Message
598 if len(assistantParts) > 0 {
599 messages = append(messages, Message{
600 Role: MessageRoleAssistant,
601 Content: assistantParts,
602 })
603 }
604 if len(toolParts) > 0 {
605 messages = append(messages, Message{
606 Role: MessageRoleTool,
607 Content: toolParts,
608 })
609 }
610 return messages
611}
612
613func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
614 if len(toolCalls) == 0 {
615 return nil, nil
616 }
617
618 // Create a map for quick tool lookup
619 toolMap := make(map[string]AgentTool)
620 for _, tool := range allTools {
621 toolMap[tool.Info().Name] = tool
622 }
623
624 // Execute all tool calls sequentially in order
625 results := make([]ToolResultContent, 0, len(toolCalls))
626
627 for _, toolCall := range toolCalls {
628 // Skip invalid tool calls - create error result
629 if toolCall.Invalid {
630 result := ToolResultContent{
631 ToolCallID: toolCall.ToolCallID,
632 ToolName: toolCall.ToolName,
633 Result: ToolResultOutputContentError{
634 Error: toolCall.ValidationError,
635 },
636 ProviderExecuted: false,
637 }
638 results = append(results, result)
639 if toolResultCallback != nil {
640 if err := toolResultCallback(result); err != nil {
641 return nil, err
642 }
643 }
644 continue
645 }
646
647 tool, exists := toolMap[toolCall.ToolName]
648 if !exists {
649 result := ToolResultContent{
650 ToolCallID: toolCall.ToolCallID,
651 ToolName: toolCall.ToolName,
652 Result: ToolResultOutputContentError{
653 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
654 },
655 ProviderExecuted: false,
656 }
657 results = append(results, result)
658 if toolResultCallback != nil {
659 if err := toolResultCallback(result); err != nil {
660 return nil, err
661 }
662 }
663 continue
664 }
665
666 // Execute the tool
667 toolResult, err := tool.Run(ctx, ToolCall{
668 ID: toolCall.ToolCallID,
669 Name: toolCall.ToolName,
670 Input: toolCall.Input,
671 })
672 if err != nil {
673 result := ToolResultContent{
674 ToolCallID: toolCall.ToolCallID,
675 ToolName: toolCall.ToolName,
676 Result: ToolResultOutputContentError{
677 Error: err,
678 },
679 ClientMetadata: toolResult.Metadata,
680 ProviderExecuted: false,
681 }
682 if toolResultCallback != nil {
683 if cbErr := toolResultCallback(result); cbErr != nil {
684 return nil, cbErr
685 }
686 }
687 return nil, err
688 }
689
690 var result ToolResultContent
691 if toolResult.IsError {
692 result = ToolResultContent{
693 ToolCallID: toolCall.ToolCallID,
694 ToolName: toolCall.ToolName,
695 Result: ToolResultOutputContentError{
696 Error: errors.New(toolResult.Content),
697 },
698 ClientMetadata: toolResult.Metadata,
699 ProviderExecuted: false,
700 }
701 } else {
702 result = ToolResultContent{
703 ToolCallID: toolCall.ToolCallID,
704 ToolName: toolCall.ToolName,
705 Result: ToolResultOutputContentText{
706 Text: toolResult.Content,
707 },
708 ClientMetadata: toolResult.Metadata,
709 ProviderExecuted: false,
710 }
711 }
712 results = append(results, result)
713 if toolResultCallback != nil {
714 if err := toolResultCallback(result); err != nil {
715 return nil, err
716 }
717 }
718 }
719
720 return results, nil
721}
722
723// Stream implements Agent.
724func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
725 // Convert AgentStreamCall to AgentCall for preparation
726 call := AgentCall{
727 Prompt: opts.Prompt,
728 Files: opts.Files,
729 Messages: opts.Messages,
730 MaxOutputTokens: opts.MaxOutputTokens,
731 Temperature: opts.Temperature,
732 TopP: opts.TopP,
733 TopK: opts.TopK,
734 PresencePenalty: opts.PresencePenalty,
735 FrequencyPenalty: opts.FrequencyPenalty,
736 ActiveTools: opts.ActiveTools,
737 ProviderOptions: opts.ProviderOptions,
738 MaxRetries: opts.MaxRetries,
739 StopWhen: opts.StopWhen,
740 PrepareStep: opts.PrepareStep,
741 RepairToolCall: opts.RepairToolCall,
742 }
743
744 call = a.prepareCall(call)
745
746 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
747 if err != nil {
748 return nil, err
749 }
750
751 var responseMessages []Message
752 var steps []StepResult
753 var totalUsage Usage
754
755 // Start agent stream
756 if opts.OnAgentStart != nil {
757 opts.OnAgentStart()
758 }
759
760 for stepNumber := 0; ; stepNumber++ {
761 stepInputMessages := append(initialPrompt, responseMessages...)
762 stepModel := a.settings.model
763 stepSystemPrompt := a.settings.systemPrompt
764 stepActiveTools := call.ActiveTools
765 stepToolChoice := ToolChoiceAuto
766 disableAllTools := false
767
768 // Apply step preparation if provided
769 if call.PrepareStep != nil {
770 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
771 Model: stepModel,
772 Steps: steps,
773 StepNumber: stepNumber,
774 Messages: stepInputMessages,
775 })
776 if err != nil {
777 return nil, err
778 }
779
780 ctx = updatedCtx
781
782 if prepared.Messages != nil {
783 stepInputMessages = prepared.Messages
784 }
785 if prepared.Model != nil {
786 stepModel = prepared.Model
787 }
788 if prepared.System != nil {
789 stepSystemPrompt = *prepared.System
790 }
791 if prepared.ToolChoice != nil {
792 stepToolChoice = *prepared.ToolChoice
793 }
794 if len(prepared.ActiveTools) > 0 {
795 stepActiveTools = prepared.ActiveTools
796 }
797 disableAllTools = prepared.DisableAllTools
798 }
799
800 // Recreate prompt with potentially modified system prompt
801 if stepSystemPrompt != a.settings.systemPrompt {
802 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
803 if err != nil {
804 return nil, err
805 }
806 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
807 stepInputMessages[0] = stepPrompt[0]
808 }
809 }
810
811 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
812
813 // Start step stream
814 if opts.OnStepStart != nil {
815 _ = opts.OnStepStart(stepNumber)
816 }
817
818 // Create streaming call
819 streamCall := Call{
820 Prompt: stepInputMessages,
821 MaxOutputTokens: call.MaxOutputTokens,
822 Temperature: call.Temperature,
823 TopP: call.TopP,
824 TopK: call.TopK,
825 PresencePenalty: call.PresencePenalty,
826 FrequencyPenalty: call.FrequencyPenalty,
827 Tools: preparedTools,
828 ToolChoice: &stepToolChoice,
829 ProviderOptions: call.ProviderOptions,
830 }
831
832 // Get streaming response with retry logic
833 retryOptions := DefaultRetryOptions()
834 if call.MaxRetries != nil {
835 retryOptions.MaxRetries = *call.MaxRetries
836 }
837 retryOptions.OnRetry = call.OnRetry
838 retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions)
839
840 stream, err := retry(ctx, func() (StreamResponse, error) {
841 return stepModel.Stream(ctx, streamCall)
842 })
843 if err != nil {
844 if opts.OnError != nil {
845 opts.OnError(err)
846 }
847 return nil, err
848 }
849
850 // Process stream with tool execution
851 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
852 if err != nil {
853 if opts.OnError != nil {
854 opts.OnError(err)
855 }
856 return nil, err
857 }
858
859 steps = append(steps, stepResult)
860 totalUsage = addUsage(totalUsage, stepResult.Usage)
861
862 // Call step finished callback
863 if opts.OnStepFinish != nil {
864 _ = opts.OnStepFinish(stepResult)
865 }
866
867 // Add step messages to response messages
868 stepMessages := toResponseMessages(stepResult.Content)
869 responseMessages = append(responseMessages, stepMessages...)
870
871 // Check stop conditions
872 shouldStop := isStopConditionMet(call.StopWhen, steps)
873 if shouldStop || !shouldContinue {
874 break
875 }
876 }
877
878 // Finish agent stream
879 agentResult := &AgentResult{
880 Steps: steps,
881 Response: steps[len(steps)-1].Response,
882 TotalUsage: totalUsage,
883 }
884
885 if opts.OnFinish != nil {
886 opts.OnFinish(agentResult)
887 }
888
889 if opts.OnAgentFinish != nil {
890 _ = opts.OnAgentFinish(agentResult)
891 }
892
893 return agentResult, nil
894}
895
896func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
897 preparedTools := make([]Tool, 0, len(tools))
898
899 // If explicitly disabling all tools, return no tools
900 if disableAllTools {
901 return preparedTools
902 }
903
904 for _, tool := range tools {
905 // If activeTools has items, only include tools in the list
906 // If activeTools is empty, include all tools
907 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
908 continue
909 }
910 info := tool.Info()
911 preparedTools = append(preparedTools, FunctionTool{
912 Name: info.Name,
913 Description: info.Description,
914 InputSchema: map[string]any{
915 "type": "object",
916 "properties": info.Parameters,
917 "required": info.Required,
918 },
919 ProviderOptions: tool.ProviderOptions(),
920 })
921 }
922 return preparedTools
923}
924
925// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
926func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
927 if err := a.validateToolCall(toolCall, availableTools); err == nil {
928 return toolCall
929 } else { //nolint: revive
930 if repairFunc != nil {
931 repairOptions := ToolCallRepairOptions{
932 OriginalToolCall: toolCall,
933 ValidationError: err,
934 AvailableTools: availableTools,
935 SystemPrompt: systemPrompt,
936 Messages: messages,
937 }
938
939 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
940 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
941 return *repairedToolCall
942 }
943 }
944 }
945
946 invalidToolCall := toolCall
947 invalidToolCall.Invalid = true
948 invalidToolCall.ValidationError = err
949 return invalidToolCall
950 }
951}
952
953// validateToolCall validates a tool call against available tools and their schemas.
954func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
955 var tool AgentTool
956 for _, t := range availableTools {
957 if t.Info().Name == toolCall.ToolName {
958 tool = t
959 break
960 }
961 }
962
963 if tool == nil {
964 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
965 }
966
967 // Validate JSON parsing
968 var input map[string]any
969 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
970 return fmt.Errorf("invalid JSON input: %w", err)
971 }
972
973 // Basic schema validation (check required fields)
974 // TODO: more robust schema validation using JSON Schema or similar
975 toolInfo := tool.Info()
976 for _, required := range toolInfo.Required {
977 if _, exists := input[required]; !exists {
978 return fmt.Errorf("missing required parameter: %s", required)
979 }
980 }
981 return nil
982}
983
984func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
985 if prompt == "" {
986 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
987 }
988
989 var preparedPrompt Prompt
990
991 if system != "" {
992 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
993 }
994 preparedPrompt = append(preparedPrompt, messages...)
995 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
996 return preparedPrompt, nil
997}
998
999// WithSystemPrompt sets the system prompt for the agent.
1000func WithSystemPrompt(prompt string) AgentOption {
1001 return func(s *agentSettings) {
1002 s.systemPrompt = prompt
1003 }
1004}
1005
1006// WithMaxOutputTokens sets the maximum output tokens for the agent.
1007func WithMaxOutputTokens(tokens int64) AgentOption {
1008 return func(s *agentSettings) {
1009 s.maxOutputTokens = &tokens
1010 }
1011}
1012
1013// WithTemperature sets the temperature for the agent.
1014func WithTemperature(temp float64) AgentOption {
1015 return func(s *agentSettings) {
1016 s.temperature = &temp
1017 }
1018}
1019
1020// WithTopP sets the top-p value for the agent.
1021func WithTopP(topP float64) AgentOption {
1022 return func(s *agentSettings) {
1023 s.topP = &topP
1024 }
1025}
1026
1027// WithTopK sets the top-k value for the agent.
1028func WithTopK(topK int64) AgentOption {
1029 return func(s *agentSettings) {
1030 s.topK = &topK
1031 }
1032}
1033
1034// WithPresencePenalty sets the presence penalty for the agent.
1035func WithPresencePenalty(penalty float64) AgentOption {
1036 return func(s *agentSettings) {
1037 s.presencePenalty = &penalty
1038 }
1039}
1040
1041// WithFrequencyPenalty sets the frequency penalty for the agent.
1042func WithFrequencyPenalty(penalty float64) AgentOption {
1043 return func(s *agentSettings) {
1044 s.frequencyPenalty = &penalty
1045 }
1046}
1047
1048// WithTools sets the tools for the agent.
1049func WithTools(tools ...AgentTool) AgentOption {
1050 return func(s *agentSettings) {
1051 s.tools = append(s.tools, tools...)
1052 }
1053}
1054
1055// WithStopConditions sets the stop conditions for the agent.
1056func WithStopConditions(conditions ...StopCondition) AgentOption {
1057 return func(s *agentSettings) {
1058 s.stopWhen = append(s.stopWhen, conditions...)
1059 }
1060}
1061
1062// WithPrepareStep sets the prepare step function for the agent.
1063func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1064 return func(s *agentSettings) {
1065 s.prepareStep = fn
1066 }
1067}
1068
1069// WithRepairToolCall sets the repair tool call function for the agent.
1070func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1071 return func(s *agentSettings) {
1072 s.repairToolCall = fn
1073 }
1074}
1075
1076// WithMaxRetries sets the maximum number of retries for the agent.
1077func WithMaxRetries(maxRetries int) AgentOption {
1078 return func(s *agentSettings) {
1079 s.maxRetries = &maxRetries
1080 }
1081}
1082
1083// WithOnRetry sets the retry callback for the agent.
1084func WithOnRetry(callback OnRetryCallback) AgentOption {
1085 return func(s *agentSettings) {
1086 s.onRetry = callback
1087 }
1088}
1089
1090// processStepStream processes a single step's stream and returns the step result.
1091func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1092 var stepContent []Content
1093 var stepToolCalls []ToolCallContent
1094 var stepUsage Usage
1095 stepFinishReason := FinishReasonUnknown
1096 var stepWarnings []CallWarning
1097 var stepProviderMetadata ProviderMetadata
1098
1099 activeToolCalls := make(map[string]*ToolCallContent)
1100 activeTextContent := make(map[string]string)
1101 type reasoningContent struct {
1102 content string
1103 options ProviderMetadata
1104 }
1105 activeReasoningContent := make(map[string]reasoningContent)
1106
1107 // Process stream parts
1108 for part := range stream {
1109 // Forward all parts to chunk callback
1110 if opts.OnChunk != nil {
1111 err := opts.OnChunk(part)
1112 if err != nil {
1113 return StepResult{}, false, err
1114 }
1115 }
1116
1117 switch part.Type {
1118 case StreamPartTypeWarnings:
1119 stepWarnings = part.Warnings
1120 if opts.OnWarnings != nil {
1121 err := opts.OnWarnings(part.Warnings)
1122 if err != nil {
1123 return StepResult{}, false, err
1124 }
1125 }
1126
1127 case StreamPartTypeTextStart:
1128 activeTextContent[part.ID] = ""
1129 if opts.OnTextStart != nil {
1130 err := opts.OnTextStart(part.ID)
1131 if err != nil {
1132 return StepResult{}, false, err
1133 }
1134 }
1135
1136 case StreamPartTypeTextDelta:
1137 if _, exists := activeTextContent[part.ID]; exists {
1138 activeTextContent[part.ID] += part.Delta
1139 }
1140 if opts.OnTextDelta != nil {
1141 err := opts.OnTextDelta(part.ID, part.Delta)
1142 if err != nil {
1143 return StepResult{}, false, err
1144 }
1145 }
1146
1147 case StreamPartTypeTextEnd:
1148 if text, exists := activeTextContent[part.ID]; exists {
1149 stepContent = append(stepContent, TextContent{
1150 Text: text,
1151 ProviderMetadata: part.ProviderMetadata,
1152 })
1153 delete(activeTextContent, part.ID)
1154 }
1155 if opts.OnTextEnd != nil {
1156 err := opts.OnTextEnd(part.ID)
1157 if err != nil {
1158 return StepResult{}, false, err
1159 }
1160 }
1161
1162 case StreamPartTypeReasoningStart:
1163 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1164 if opts.OnReasoningStart != nil {
1165 content := ReasoningContent{
1166 Text: part.Delta,
1167 ProviderMetadata: part.ProviderMetadata,
1168 }
1169 err := opts.OnReasoningStart(part.ID, content)
1170 if err != nil {
1171 return StepResult{}, false, err
1172 }
1173 }
1174
1175 case StreamPartTypeReasoningDelta:
1176 if active, exists := activeReasoningContent[part.ID]; exists {
1177 active.content += part.Delta
1178 active.options = part.ProviderMetadata
1179 activeReasoningContent[part.ID] = active
1180 }
1181 if opts.OnReasoningDelta != nil {
1182 err := opts.OnReasoningDelta(part.ID, part.Delta)
1183 if err != nil {
1184 return StepResult{}, false, err
1185 }
1186 }
1187
1188 case StreamPartTypeReasoningEnd:
1189 if active, exists := activeReasoningContent[part.ID]; exists {
1190 if part.ProviderMetadata != nil {
1191 active.options = part.ProviderMetadata
1192 }
1193 content := ReasoningContent{
1194 Text: active.content,
1195 ProviderMetadata: active.options,
1196 }
1197 stepContent = append(stepContent, content)
1198 if opts.OnReasoningEnd != nil {
1199 err := opts.OnReasoningEnd(part.ID, content)
1200 if err != nil {
1201 return StepResult{}, false, err
1202 }
1203 }
1204 delete(activeReasoningContent, part.ID)
1205 }
1206
1207 case StreamPartTypeToolInputStart:
1208 activeToolCalls[part.ID] = &ToolCallContent{
1209 ToolCallID: part.ID,
1210 ToolName: part.ToolCallName,
1211 Input: "",
1212 ProviderExecuted: part.ProviderExecuted,
1213 }
1214 if opts.OnToolInputStart != nil {
1215 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1216 if err != nil {
1217 return StepResult{}, false, err
1218 }
1219 }
1220
1221 case StreamPartTypeToolInputDelta:
1222 if toolCall, exists := activeToolCalls[part.ID]; exists {
1223 toolCall.Input += part.Delta
1224 }
1225 if opts.OnToolInputDelta != nil {
1226 err := opts.OnToolInputDelta(part.ID, part.Delta)
1227 if err != nil {
1228 return StepResult{}, false, err
1229 }
1230 }
1231
1232 case StreamPartTypeToolInputEnd:
1233 if opts.OnToolInputEnd != nil {
1234 err := opts.OnToolInputEnd(part.ID)
1235 if err != nil {
1236 return StepResult{}, false, err
1237 }
1238 }
1239
1240 case StreamPartTypeToolCall:
1241 toolCall := ToolCallContent{
1242 ToolCallID: part.ID,
1243 ToolName: part.ToolCallName,
1244 Input: part.ToolCallInput,
1245 ProviderExecuted: part.ProviderExecuted,
1246 ProviderMetadata: part.ProviderMetadata,
1247 }
1248
1249 // Validate and potentially repair the tool call
1250 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1251 stepToolCalls = append(stepToolCalls, validatedToolCall)
1252 stepContent = append(stepContent, validatedToolCall)
1253
1254 if opts.OnToolCall != nil {
1255 err := opts.OnToolCall(validatedToolCall)
1256 if err != nil {
1257 return StepResult{}, false, err
1258 }
1259 }
1260
1261 // Clean up active tool call
1262 delete(activeToolCalls, part.ID)
1263
1264 case StreamPartTypeSource:
1265 sourceContent := SourceContent{
1266 SourceType: part.SourceType,
1267 ID: part.ID,
1268 URL: part.URL,
1269 Title: part.Title,
1270 ProviderMetadata: part.ProviderMetadata,
1271 }
1272 stepContent = append(stepContent, sourceContent)
1273 if opts.OnSource != nil {
1274 err := opts.OnSource(sourceContent)
1275 if err != nil {
1276 return StepResult{}, false, err
1277 }
1278 }
1279
1280 case StreamPartTypeFinish:
1281 stepUsage = part.Usage
1282 stepFinishReason = part.FinishReason
1283 stepProviderMetadata = part.ProviderMetadata
1284 if opts.OnStreamFinish != nil {
1285 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1286 if err != nil {
1287 return StepResult{}, false, err
1288 }
1289 }
1290
1291 case StreamPartTypeError:
1292 if opts.OnError != nil {
1293 opts.OnError(part.Error)
1294 }
1295 return StepResult{}, false, part.Error
1296 }
1297 }
1298
1299 // Execute tools if any
1300 var toolResults []ToolResultContent
1301 if len(stepToolCalls) > 0 {
1302 var err error
1303 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1304 if err != nil {
1305 return StepResult{}, false, err
1306 }
1307 // Add tool results to content
1308 for _, result := range toolResults {
1309 stepContent = append(stepContent, result)
1310 }
1311 }
1312
1313 stepResult := StepResult{
1314 Response: Response{
1315 Content: stepContent,
1316 FinishReason: stepFinishReason,
1317 Usage: stepUsage,
1318 Warnings: stepWarnings,
1319 ProviderMetadata: stepProviderMetadata,
1320 },
1321 Messages: toResponseMessages(stepContent),
1322 }
1323
1324 // Determine if we should continue (has tool calls and not stopped)
1325 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1326
1327 return stepResult, shouldContinue, nil
1328}
1329
1330func addUsage(a, b Usage) Usage {
1331 return Usage{
1332 InputTokens: a.InputTokens + b.InputTokens,
1333 OutputTokens: a.OutputTokens + b.OutputTokens,
1334 TotalTokens: a.TotalTokens + b.TotalTokens,
1335 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1336 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1337 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1338 }
1339}
1340
1341// WithHeaders sets the headers for the agent.
1342func WithHeaders(headers map[string]string) AgentOption {
1343 return func(s *agentSettings) {
1344 s.headers = headers
1345 }
1346}
1347
1348// WithProviderOptions sets the provider options for the agent.
1349func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1350 return func(s *agentSettings) {
1351 s.providerOptions = providerOptions
1352 }
1353}