1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11)
12
13// StepResult represents the result of a single step in an agent execution.
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19// StopCondition defines a function that determines when an agent should stop executing.
20type StopCondition = func(steps []StepResult) bool
21
22// StepCountIs returns a stop condition that stops after the specified number of steps.
23func StepCountIs(stepCount int) StopCondition {
24 return func(steps []StepResult) bool {
25 return len(steps) >= stepCount
26 }
27}
28
29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
30func HasToolCall(toolName string) StopCondition {
31 return func(steps []StepResult) bool {
32 if len(steps) == 0 {
33 return false
34 }
35 lastStep := steps[len(steps)-1]
36 toolCalls := lastStep.Content.ToolCalls()
37 for _, toolCall := range toolCalls {
38 if toolCall.ToolName == toolName {
39 return true
40 }
41 }
42 return false
43 }
44}
45
46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
47func HasContent(contentType ContentType) StopCondition {
48 return func(steps []StepResult) bool {
49 if len(steps) == 0 {
50 return false
51 }
52 lastStep := steps[len(steps)-1]
53 for _, content := range lastStep.Content {
54 if content.GetType() == contentType {
55 return true
56 }
57 }
58 return false
59 }
60}
61
62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
63func FinishReasonIs(reason FinishReason) StopCondition {
64 return func(steps []StepResult) bool {
65 if len(steps) == 0 {
66 return false
67 }
68 lastStep := steps[len(steps)-1]
69 return lastStep.FinishReason == reason
70 }
71}
72
73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
74func MaxTokensUsed(maxTokens int64) StopCondition {
75 return func(steps []StepResult) bool {
76 var totalTokens int64
77 for _, step := range steps {
78 totalTokens += step.Usage.TotalTokens
79 }
80 return totalTokens >= maxTokens
81 }
82}
83
84// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
85type PrepareStepFunctionOptions struct {
86 Steps []StepResult
87 StepNumber int
88 Model LanguageModel
89 Messages []Message
90}
91
92// PrepareStepResult contains the result of preparing a step in an agent execution.
93type PrepareStepResult struct {
94 Model LanguageModel
95 Messages []Message
96 System *string
97 ToolChoice *ToolChoice
98 ActiveTools []string
99 DisableAllTools bool
100}
101
102// ToolCallRepairOptions contains the options for repairing a tool call.
103type ToolCallRepairOptions struct {
104 OriginalToolCall ToolCallContent
105 ValidationError error
106 AvailableTools []AgentTool
107 SystemPrompt string
108 Messages []Message
109}
110
111type (
112 // PrepareStepFunction defines a function that prepares a step in an agent execution.
113 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
114
115 // OnStepFinishedFunction defines a function that is called when a step finishes.
116 OnStepFinishedFunction = func(step StepResult)
117
118 // RepairToolCallFunction defines a function that repairs a tool call.
119 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
120)
121
122type agentSettings struct {
123 systemPrompt string
124 maxOutputTokens *int64
125 temperature *float64
126 topP *float64
127 topK *int64
128 presencePenalty *float64
129 frequencyPenalty *float64
130 headers map[string]string
131 providerOptions ProviderOptions
132
133 // TODO: add support for provider tools
134 tools []AgentTool
135 maxRetries *int
136
137 model LanguageModel
138
139 stopWhen []StopCondition
140 prepareStep PrepareStepFunction
141 repairToolCall RepairToolCallFunction
142 onRetry OnRetryCallback
143}
144
145// AgentCall represents a call to an agent.
146type AgentCall struct {
147 Prompt string `json:"prompt"`
148 Files []FilePart `json:"files"`
149 Messages []Message `json:"messages"`
150 MaxOutputTokens *int64
151 Temperature *float64 `json:"temperature"`
152 TopP *float64 `json:"top_p"`
153 TopK *int64 `json:"top_k"`
154 PresencePenalty *float64 `json:"presence_penalty"`
155 FrequencyPenalty *float64 `json:"frequency_penalty"`
156 ActiveTools []string `json:"active_tools"`
157 ProviderOptions ProviderOptions
158 OnRetry OnRetryCallback
159 MaxRetries *int
160
161 StopWhen []StopCondition
162 PrepareStep PrepareStepFunction
163 RepairToolCall RepairToolCallFunction
164}
165
166// Agent-level callbacks.
167type (
168 // OnAgentStartFunc is called when agent starts.
169 OnAgentStartFunc func()
170
171 // OnAgentFinishFunc is called when agent finishes.
172 OnAgentFinishFunc func(result *AgentResult) error
173
174 // OnStepStartFunc is called when a step starts.
175 OnStepStartFunc func(stepNumber int) error
176
177 // OnStepFinishFunc is called when a step finishes.
178 OnStepFinishFunc func(stepResult StepResult) error
179
180 // OnFinishFunc is called when entire agent completes.
181 OnFinishFunc func(result *AgentResult)
182
183 // OnErrorFunc is called when an error occurs.
184 OnErrorFunc func(error)
185)
186
187// Stream part callbacks - called for each corresponding stream part type.
188type (
189 // OnChunkFunc is called for each stream part (catch-all).
190 OnChunkFunc func(StreamPart) error
191
192 // OnWarningsFunc is called for warnings.
193 OnWarningsFunc func(warnings []CallWarning) error
194
195 // OnTextStartFunc is called when text starts.
196 OnTextStartFunc func(id string) error
197
198 // OnTextDeltaFunc is called for text deltas.
199 OnTextDeltaFunc func(id, text string) error
200
201 // OnTextEndFunc is called when text ends.
202 OnTextEndFunc func(id string) error
203
204 // OnReasoningStartFunc is called when reasoning starts.
205 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
206
207 // OnReasoningDeltaFunc is called for reasoning deltas.
208 OnReasoningDeltaFunc func(id, text string) error
209
210 // OnReasoningEndFunc is called when reasoning ends.
211 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
212
213 // OnToolInputStartFunc is called when tool input starts.
214 OnToolInputStartFunc func(id, toolName string) error
215
216 // OnToolInputDeltaFunc is called for tool input deltas.
217 OnToolInputDeltaFunc func(id, delta string) error
218
219 // OnToolInputEndFunc is called when tool input ends.
220 OnToolInputEndFunc func(id string) error
221
222 // OnToolCallFunc is called when tool call is complete.
223 OnToolCallFunc func(toolCall ToolCallContent) error
224
225 // OnToolResultFunc is called when tool execution completes.
226 OnToolResultFunc func(result ToolResultContent) error
227
228 // OnSourceFunc is called for source references.
229 OnSourceFunc func(source SourceContent) error
230
231 // OnStreamFinishFunc is called when stream finishes.
232 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
233)
234
235// AgentStreamCall represents a streaming call to an agent.
236type AgentStreamCall struct {
237 Prompt string `json:"prompt"`
238 Files []FilePart `json:"files"`
239 Messages []Message `json:"messages"`
240 MaxOutputTokens *int64
241 Temperature *float64 `json:"temperature"`
242 TopP *float64 `json:"top_p"`
243 TopK *int64 `json:"top_k"`
244 PresencePenalty *float64 `json:"presence_penalty"`
245 FrequencyPenalty *float64 `json:"frequency_penalty"`
246 ActiveTools []string `json:"active_tools"`
247 Headers map[string]string
248 ProviderOptions ProviderOptions
249 OnRetry OnRetryCallback
250 MaxRetries *int
251
252 StopWhen []StopCondition
253 PrepareStep PrepareStepFunction
254 RepairToolCall RepairToolCallFunction
255
256 // Agent-level callbacks
257 OnAgentStart OnAgentStartFunc // Called when agent starts
258 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
259 OnStepStart OnStepStartFunc // Called when a step starts
260 OnStepFinish OnStepFinishFunc // Called when a step finishes
261 OnFinish OnFinishFunc // Called when entire agent completes
262 OnError OnErrorFunc // Called when an error occurs
263
264 // Stream part callbacks - called for each corresponding stream part type
265 OnChunk OnChunkFunc // Called for each stream part (catch-all)
266 OnWarnings OnWarningsFunc // Called for warnings
267 OnTextStart OnTextStartFunc // Called when text starts
268 OnTextDelta OnTextDeltaFunc // Called for text deltas
269 OnTextEnd OnTextEndFunc // Called when text ends
270 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
271 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
272 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
273 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
274 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
275 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
276 OnToolCall OnToolCallFunc // Called when tool call is complete
277 OnToolResult OnToolResultFunc // Called when tool execution completes
278 OnSource OnSourceFunc // Called for source references
279 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
280}
281
282// AgentResult represents the result of an agent execution.
283type AgentResult struct {
284 Steps []StepResult
285 // Final response
286 Response Response
287 TotalUsage Usage
288}
289
290// Agent represents an AI agent that can generate responses and stream responses.
291type Agent interface {
292 Generate(context.Context, AgentCall) (*AgentResult, error)
293 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
294}
295
296// AgentOption defines a function that configures agent settings.
297type AgentOption = func(*agentSettings)
298
299type agent struct {
300 settings agentSettings
301}
302
303// NewAgent creates a new agent with the given language model and options.
304func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
305 settings := agentSettings{
306 model: model,
307 }
308 for _, o := range opts {
309 o(&settings)
310 }
311 return &agent{
312 settings: settings,
313 }
314}
315
316func (a *agent) prepareCall(call AgentCall) AgentCall {
317 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
318 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
319 call.TopP = cmp.Or(call.TopP, a.settings.topP)
320 call.TopK = cmp.Or(call.TopK, a.settings.topK)
321 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
322 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
323 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
324
325 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
326 call.StopWhen = a.settings.stopWhen
327 }
328 if call.PrepareStep == nil && a.settings.prepareStep != nil {
329 call.PrepareStep = a.settings.prepareStep
330 }
331 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
332 call.RepairToolCall = a.settings.repairToolCall
333 }
334 if call.OnRetry == nil && a.settings.onRetry != nil {
335 call.OnRetry = a.settings.onRetry
336 }
337
338 providerOptions := ProviderOptions{}
339 if a.settings.providerOptions != nil {
340 maps.Copy(providerOptions, a.settings.providerOptions)
341 }
342 if call.ProviderOptions != nil {
343 maps.Copy(providerOptions, call.ProviderOptions)
344 }
345 call.ProviderOptions = providerOptions
346
347 headers := map[string]string{}
348
349 if a.settings.headers != nil {
350 maps.Copy(headers, a.settings.headers)
351 }
352
353 return call
354}
355
356// Generate implements Agent.
357func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
358 opts = a.prepareCall(opts)
359 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
360 if err != nil {
361 return nil, err
362 }
363 var responseMessages []Message
364 var steps []StepResult
365
366 for {
367 stepInputMessages := append(initialPrompt, responseMessages...)
368 stepModel := a.settings.model
369 stepSystemPrompt := a.settings.systemPrompt
370 stepActiveTools := opts.ActiveTools
371 stepToolChoice := ToolChoiceAuto
372 disableAllTools := false
373
374 if opts.PrepareStep != nil {
375 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
376 Model: stepModel,
377 Steps: steps,
378 StepNumber: len(steps),
379 Messages: stepInputMessages,
380 })
381 if err != nil {
382 return nil, err
383 }
384
385 ctx = updatedCtx
386
387 // Apply prepared step modifications
388 if prepared.Messages != nil {
389 stepInputMessages = prepared.Messages
390 }
391 if prepared.Model != nil {
392 stepModel = prepared.Model
393 }
394 if prepared.System != nil {
395 stepSystemPrompt = *prepared.System
396 }
397 if prepared.ToolChoice != nil {
398 stepToolChoice = *prepared.ToolChoice
399 }
400 if len(prepared.ActiveTools) > 0 {
401 stepActiveTools = prepared.ActiveTools
402 }
403 disableAllTools = prepared.DisableAllTools
404 }
405
406 // Recreate prompt with potentially modified system prompt
407 if stepSystemPrompt != a.settings.systemPrompt {
408 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
409 if err != nil {
410 return nil, err
411 }
412 // Replace system message part, keep the rest
413 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
414 stepInputMessages[0] = stepPrompt[0] // Replace system message
415 }
416 }
417
418 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
419
420 retryOptions := DefaultRetryOptions()
421 retryOptions.OnRetry = opts.OnRetry
422 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
423
424 result, err := retry(ctx, func() (*Response, error) {
425 return stepModel.Generate(ctx, Call{
426 Prompt: stepInputMessages,
427 MaxOutputTokens: opts.MaxOutputTokens,
428 Temperature: opts.Temperature,
429 TopP: opts.TopP,
430 TopK: opts.TopK,
431 PresencePenalty: opts.PresencePenalty,
432 FrequencyPenalty: opts.FrequencyPenalty,
433 Tools: preparedTools,
434 ToolChoice: &stepToolChoice,
435 ProviderOptions: opts.ProviderOptions,
436 })
437 })
438 if err != nil {
439 return nil, err
440 }
441
442 var stepToolCalls []ToolCallContent
443 for _, content := range result.Content {
444 if content.GetType() == ContentTypeToolCall {
445 toolCall, ok := AsContentType[ToolCallContent](content)
446 if !ok {
447 continue
448 }
449
450 // Validate and potentially repair the tool call
451 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
452 stepToolCalls = append(stepToolCalls, validatedToolCall)
453 }
454 }
455
456 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
457
458 // Build step content with validated tool calls and tool results
459 stepContent := []Content{}
460 toolCallIndex := 0
461 for _, content := range result.Content {
462 if content.GetType() == ContentTypeToolCall {
463 // Replace with validated tool call
464 if toolCallIndex < len(stepToolCalls) {
465 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
466 toolCallIndex++
467 }
468 } else {
469 // Keep other content as-is
470 stepContent = append(stepContent, content)
471 }
472 }
473 // Add tool results
474 for _, result := range toolResults {
475 stepContent = append(stepContent, result)
476 }
477 currentStepMessages := toResponseMessages(stepContent)
478 responseMessages = append(responseMessages, currentStepMessages...)
479
480 stepResult := StepResult{
481 Response: Response{
482 Content: stepContent,
483 FinishReason: result.FinishReason,
484 Usage: result.Usage,
485 Warnings: result.Warnings,
486 ProviderMetadata: result.ProviderMetadata,
487 },
488 Messages: currentStepMessages,
489 }
490 steps = append(steps, stepResult)
491 shouldStop := isStopConditionMet(opts.StopWhen, steps)
492
493 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
494 break
495 }
496 }
497
498 totalUsage := Usage{}
499
500 for _, step := range steps {
501 usage := step.Usage
502 totalUsage.InputTokens += usage.InputTokens
503 totalUsage.OutputTokens += usage.OutputTokens
504 totalUsage.ReasoningTokens += usage.ReasoningTokens
505 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
506 totalUsage.CacheReadTokens += usage.CacheReadTokens
507 totalUsage.TotalTokens += usage.TotalTokens
508 }
509
510 agentResult := &AgentResult{
511 Steps: steps,
512 Response: steps[len(steps)-1].Response,
513 TotalUsage: totalUsage,
514 }
515 return agentResult, nil
516}
517
518func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
519 if len(conditions) == 0 {
520 return false
521 }
522
523 for _, condition := range conditions {
524 if condition(steps) {
525 return true
526 }
527 }
528 return false
529}
530
531func toResponseMessages(content []Content) []Message {
532 var assistantParts []MessagePart
533 var toolParts []MessagePart
534
535 for _, c := range content {
536 switch c.GetType() {
537 case ContentTypeText:
538 text, ok := AsContentType[TextContent](c)
539 if !ok {
540 continue
541 }
542 assistantParts = append(assistantParts, TextPart{
543 Text: text.Text,
544 ProviderOptions: ProviderOptions(text.ProviderMetadata),
545 })
546 case ContentTypeReasoning:
547 reasoning, ok := AsContentType[ReasoningContent](c)
548 if !ok {
549 continue
550 }
551 assistantParts = append(assistantParts, ReasoningPart{
552 Text: reasoning.Text,
553 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
554 })
555 case ContentTypeToolCall:
556 toolCall, ok := AsContentType[ToolCallContent](c)
557 if !ok {
558 continue
559 }
560 assistantParts = append(assistantParts, ToolCallPart{
561 ToolCallID: toolCall.ToolCallID,
562 ToolName: toolCall.ToolName,
563 Input: toolCall.Input,
564 ProviderExecuted: toolCall.ProviderExecuted,
565 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
566 })
567 case ContentTypeFile:
568 file, ok := AsContentType[FileContent](c)
569 if !ok {
570 continue
571 }
572 assistantParts = append(assistantParts, FilePart{
573 Data: file.Data,
574 MediaType: file.MediaType,
575 ProviderOptions: ProviderOptions(file.ProviderMetadata),
576 })
577 case ContentTypeSource:
578 // Sources are metadata about references used to generate the response.
579 // They don't need to be included in the conversation messages.
580 continue
581 case ContentTypeToolResult:
582 result, ok := AsContentType[ToolResultContent](c)
583 if !ok {
584 continue
585 }
586 toolParts = append(toolParts, ToolResultPart{
587 ToolCallID: result.ToolCallID,
588 Output: result.Result,
589 ProviderOptions: ProviderOptions(result.ProviderMetadata),
590 })
591 }
592 }
593
594 var messages []Message
595 if len(assistantParts) > 0 {
596 messages = append(messages, Message{
597 Role: MessageRoleAssistant,
598 Content: assistantParts,
599 })
600 }
601 if len(toolParts) > 0 {
602 messages = append(messages, Message{
603 Role: MessageRoleTool,
604 Content: toolParts,
605 })
606 }
607 return messages
608}
609
610func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
611 if len(toolCalls) == 0 {
612 return nil, nil
613 }
614
615 // Create a map for quick tool lookup
616 toolMap := make(map[string]AgentTool)
617 for _, tool := range allTools {
618 toolMap[tool.Info().Name] = tool
619 }
620
621 // Execute all tool calls sequentially in order
622 results := make([]ToolResultContent, 0, len(toolCalls))
623
624 for _, toolCall := range toolCalls {
625 // Skip invalid tool calls - create error result
626 if toolCall.Invalid {
627 result := ToolResultContent{
628 ToolCallID: toolCall.ToolCallID,
629 ToolName: toolCall.ToolName,
630 Result: ToolResultOutputContentError{
631 Error: toolCall.ValidationError,
632 },
633 ProviderExecuted: false,
634 }
635 results = append(results, result)
636 if toolResultCallback != nil {
637 if err := toolResultCallback(result); err != nil {
638 return nil, err
639 }
640 }
641 continue
642 }
643
644 tool, exists := toolMap[toolCall.ToolName]
645 if !exists {
646 result := ToolResultContent{
647 ToolCallID: toolCall.ToolCallID,
648 ToolName: toolCall.ToolName,
649 Result: ToolResultOutputContentError{
650 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
651 },
652 ProviderExecuted: false,
653 }
654 results = append(results, result)
655 if toolResultCallback != nil {
656 if err := toolResultCallback(result); err != nil {
657 return nil, err
658 }
659 }
660 continue
661 }
662
663 // Execute the tool
664 toolResult, err := tool.Run(ctx, ToolCall{
665 ID: toolCall.ToolCallID,
666 Name: toolCall.ToolName,
667 Input: toolCall.Input,
668 })
669 if err != nil {
670 result := ToolResultContent{
671 ToolCallID: toolCall.ToolCallID,
672 ToolName: toolCall.ToolName,
673 Result: ToolResultOutputContentError{
674 Error: err,
675 },
676 ClientMetadata: toolResult.Metadata,
677 ProviderExecuted: false,
678 }
679 if toolResultCallback != nil {
680 if cbErr := toolResultCallback(result); cbErr != nil {
681 return nil, cbErr
682 }
683 }
684 return nil, err
685 }
686
687 var result ToolResultContent
688 if toolResult.IsError {
689 result = ToolResultContent{
690 ToolCallID: toolCall.ToolCallID,
691 ToolName: toolCall.ToolName,
692 Result: ToolResultOutputContentError{
693 Error: errors.New(toolResult.Content),
694 },
695 ClientMetadata: toolResult.Metadata,
696 ProviderExecuted: false,
697 }
698 } else {
699 result = ToolResultContent{
700 ToolCallID: toolCall.ToolCallID,
701 ToolName: toolCall.ToolName,
702 Result: ToolResultOutputContentText{
703 Text: toolResult.Content,
704 },
705 ClientMetadata: toolResult.Metadata,
706 ProviderExecuted: false,
707 }
708 }
709 results = append(results, result)
710 if toolResultCallback != nil {
711 if err := toolResultCallback(result); err != nil {
712 return nil, err
713 }
714 }
715 }
716
717 return results, nil
718}
719
720// Stream implements Agent.
721func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
722 // Convert AgentStreamCall to AgentCall for preparation
723 call := AgentCall{
724 Prompt: opts.Prompt,
725 Files: opts.Files,
726 Messages: opts.Messages,
727 MaxOutputTokens: opts.MaxOutputTokens,
728 Temperature: opts.Temperature,
729 TopP: opts.TopP,
730 TopK: opts.TopK,
731 PresencePenalty: opts.PresencePenalty,
732 FrequencyPenalty: opts.FrequencyPenalty,
733 ActiveTools: opts.ActiveTools,
734 ProviderOptions: opts.ProviderOptions,
735 MaxRetries: opts.MaxRetries,
736 StopWhen: opts.StopWhen,
737 PrepareStep: opts.PrepareStep,
738 RepairToolCall: opts.RepairToolCall,
739 }
740
741 call = a.prepareCall(call)
742
743 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
744 if err != nil {
745 return nil, err
746 }
747
748 var responseMessages []Message
749 var steps []StepResult
750 var totalUsage Usage
751
752 // Start agent stream
753 if opts.OnAgentStart != nil {
754 opts.OnAgentStart()
755 }
756
757 for stepNumber := 0; ; stepNumber++ {
758 stepInputMessages := append(initialPrompt, responseMessages...)
759 stepModel := a.settings.model
760 stepSystemPrompt := a.settings.systemPrompt
761 stepActiveTools := call.ActiveTools
762 stepToolChoice := ToolChoiceAuto
763 disableAllTools := false
764
765 // Apply step preparation if provided
766 if call.PrepareStep != nil {
767 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
768 Model: stepModel,
769 Steps: steps,
770 StepNumber: stepNumber,
771 Messages: stepInputMessages,
772 })
773 if err != nil {
774 return nil, err
775 }
776
777 ctx = updatedCtx
778
779 if prepared.Messages != nil {
780 stepInputMessages = prepared.Messages
781 }
782 if prepared.Model != nil {
783 stepModel = prepared.Model
784 }
785 if prepared.System != nil {
786 stepSystemPrompt = *prepared.System
787 }
788 if prepared.ToolChoice != nil {
789 stepToolChoice = *prepared.ToolChoice
790 }
791 if len(prepared.ActiveTools) > 0 {
792 stepActiveTools = prepared.ActiveTools
793 }
794 disableAllTools = prepared.DisableAllTools
795 }
796
797 // Recreate prompt with potentially modified system prompt
798 if stepSystemPrompt != a.settings.systemPrompt {
799 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
800 if err != nil {
801 return nil, err
802 }
803 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
804 stepInputMessages[0] = stepPrompt[0]
805 }
806 }
807
808 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
809
810 // Start step stream
811 if opts.OnStepStart != nil {
812 _ = opts.OnStepStart(stepNumber)
813 }
814
815 // Create streaming call
816 streamCall := Call{
817 Prompt: stepInputMessages,
818 MaxOutputTokens: call.MaxOutputTokens,
819 Temperature: call.Temperature,
820 TopP: call.TopP,
821 TopK: call.TopK,
822 PresencePenalty: call.PresencePenalty,
823 FrequencyPenalty: call.FrequencyPenalty,
824 Tools: preparedTools,
825 ToolChoice: &stepToolChoice,
826 ProviderOptions: call.ProviderOptions,
827 }
828
829 // Get streaming response
830 stream, err := stepModel.Stream(ctx, streamCall)
831 if err != nil {
832 if opts.OnError != nil {
833 opts.OnError(err)
834 }
835 return nil, err
836 }
837
838 // Process stream with tool execution
839 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
840 if err != nil {
841 if opts.OnError != nil {
842 opts.OnError(err)
843 }
844 return nil, err
845 }
846
847 steps = append(steps, stepResult)
848 totalUsage = addUsage(totalUsage, stepResult.Usage)
849
850 // Call step finished callback
851 if opts.OnStepFinish != nil {
852 _ = opts.OnStepFinish(stepResult)
853 }
854
855 // Add step messages to response messages
856 stepMessages := toResponseMessages(stepResult.Content)
857 responseMessages = append(responseMessages, stepMessages...)
858
859 // Check stop conditions
860 shouldStop := isStopConditionMet(call.StopWhen, steps)
861 if shouldStop || !shouldContinue {
862 break
863 }
864 }
865
866 // Finish agent stream
867 agentResult := &AgentResult{
868 Steps: steps,
869 Response: steps[len(steps)-1].Response,
870 TotalUsage: totalUsage,
871 }
872
873 if opts.OnFinish != nil {
874 opts.OnFinish(agentResult)
875 }
876
877 if opts.OnAgentFinish != nil {
878 _ = opts.OnAgentFinish(agentResult)
879 }
880
881 return agentResult, nil
882}
883
884func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
885 preparedTools := make([]Tool, 0, len(tools))
886
887 // If explicitly disabling all tools, return no tools
888 if disableAllTools {
889 return preparedTools
890 }
891
892 for _, tool := range tools {
893 // If activeTools has items, only include tools in the list
894 // If activeTools is empty, include all tools
895 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
896 continue
897 }
898 info := tool.Info()
899 preparedTools = append(preparedTools, FunctionTool{
900 Name: info.Name,
901 Description: info.Description,
902 InputSchema: map[string]any{
903 "type": "object",
904 "properties": info.Parameters,
905 "required": info.Required,
906 },
907 ProviderOptions: tool.ProviderOptions(),
908 })
909 }
910 return preparedTools
911}
912
913// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
914func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
915 if err := a.validateToolCall(toolCall, availableTools); err == nil {
916 return toolCall
917 } else { //nolint: revive
918 if repairFunc != nil {
919 repairOptions := ToolCallRepairOptions{
920 OriginalToolCall: toolCall,
921 ValidationError: err,
922 AvailableTools: availableTools,
923 SystemPrompt: systemPrompt,
924 Messages: messages,
925 }
926
927 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
928 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
929 return *repairedToolCall
930 }
931 }
932 }
933
934 invalidToolCall := toolCall
935 invalidToolCall.Invalid = true
936 invalidToolCall.ValidationError = err
937 return invalidToolCall
938 }
939}
940
941// validateToolCall validates a tool call against available tools and their schemas.
942func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
943 var tool AgentTool
944 for _, t := range availableTools {
945 if t.Info().Name == toolCall.ToolName {
946 tool = t
947 break
948 }
949 }
950
951 if tool == nil {
952 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
953 }
954
955 // Validate JSON parsing
956 var input map[string]any
957 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
958 return fmt.Errorf("invalid JSON input: %w", err)
959 }
960
961 // Basic schema validation (check required fields)
962 // TODO: more robust schema validation using JSON Schema or similar
963 toolInfo := tool.Info()
964 for _, required := range toolInfo.Required {
965 if _, exists := input[required]; !exists {
966 return fmt.Errorf("missing required parameter: %s", required)
967 }
968 }
969 return nil
970}
971
972func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
973 if prompt == "" {
974 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
975 }
976
977 var preparedPrompt Prompt
978
979 if system != "" {
980 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
981 }
982 preparedPrompt = append(preparedPrompt, messages...)
983 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
984 return preparedPrompt, nil
985}
986
987// WithSystemPrompt sets the system prompt for the agent.
988func WithSystemPrompt(prompt string) AgentOption {
989 return func(s *agentSettings) {
990 s.systemPrompt = prompt
991 }
992}
993
994// WithMaxOutputTokens sets the maximum output tokens for the agent.
995func WithMaxOutputTokens(tokens int64) AgentOption {
996 return func(s *agentSettings) {
997 s.maxOutputTokens = &tokens
998 }
999}
1000
1001// WithTemperature sets the temperature for the agent.
1002func WithTemperature(temp float64) AgentOption {
1003 return func(s *agentSettings) {
1004 s.temperature = &temp
1005 }
1006}
1007
1008// WithTopP sets the top-p value for the agent.
1009func WithTopP(topP float64) AgentOption {
1010 return func(s *agentSettings) {
1011 s.topP = &topP
1012 }
1013}
1014
1015// WithTopK sets the top-k value for the agent.
1016func WithTopK(topK int64) AgentOption {
1017 return func(s *agentSettings) {
1018 s.topK = &topK
1019 }
1020}
1021
1022// WithPresencePenalty sets the presence penalty for the agent.
1023func WithPresencePenalty(penalty float64) AgentOption {
1024 return func(s *agentSettings) {
1025 s.presencePenalty = &penalty
1026 }
1027}
1028
1029// WithFrequencyPenalty sets the frequency penalty for the agent.
1030func WithFrequencyPenalty(penalty float64) AgentOption {
1031 return func(s *agentSettings) {
1032 s.frequencyPenalty = &penalty
1033 }
1034}
1035
1036// WithTools sets the tools for the agent.
1037func WithTools(tools ...AgentTool) AgentOption {
1038 return func(s *agentSettings) {
1039 s.tools = append(s.tools, tools...)
1040 }
1041}
1042
1043// WithStopConditions sets the stop conditions for the agent.
1044func WithStopConditions(conditions ...StopCondition) AgentOption {
1045 return func(s *agentSettings) {
1046 s.stopWhen = append(s.stopWhen, conditions...)
1047 }
1048}
1049
1050// WithPrepareStep sets the prepare step function for the agent.
1051func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1052 return func(s *agentSettings) {
1053 s.prepareStep = fn
1054 }
1055}
1056
1057// WithRepairToolCall sets the repair tool call function for the agent.
1058func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1059 return func(s *agentSettings) {
1060 s.repairToolCall = fn
1061 }
1062}
1063
1064// processStepStream processes a single step's stream and returns the step result.
1065func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1066 var stepContent []Content
1067 var stepToolCalls []ToolCallContent
1068 var stepUsage Usage
1069 stepFinishReason := FinishReasonUnknown
1070 var stepWarnings []CallWarning
1071 var stepProviderMetadata ProviderMetadata
1072
1073 activeToolCalls := make(map[string]*ToolCallContent)
1074 activeTextContent := make(map[string]string)
1075 type reasoningContent struct {
1076 content string
1077 options ProviderMetadata
1078 }
1079 activeReasoningContent := make(map[string]reasoningContent)
1080
1081 // Process stream parts
1082 for part := range stream {
1083 // Forward all parts to chunk callback
1084 if opts.OnChunk != nil {
1085 err := opts.OnChunk(part)
1086 if err != nil {
1087 return StepResult{}, false, err
1088 }
1089 }
1090
1091 switch part.Type {
1092 case StreamPartTypeWarnings:
1093 stepWarnings = part.Warnings
1094 if opts.OnWarnings != nil {
1095 err := opts.OnWarnings(part.Warnings)
1096 if err != nil {
1097 return StepResult{}, false, err
1098 }
1099 }
1100
1101 case StreamPartTypeTextStart:
1102 activeTextContent[part.ID] = ""
1103 if opts.OnTextStart != nil {
1104 err := opts.OnTextStart(part.ID)
1105 if err != nil {
1106 return StepResult{}, false, err
1107 }
1108 }
1109
1110 case StreamPartTypeTextDelta:
1111 if _, exists := activeTextContent[part.ID]; exists {
1112 activeTextContent[part.ID] += part.Delta
1113 }
1114 if opts.OnTextDelta != nil {
1115 err := opts.OnTextDelta(part.ID, part.Delta)
1116 if err != nil {
1117 return StepResult{}, false, err
1118 }
1119 }
1120
1121 case StreamPartTypeTextEnd:
1122 if text, exists := activeTextContent[part.ID]; exists {
1123 stepContent = append(stepContent, TextContent{
1124 Text: text,
1125 ProviderMetadata: part.ProviderMetadata,
1126 })
1127 delete(activeTextContent, part.ID)
1128 }
1129 if opts.OnTextEnd != nil {
1130 err := opts.OnTextEnd(part.ID)
1131 if err != nil {
1132 return StepResult{}, false, err
1133 }
1134 }
1135
1136 case StreamPartTypeReasoningStart:
1137 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1138 if opts.OnReasoningStart != nil {
1139 content := ReasoningContent{
1140 Text: part.Delta,
1141 ProviderMetadata: part.ProviderMetadata,
1142 }
1143 err := opts.OnReasoningStart(part.ID, content)
1144 if err != nil {
1145 return StepResult{}, false, err
1146 }
1147 }
1148
1149 case StreamPartTypeReasoningDelta:
1150 if active, exists := activeReasoningContent[part.ID]; exists {
1151 active.content += part.Delta
1152 active.options = part.ProviderMetadata
1153 activeReasoningContent[part.ID] = active
1154 }
1155 if opts.OnReasoningDelta != nil {
1156 err := opts.OnReasoningDelta(part.ID, part.Delta)
1157 if err != nil {
1158 return StepResult{}, false, err
1159 }
1160 }
1161
1162 case StreamPartTypeReasoningEnd:
1163 if active, exists := activeReasoningContent[part.ID]; exists {
1164 if part.ProviderMetadata != nil {
1165 active.options = part.ProviderMetadata
1166 }
1167 content := ReasoningContent{
1168 Text: active.content,
1169 ProviderMetadata: active.options,
1170 }
1171 stepContent = append(stepContent, content)
1172 if opts.OnReasoningEnd != nil {
1173 err := opts.OnReasoningEnd(part.ID, content)
1174 if err != nil {
1175 return StepResult{}, false, err
1176 }
1177 }
1178 delete(activeReasoningContent, part.ID)
1179 }
1180
1181 case StreamPartTypeToolInputStart:
1182 activeToolCalls[part.ID] = &ToolCallContent{
1183 ToolCallID: part.ID,
1184 ToolName: part.ToolCallName,
1185 Input: "",
1186 ProviderExecuted: part.ProviderExecuted,
1187 }
1188 if opts.OnToolInputStart != nil {
1189 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1190 if err != nil {
1191 return StepResult{}, false, err
1192 }
1193 }
1194
1195 case StreamPartTypeToolInputDelta:
1196 if toolCall, exists := activeToolCalls[part.ID]; exists {
1197 toolCall.Input += part.Delta
1198 }
1199 if opts.OnToolInputDelta != nil {
1200 err := opts.OnToolInputDelta(part.ID, part.Delta)
1201 if err != nil {
1202 return StepResult{}, false, err
1203 }
1204 }
1205
1206 case StreamPartTypeToolInputEnd:
1207 if opts.OnToolInputEnd != nil {
1208 err := opts.OnToolInputEnd(part.ID)
1209 if err != nil {
1210 return StepResult{}, false, err
1211 }
1212 }
1213
1214 case StreamPartTypeToolCall:
1215 toolCall := ToolCallContent{
1216 ToolCallID: part.ID,
1217 ToolName: part.ToolCallName,
1218 Input: part.ToolCallInput,
1219 ProviderExecuted: part.ProviderExecuted,
1220 ProviderMetadata: part.ProviderMetadata,
1221 }
1222
1223 // Validate and potentially repair the tool call
1224 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1225 stepToolCalls = append(stepToolCalls, validatedToolCall)
1226 stepContent = append(stepContent, validatedToolCall)
1227
1228 if opts.OnToolCall != nil {
1229 err := opts.OnToolCall(validatedToolCall)
1230 if err != nil {
1231 return StepResult{}, false, err
1232 }
1233 }
1234
1235 // Clean up active tool call
1236 delete(activeToolCalls, part.ID)
1237
1238 case StreamPartTypeSource:
1239 sourceContent := SourceContent{
1240 SourceType: part.SourceType,
1241 ID: part.ID,
1242 URL: part.URL,
1243 Title: part.Title,
1244 ProviderMetadata: part.ProviderMetadata,
1245 }
1246 stepContent = append(stepContent, sourceContent)
1247 if opts.OnSource != nil {
1248 err := opts.OnSource(sourceContent)
1249 if err != nil {
1250 return StepResult{}, false, err
1251 }
1252 }
1253
1254 case StreamPartTypeFinish:
1255 stepUsage = part.Usage
1256 stepFinishReason = part.FinishReason
1257 stepProviderMetadata = part.ProviderMetadata
1258 if opts.OnStreamFinish != nil {
1259 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1260 if err != nil {
1261 return StepResult{}, false, err
1262 }
1263 }
1264
1265 case StreamPartTypeError:
1266 if opts.OnError != nil {
1267 opts.OnError(part.Error)
1268 }
1269 return StepResult{}, false, part.Error
1270 }
1271 }
1272
1273 // Execute tools if any
1274 var toolResults []ToolResultContent
1275 if len(stepToolCalls) > 0 {
1276 var err error
1277 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1278 if err != nil {
1279 return StepResult{}, false, err
1280 }
1281 // Add tool results to content
1282 for _, result := range toolResults {
1283 stepContent = append(stepContent, result)
1284 }
1285 }
1286
1287 stepResult := StepResult{
1288 Response: Response{
1289 Content: stepContent,
1290 FinishReason: stepFinishReason,
1291 Usage: stepUsage,
1292 Warnings: stepWarnings,
1293 ProviderMetadata: stepProviderMetadata,
1294 },
1295 Messages: toResponseMessages(stepContent),
1296 }
1297
1298 // Determine if we should continue (has tool calls and not stopped)
1299 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1300
1301 return stepResult, shouldContinue, nil
1302}
1303
1304func addUsage(a, b Usage) Usage {
1305 return Usage{
1306 InputTokens: a.InputTokens + b.InputTokens,
1307 OutputTokens: a.OutputTokens + b.OutputTokens,
1308 TotalTokens: a.TotalTokens + b.TotalTokens,
1309 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1310 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1311 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1312 }
1313}
1314
1315// WithHeaders sets the headers for the agent.
1316func WithHeaders(headers map[string]string) AgentOption {
1317 return func(s *agentSettings) {
1318 s.headers = headers
1319 }
1320}
1321
1322// WithProviderOptions sets the provider options for the agent.
1323func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1324 return func(s *agentSettings) {
1325 s.providerOptions = providerOptions
1326 }
1327}