1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12)
13
14// StepResult represents the result of a single step in an agent execution.
15type StepResult struct {
16 Response
17 Messages []Message
18}
19
20// StopCondition defines a function that determines when an agent should stop executing.
21type StopCondition = func(steps []StepResult) bool
22
23// StepCountIs returns a stop condition that stops after the specified number of steps.
24func StepCountIs(stepCount int) StopCondition {
25 return func(steps []StepResult) bool {
26 return len(steps) >= stepCount
27 }
28}
29
30// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
31func HasToolCall(toolName string) StopCondition {
32 return func(steps []StepResult) bool {
33 if len(steps) == 0 {
34 return false
35 }
36 lastStep := steps[len(steps)-1]
37 toolCalls := lastStep.Content.ToolCalls()
38 for _, toolCall := range toolCalls {
39 if toolCall.ToolName == toolName {
40 return true
41 }
42 }
43 return false
44 }
45}
46
47// HasContent returns a stop condition that stops when the specified content type appears in the last step.
48func HasContent(contentType ContentType) StopCondition {
49 return func(steps []StepResult) bool {
50 if len(steps) == 0 {
51 return false
52 }
53 lastStep := steps[len(steps)-1]
54 for _, content := range lastStep.Content {
55 if content.GetType() == contentType {
56 return true
57 }
58 }
59 return false
60 }
61}
62
63// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
64func FinishReasonIs(reason FinishReason) StopCondition {
65 return func(steps []StepResult) bool {
66 if len(steps) == 0 {
67 return false
68 }
69 lastStep := steps[len(steps)-1]
70 return lastStep.FinishReason == reason
71 }
72}
73
74// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
75func MaxTokensUsed(maxTokens int64) StopCondition {
76 return func(steps []StepResult) bool {
77 var totalTokens int64
78 for _, step := range steps {
79 totalTokens += step.Usage.TotalTokens
80 }
81 return totalTokens >= maxTokens
82 }
83}
84
85// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
86type PrepareStepFunctionOptions struct {
87 Steps []StepResult
88 StepNumber int
89 Model LanguageModel
90 Messages []Message
91}
92
93// PrepareStepResult contains the result of preparing a step in an agent execution.
94type PrepareStepResult struct {
95 Model LanguageModel
96 Messages []Message
97 System *string
98 ToolChoice *ToolChoice
99 ActiveTools []string
100 DisableAllTools bool
101}
102
103// ToolCallRepairOptions contains the options for repairing a tool call.
104type ToolCallRepairOptions struct {
105 OriginalToolCall ToolCallContent
106 ValidationError error
107 AvailableTools []AgentTool
108 SystemPrompt string
109 Messages []Message
110}
111
112type (
113 // PrepareStepFunction defines a function that prepares a step in an agent execution.
114 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
115
116 // OnStepFinishedFunction defines a function that is called when a step finishes.
117 OnStepFinishedFunction = func(step StepResult)
118
119 // RepairToolCallFunction defines a function that repairs a tool call.
120 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
121)
122
123type agentSettings struct {
124 systemPrompt string
125 maxOutputTokens *int64
126 temperature *float64
127 topP *float64
128 topK *int64
129 presencePenalty *float64
130 frequencyPenalty *float64
131 headers map[string]string
132 providerOptions ProviderOptions
133
134 // TODO: add support for provider tools
135 tools []AgentTool
136 maxRetries *int
137
138 model LanguageModel
139
140 stopWhen []StopCondition
141 prepareStep PrepareStepFunction
142 repairToolCall RepairToolCallFunction
143 onRetry OnRetryCallback
144}
145
146// AgentCall represents a call to an agent.
147type AgentCall struct {
148 Prompt string `json:"prompt"`
149 Files []FilePart `json:"files"`
150 Messages []Message `json:"messages"`
151 MaxOutputTokens *int64
152 Temperature *float64 `json:"temperature"`
153 TopP *float64 `json:"top_p"`
154 TopK *int64 `json:"top_k"`
155 PresencePenalty *float64 `json:"presence_penalty"`
156 FrequencyPenalty *float64 `json:"frequency_penalty"`
157 ActiveTools []string `json:"active_tools"`
158 ProviderOptions ProviderOptions
159 OnRetry OnRetryCallback
160 MaxRetries *int
161
162 StopWhen []StopCondition
163 PrepareStep PrepareStepFunction
164 RepairToolCall RepairToolCallFunction
165}
166
167// Agent-level callbacks.
168type (
169 // OnAgentStartFunc is called when agent starts.
170 OnAgentStartFunc func()
171
172 // OnAgentFinishFunc is called when agent finishes.
173 OnAgentFinishFunc func(result *AgentResult) error
174
175 // OnStepStartFunc is called when a step starts.
176 OnStepStartFunc func(stepNumber int) error
177
178 // OnStepFinishFunc is called when a step finishes.
179 OnStepFinishFunc func(stepResult StepResult) error
180
181 // OnFinishFunc is called when entire agent completes.
182 OnFinishFunc func(result *AgentResult)
183
184 // OnErrorFunc is called when an error occurs.
185 OnErrorFunc func(error)
186)
187
188// Stream part callbacks - called for each corresponding stream part type.
189type (
190 // OnChunkFunc is called for each stream part (catch-all).
191 OnChunkFunc func(StreamPart) error
192
193 // OnWarningsFunc is called for warnings.
194 OnWarningsFunc func(warnings []CallWarning) error
195
196 // OnTextStartFunc is called when text starts.
197 OnTextStartFunc func(id string) error
198
199 // OnTextDeltaFunc is called for text deltas.
200 OnTextDeltaFunc func(id, text string) error
201
202 // OnTextEndFunc is called when text ends.
203 OnTextEndFunc func(id string) error
204
205 // OnReasoningStartFunc is called when reasoning starts.
206 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
207
208 // OnReasoningDeltaFunc is called for reasoning deltas.
209 OnReasoningDeltaFunc func(id, text string) error
210
211 // OnReasoningEndFunc is called when reasoning ends.
212 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
213
214 // OnToolInputStartFunc is called when tool input starts.
215 OnToolInputStartFunc func(id, toolName string) error
216
217 // OnToolInputDeltaFunc is called for tool input deltas.
218 OnToolInputDeltaFunc func(id, delta string) error
219
220 // OnToolInputEndFunc is called when tool input ends.
221 OnToolInputEndFunc func(id string) error
222
223 // OnToolCallFunc is called when tool call is complete.
224 OnToolCallFunc func(toolCall ToolCallContent) error
225
226 // OnToolResultFunc is called when tool execution completes.
227 OnToolResultFunc func(result ToolResultContent) error
228
229 // OnSourceFunc is called for source references.
230 OnSourceFunc func(source SourceContent) error
231
232 // OnStreamFinishFunc is called when stream finishes.
233 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
234)
235
236// AgentStreamCall represents a streaming call to an agent.
237type AgentStreamCall struct {
238 Prompt string `json:"prompt"`
239 Files []FilePart `json:"files"`
240 Messages []Message `json:"messages"`
241 MaxOutputTokens *int64
242 Temperature *float64 `json:"temperature"`
243 TopP *float64 `json:"top_p"`
244 TopK *int64 `json:"top_k"`
245 PresencePenalty *float64 `json:"presence_penalty"`
246 FrequencyPenalty *float64 `json:"frequency_penalty"`
247 ActiveTools []string `json:"active_tools"`
248 Headers map[string]string
249 ProviderOptions ProviderOptions
250 OnRetry OnRetryCallback
251 MaxRetries *int
252
253 StopWhen []StopCondition
254 PrepareStep PrepareStepFunction
255 RepairToolCall RepairToolCallFunction
256
257 // Agent-level callbacks
258 OnAgentStart OnAgentStartFunc // Called when agent starts
259 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
260 OnStepStart OnStepStartFunc // Called when a step starts
261 OnStepFinish OnStepFinishFunc // Called when a step finishes
262 OnFinish OnFinishFunc // Called when entire agent completes
263 OnError OnErrorFunc // Called when an error occurs
264
265 // Stream part callbacks - called for each corresponding stream part type
266 OnChunk OnChunkFunc // Called for each stream part (catch-all)
267 OnWarnings OnWarningsFunc // Called for warnings
268 OnTextStart OnTextStartFunc // Called when text starts
269 OnTextDelta OnTextDeltaFunc // Called for text deltas
270 OnTextEnd OnTextEndFunc // Called when text ends
271 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
272 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
273 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
274 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
275 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
276 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
277 OnToolCall OnToolCallFunc // Called when tool call is complete
278 OnToolResult OnToolResultFunc // Called when tool execution completes
279 OnSource OnSourceFunc // Called for source references
280 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
281}
282
283// AgentResult represents the result of an agent execution.
284type AgentResult struct {
285 Steps []StepResult
286 // Final response
287 Response Response
288 TotalUsage Usage
289}
290
291// Agent represents an AI agent that can generate responses and stream responses.
292type Agent interface {
293 Generate(context.Context, AgentCall) (*AgentResult, error)
294 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
295}
296
297// AgentOption defines a function that configures agent settings.
298type AgentOption = func(*agentSettings)
299
300type agent struct {
301 settings agentSettings
302}
303
304// NewAgent creates a new agent with the given language model and options.
305func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
306 settings := agentSettings{
307 model: model,
308 }
309 for _, o := range opts {
310 o(&settings)
311 }
312 return &agent{
313 settings: settings,
314 }
315}
316
317func (a *agent) prepareCall(call AgentCall) AgentCall {
318 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
319 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
320 call.TopP = cmp.Or(call.TopP, a.settings.topP)
321 call.TopK = cmp.Or(call.TopK, a.settings.topK)
322 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
323 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
324 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
325
326 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
327 call.StopWhen = a.settings.stopWhen
328 }
329 if call.PrepareStep == nil && a.settings.prepareStep != nil {
330 call.PrepareStep = a.settings.prepareStep
331 }
332 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
333 call.RepairToolCall = a.settings.repairToolCall
334 }
335 if call.OnRetry == nil && a.settings.onRetry != nil {
336 call.OnRetry = a.settings.onRetry
337 }
338
339 providerOptions := ProviderOptions{}
340 if a.settings.providerOptions != nil {
341 maps.Copy(providerOptions, a.settings.providerOptions)
342 }
343 if call.ProviderOptions != nil {
344 maps.Copy(providerOptions, call.ProviderOptions)
345 }
346 call.ProviderOptions = providerOptions
347
348 headers := map[string]string{}
349
350 if a.settings.headers != nil {
351 maps.Copy(headers, a.settings.headers)
352 }
353
354 return call
355}
356
357// Generate implements Agent.
358func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
359 opts = a.prepareCall(opts)
360 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
361 if err != nil {
362 return nil, err
363 }
364 var responseMessages []Message
365 var steps []StepResult
366
367 for {
368 stepInputMessages := append(initialPrompt, responseMessages...)
369 stepModel := a.settings.model
370 stepSystemPrompt := a.settings.systemPrompt
371 stepActiveTools := opts.ActiveTools
372 stepToolChoice := ToolChoiceAuto
373 disableAllTools := false
374
375 if opts.PrepareStep != nil {
376 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
377 Model: stepModel,
378 Steps: steps,
379 StepNumber: len(steps),
380 Messages: stepInputMessages,
381 })
382 if err != nil {
383 return nil, err
384 }
385
386 ctx = updatedCtx
387
388 // Apply prepared step modifications
389 if prepared.Messages != nil {
390 stepInputMessages = prepared.Messages
391 }
392 if prepared.Model != nil {
393 stepModel = prepared.Model
394 }
395 if prepared.System != nil {
396 stepSystemPrompt = *prepared.System
397 }
398 if prepared.ToolChoice != nil {
399 stepToolChoice = *prepared.ToolChoice
400 }
401 if len(prepared.ActiveTools) > 0 {
402 stepActiveTools = prepared.ActiveTools
403 }
404 disableAllTools = prepared.DisableAllTools
405 }
406
407 // Recreate prompt with potentially modified system prompt
408 if stepSystemPrompt != a.settings.systemPrompt {
409 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
410 if err != nil {
411 return nil, err
412 }
413 // Replace system message part, keep the rest
414 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
415 stepInputMessages[0] = stepPrompt[0] // Replace system message
416 }
417 }
418
419 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
420
421 retryOptions := DefaultRetryOptions()
422 retryOptions.OnRetry = opts.OnRetry
423 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
424
425 result, err := retry(ctx, func() (*Response, error) {
426 return stepModel.Generate(ctx, Call{
427 Prompt: stepInputMessages,
428 MaxOutputTokens: opts.MaxOutputTokens,
429 Temperature: opts.Temperature,
430 TopP: opts.TopP,
431 TopK: opts.TopK,
432 PresencePenalty: opts.PresencePenalty,
433 FrequencyPenalty: opts.FrequencyPenalty,
434 Tools: preparedTools,
435 ToolChoice: &stepToolChoice,
436 ProviderOptions: opts.ProviderOptions,
437 })
438 })
439 if err != nil {
440 return nil, err
441 }
442
443 var stepToolCalls []ToolCallContent
444 for _, content := range result.Content {
445 if content.GetType() == ContentTypeToolCall {
446 toolCall, ok := AsContentType[ToolCallContent](content)
447 if !ok {
448 continue
449 }
450
451 // Validate and potentially repair the tool call
452 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
453 stepToolCalls = append(stepToolCalls, validatedToolCall)
454 }
455 }
456
457 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
458
459 // Build step content with validated tool calls and tool results
460 stepContent := []Content{}
461 toolCallIndex := 0
462 for _, content := range result.Content {
463 if content.GetType() == ContentTypeToolCall {
464 // Replace with validated tool call
465 if toolCallIndex < len(stepToolCalls) {
466 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
467 toolCallIndex++
468 }
469 } else {
470 // Keep other content as-is
471 stepContent = append(stepContent, content)
472 }
473 }
474 // Add tool results
475 for _, result := range toolResults {
476 stepContent = append(stepContent, result)
477 }
478 currentStepMessages := toResponseMessages(stepContent)
479 responseMessages = append(responseMessages, currentStepMessages...)
480
481 stepResult := StepResult{
482 Response: Response{
483 Content: stepContent,
484 FinishReason: result.FinishReason,
485 Usage: result.Usage,
486 Warnings: result.Warnings,
487 ProviderMetadata: result.ProviderMetadata,
488 },
489 Messages: currentStepMessages,
490 }
491 steps = append(steps, stepResult)
492 shouldStop := isStopConditionMet(opts.StopWhen, steps)
493
494 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
495 break
496 }
497 }
498
499 totalUsage := Usage{}
500
501 for _, step := range steps {
502 usage := step.Usage
503 totalUsage.InputTokens += usage.InputTokens
504 totalUsage.OutputTokens += usage.OutputTokens
505 totalUsage.ReasoningTokens += usage.ReasoningTokens
506 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
507 totalUsage.CacheReadTokens += usage.CacheReadTokens
508 totalUsage.TotalTokens += usage.TotalTokens
509 }
510
511 agentResult := &AgentResult{
512 Steps: steps,
513 Response: steps[len(steps)-1].Response,
514 TotalUsage: totalUsage,
515 }
516 return agentResult, nil
517}
518
519func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
520 if len(conditions) == 0 {
521 return false
522 }
523
524 for _, condition := range conditions {
525 if condition(steps) {
526 return true
527 }
528 }
529 return false
530}
531
532func toResponseMessages(content []Content) []Message {
533 var assistantParts []MessagePart
534 var toolParts []MessagePart
535
536 for _, c := range content {
537 switch c.GetType() {
538 case ContentTypeText:
539 text, ok := AsContentType[TextContent](c)
540 if !ok {
541 continue
542 }
543 assistantParts = append(assistantParts, TextPart{
544 Text: text.Text,
545 ProviderOptions: ProviderOptions(text.ProviderMetadata),
546 })
547 case ContentTypeReasoning:
548 reasoning, ok := AsContentType[ReasoningContent](c)
549 if !ok {
550 continue
551 }
552 assistantParts = append(assistantParts, ReasoningPart{
553 Text: reasoning.Text,
554 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
555 })
556 case ContentTypeToolCall:
557 toolCall, ok := AsContentType[ToolCallContent](c)
558 if !ok {
559 continue
560 }
561 assistantParts = append(assistantParts, ToolCallPart{
562 ToolCallID: toolCall.ToolCallID,
563 ToolName: toolCall.ToolName,
564 Input: toolCall.Input,
565 ProviderExecuted: toolCall.ProviderExecuted,
566 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
567 })
568 case ContentTypeFile:
569 file, ok := AsContentType[FileContent](c)
570 if !ok {
571 continue
572 }
573 assistantParts = append(assistantParts, FilePart{
574 Data: file.Data,
575 MediaType: file.MediaType,
576 ProviderOptions: ProviderOptions(file.ProviderMetadata),
577 })
578 case ContentTypeSource:
579 // Sources are metadata about references used to generate the response.
580 // They don't need to be included in the conversation messages.
581 continue
582 case ContentTypeToolResult:
583 result, ok := AsContentType[ToolResultContent](c)
584 if !ok {
585 continue
586 }
587 toolParts = append(toolParts, ToolResultPart{
588 ToolCallID: result.ToolCallID,
589 Output: result.Result,
590 ProviderOptions: ProviderOptions(result.ProviderMetadata),
591 })
592 }
593 }
594
595 var messages []Message
596 if len(assistantParts) > 0 {
597 messages = append(messages, Message{
598 Role: MessageRoleAssistant,
599 Content: assistantParts,
600 })
601 }
602 if len(toolParts) > 0 {
603 messages = append(messages, Message{
604 Role: MessageRoleTool,
605 Content: toolParts,
606 })
607 }
608 return messages
609}
610
611func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
612 if len(toolCalls) == 0 {
613 return nil, nil
614 }
615
616 // Create a map for quick tool lookup
617 toolMap := make(map[string]AgentTool)
618 for _, tool := range allTools {
619 toolMap[tool.Info().Name] = tool
620 }
621
622 // Execute all tool calls in parallel
623 results := make([]ToolResultContent, len(toolCalls))
624 executeErrors := make([]error, len(toolCalls))
625
626 var wg sync.WaitGroup
627
628 for i, toolCall := range toolCalls {
629 wg.Add(1)
630 go func(index int, call ToolCallContent) {
631 defer wg.Done()
632
633 // Skip invalid tool calls - create error result
634 if call.Invalid {
635 results[index] = ToolResultContent{
636 ToolCallID: call.ToolCallID,
637 ToolName: call.ToolName,
638 Result: ToolResultOutputContentError{
639 Error: call.ValidationError,
640 },
641 ProviderExecuted: false,
642 }
643 if toolResultCallback != nil {
644 err := toolResultCallback(results[index])
645 if err != nil {
646 executeErrors[index] = err
647 }
648 }
649
650 return
651 }
652
653 tool, exists := toolMap[call.ToolName]
654 if !exists {
655 results[index] = ToolResultContent{
656 ToolCallID: call.ToolCallID,
657 ToolName: call.ToolName,
658 Result: ToolResultOutputContentError{
659 Error: errors.New("Error: Tool not found: " + call.ToolName),
660 },
661 ProviderExecuted: false,
662 }
663
664 if toolResultCallback != nil {
665 err := toolResultCallback(results[index])
666 if err != nil {
667 executeErrors[index] = err
668 }
669 }
670 return
671 }
672
673 // Execute the tool
674 result, err := tool.Run(ctx, ToolCall{
675 ID: call.ToolCallID,
676 Name: call.ToolName,
677 Input: call.Input,
678 })
679 if err != nil {
680 results[index] = ToolResultContent{
681 ToolCallID: call.ToolCallID,
682 ToolName: call.ToolName,
683 Result: ToolResultOutputContentError{
684 Error: err,
685 },
686 ClientMetadata: result.Metadata,
687 ProviderExecuted: false,
688 }
689 if toolResultCallback != nil {
690 cbErr := toolResultCallback(results[index])
691 if cbErr != nil {
692 executeErrors[index] = cbErr
693 }
694 }
695 executeErrors[index] = err
696 return
697 }
698
699 if result.IsError {
700 results[index] = ToolResultContent{
701 ToolCallID: call.ToolCallID,
702 ToolName: call.ToolName,
703 Result: ToolResultOutputContentError{
704 Error: errors.New(result.Content),
705 },
706 ClientMetadata: result.Metadata,
707 ProviderExecuted: false,
708 }
709
710 if toolResultCallback != nil {
711 err := toolResultCallback(results[index])
712 if err != nil {
713 executeErrors[index] = err
714 }
715 }
716 } else {
717 results[index] = ToolResultContent{
718 ToolCallID: call.ToolCallID,
719 ToolName: toolCall.ToolName,
720 Result: ToolResultOutputContentText{
721 Text: result.Content,
722 },
723 ClientMetadata: result.Metadata,
724 ProviderExecuted: false,
725 }
726 if toolResultCallback != nil {
727 err := toolResultCallback(results[index])
728 if err != nil {
729 executeErrors[index] = err
730 }
731 }
732 }
733 }(i, toolCall)
734 }
735
736 // Wait for all tool executions to complete
737 wg.Wait()
738
739 for _, err := range executeErrors {
740 if err != nil {
741 return nil, err
742 }
743 }
744
745 return results, nil
746}
747
748// Stream implements Agent.
749func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
750 // Convert AgentStreamCall to AgentCall for preparation
751 call := AgentCall{
752 Prompt: opts.Prompt,
753 Files: opts.Files,
754 Messages: opts.Messages,
755 MaxOutputTokens: opts.MaxOutputTokens,
756 Temperature: opts.Temperature,
757 TopP: opts.TopP,
758 TopK: opts.TopK,
759 PresencePenalty: opts.PresencePenalty,
760 FrequencyPenalty: opts.FrequencyPenalty,
761 ActiveTools: opts.ActiveTools,
762 ProviderOptions: opts.ProviderOptions,
763 MaxRetries: opts.MaxRetries,
764 StopWhen: opts.StopWhen,
765 PrepareStep: opts.PrepareStep,
766 RepairToolCall: opts.RepairToolCall,
767 }
768
769 call = a.prepareCall(call)
770
771 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
772 if err != nil {
773 return nil, err
774 }
775
776 var responseMessages []Message
777 var steps []StepResult
778 var totalUsage Usage
779
780 // Start agent stream
781 if opts.OnAgentStart != nil {
782 opts.OnAgentStart()
783 }
784
785 for stepNumber := 0; ; stepNumber++ {
786 stepInputMessages := append(initialPrompt, responseMessages...)
787 stepModel := a.settings.model
788 stepSystemPrompt := a.settings.systemPrompt
789 stepActiveTools := call.ActiveTools
790 stepToolChoice := ToolChoiceAuto
791 disableAllTools := false
792
793 // Apply step preparation if provided
794 if call.PrepareStep != nil {
795 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
796 Model: stepModel,
797 Steps: steps,
798 StepNumber: stepNumber,
799 Messages: stepInputMessages,
800 })
801 if err != nil {
802 return nil, err
803 }
804
805 ctx = updatedCtx
806
807 if prepared.Messages != nil {
808 stepInputMessages = prepared.Messages
809 }
810 if prepared.Model != nil {
811 stepModel = prepared.Model
812 }
813 if prepared.System != nil {
814 stepSystemPrompt = *prepared.System
815 }
816 if prepared.ToolChoice != nil {
817 stepToolChoice = *prepared.ToolChoice
818 }
819 if len(prepared.ActiveTools) > 0 {
820 stepActiveTools = prepared.ActiveTools
821 }
822 disableAllTools = prepared.DisableAllTools
823 }
824
825 // Recreate prompt with potentially modified system prompt
826 if stepSystemPrompt != a.settings.systemPrompt {
827 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
828 if err != nil {
829 return nil, err
830 }
831 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
832 stepInputMessages[0] = stepPrompt[0]
833 }
834 }
835
836 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
837
838 // Start step stream
839 if opts.OnStepStart != nil {
840 _ = opts.OnStepStart(stepNumber)
841 }
842
843 // Create streaming call
844 streamCall := Call{
845 Prompt: stepInputMessages,
846 MaxOutputTokens: call.MaxOutputTokens,
847 Temperature: call.Temperature,
848 TopP: call.TopP,
849 TopK: call.TopK,
850 PresencePenalty: call.PresencePenalty,
851 FrequencyPenalty: call.FrequencyPenalty,
852 Tools: preparedTools,
853 ToolChoice: &stepToolChoice,
854 ProviderOptions: call.ProviderOptions,
855 }
856
857 // Get streaming response
858 stream, err := stepModel.Stream(ctx, streamCall)
859 if err != nil {
860 if opts.OnError != nil {
861 opts.OnError(err)
862 }
863 return nil, err
864 }
865
866 // Process stream with tool execution
867 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
868 if err != nil {
869 if opts.OnError != nil {
870 opts.OnError(err)
871 }
872 return nil, err
873 }
874
875 steps = append(steps, stepResult)
876 totalUsage = addUsage(totalUsage, stepResult.Usage)
877
878 // Call step finished callback
879 if opts.OnStepFinish != nil {
880 _ = opts.OnStepFinish(stepResult)
881 }
882
883 // Add step messages to response messages
884 stepMessages := toResponseMessages(stepResult.Content)
885 responseMessages = append(responseMessages, stepMessages...)
886
887 // Check stop conditions
888 shouldStop := isStopConditionMet(call.StopWhen, steps)
889 if shouldStop || !shouldContinue {
890 break
891 }
892 }
893
894 // Finish agent stream
895 agentResult := &AgentResult{
896 Steps: steps,
897 Response: steps[len(steps)-1].Response,
898 TotalUsage: totalUsage,
899 }
900
901 if opts.OnFinish != nil {
902 opts.OnFinish(agentResult)
903 }
904
905 if opts.OnAgentFinish != nil {
906 _ = opts.OnAgentFinish(agentResult)
907 }
908
909 return agentResult, nil
910}
911
912func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
913 preparedTools := make([]Tool, 0, len(tools))
914
915 // If explicitly disabling all tools, return no tools
916 if disableAllTools {
917 return preparedTools
918 }
919
920 for _, tool := range tools {
921 // If activeTools has items, only include tools in the list
922 // If activeTools is empty, include all tools
923 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
924 continue
925 }
926 info := tool.Info()
927 preparedTools = append(preparedTools, FunctionTool{
928 Name: info.Name,
929 Description: info.Description,
930 InputSchema: map[string]any{
931 "type": "object",
932 "properties": info.Parameters,
933 "required": info.Required,
934 },
935 ProviderOptions: tool.ProviderOptions(),
936 })
937 }
938 return preparedTools
939}
940
941// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
942func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
943 if err := a.validateToolCall(toolCall, availableTools); err == nil {
944 return toolCall
945 } else { //nolint: revive
946 if repairFunc != nil {
947 repairOptions := ToolCallRepairOptions{
948 OriginalToolCall: toolCall,
949 ValidationError: err,
950 AvailableTools: availableTools,
951 SystemPrompt: systemPrompt,
952 Messages: messages,
953 }
954
955 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
956 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
957 return *repairedToolCall
958 }
959 }
960 }
961
962 invalidToolCall := toolCall
963 invalidToolCall.Invalid = true
964 invalidToolCall.ValidationError = err
965 return invalidToolCall
966 }
967}
968
969// validateToolCall validates a tool call against available tools and their schemas.
970func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
971 var tool AgentTool
972 for _, t := range availableTools {
973 if t.Info().Name == toolCall.ToolName {
974 tool = t
975 break
976 }
977 }
978
979 if tool == nil {
980 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
981 }
982
983 // Validate JSON parsing
984 var input map[string]any
985 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
986 return fmt.Errorf("invalid JSON input: %w", err)
987 }
988
989 // Basic schema validation (check required fields)
990 // TODO: more robust schema validation using JSON Schema or similar
991 toolInfo := tool.Info()
992 for _, required := range toolInfo.Required {
993 if _, exists := input[required]; !exists {
994 return fmt.Errorf("missing required parameter: %s", required)
995 }
996 }
997 return nil
998}
999
1000func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1001 if prompt == "" {
1002 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
1003 }
1004
1005 var preparedPrompt Prompt
1006
1007 if system != "" {
1008 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1009 }
1010 preparedPrompt = append(preparedPrompt, messages...)
1011 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1012 return preparedPrompt, nil
1013}
1014
1015// WithSystemPrompt sets the system prompt for the agent.
1016func WithSystemPrompt(prompt string) AgentOption {
1017 return func(s *agentSettings) {
1018 s.systemPrompt = prompt
1019 }
1020}
1021
1022// WithMaxOutputTokens sets the maximum output tokens for the agent.
1023func WithMaxOutputTokens(tokens int64) AgentOption {
1024 return func(s *agentSettings) {
1025 s.maxOutputTokens = &tokens
1026 }
1027}
1028
1029// WithTemperature sets the temperature for the agent.
1030func WithTemperature(temp float64) AgentOption {
1031 return func(s *agentSettings) {
1032 s.temperature = &temp
1033 }
1034}
1035
1036// WithTopP sets the top-p value for the agent.
1037func WithTopP(topP float64) AgentOption {
1038 return func(s *agentSettings) {
1039 s.topP = &topP
1040 }
1041}
1042
1043// WithTopK sets the top-k value for the agent.
1044func WithTopK(topK int64) AgentOption {
1045 return func(s *agentSettings) {
1046 s.topK = &topK
1047 }
1048}
1049
1050// WithPresencePenalty sets the presence penalty for the agent.
1051func WithPresencePenalty(penalty float64) AgentOption {
1052 return func(s *agentSettings) {
1053 s.presencePenalty = &penalty
1054 }
1055}
1056
1057// WithFrequencyPenalty sets the frequency penalty for the agent.
1058func WithFrequencyPenalty(penalty float64) AgentOption {
1059 return func(s *agentSettings) {
1060 s.frequencyPenalty = &penalty
1061 }
1062}
1063
1064// WithTools sets the tools for the agent.
1065func WithTools(tools ...AgentTool) AgentOption {
1066 return func(s *agentSettings) {
1067 s.tools = append(s.tools, tools...)
1068 }
1069}
1070
1071// WithStopConditions sets the stop conditions for the agent.
1072func WithStopConditions(conditions ...StopCondition) AgentOption {
1073 return func(s *agentSettings) {
1074 s.stopWhen = append(s.stopWhen, conditions...)
1075 }
1076}
1077
1078// WithPrepareStep sets the prepare step function for the agent.
1079func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1080 return func(s *agentSettings) {
1081 s.prepareStep = fn
1082 }
1083}
1084
1085// WithRepairToolCall sets the repair tool call function for the agent.
1086func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1087 return func(s *agentSettings) {
1088 s.repairToolCall = fn
1089 }
1090}
1091
1092// processStepStream processes a single step's stream and returns the step result.
1093func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1094 var stepContent []Content
1095 var stepToolCalls []ToolCallContent
1096 var stepUsage Usage
1097 stepFinishReason := FinishReasonUnknown
1098 var stepWarnings []CallWarning
1099 var stepProviderMetadata ProviderMetadata
1100
1101 activeToolCalls := make(map[string]*ToolCallContent)
1102 activeTextContent := make(map[string]string)
1103 type reasoningContent struct {
1104 content string
1105 options ProviderMetadata
1106 }
1107 activeReasoningContent := make(map[string]reasoningContent)
1108
1109 // Process stream parts
1110 for part := range stream {
1111 // Forward all parts to chunk callback
1112 if opts.OnChunk != nil {
1113 err := opts.OnChunk(part)
1114 if err != nil {
1115 return StepResult{}, false, err
1116 }
1117 }
1118
1119 switch part.Type {
1120 case StreamPartTypeWarnings:
1121 stepWarnings = part.Warnings
1122 if opts.OnWarnings != nil {
1123 err := opts.OnWarnings(part.Warnings)
1124 if err != nil {
1125 return StepResult{}, false, err
1126 }
1127 }
1128
1129 case StreamPartTypeTextStart:
1130 activeTextContent[part.ID] = ""
1131 if opts.OnTextStart != nil {
1132 err := opts.OnTextStart(part.ID)
1133 if err != nil {
1134 return StepResult{}, false, err
1135 }
1136 }
1137
1138 case StreamPartTypeTextDelta:
1139 if _, exists := activeTextContent[part.ID]; exists {
1140 activeTextContent[part.ID] += part.Delta
1141 }
1142 if opts.OnTextDelta != nil {
1143 err := opts.OnTextDelta(part.ID, part.Delta)
1144 if err != nil {
1145 return StepResult{}, false, err
1146 }
1147 }
1148
1149 case StreamPartTypeTextEnd:
1150 if text, exists := activeTextContent[part.ID]; exists {
1151 stepContent = append(stepContent, TextContent{
1152 Text: text,
1153 ProviderMetadata: part.ProviderMetadata,
1154 })
1155 delete(activeTextContent, part.ID)
1156 }
1157 if opts.OnTextEnd != nil {
1158 err := opts.OnTextEnd(part.ID)
1159 if err != nil {
1160 return StepResult{}, false, err
1161 }
1162 }
1163
1164 case StreamPartTypeReasoningStart:
1165 activeReasoningContent[part.ID] = reasoningContent{content: "", options: part.ProviderMetadata}
1166 if opts.OnReasoningStart != nil {
1167 content := ReasoningContent{
1168 Text: "",
1169 ProviderMetadata: part.ProviderMetadata,
1170 }
1171 err := opts.OnReasoningStart(part.ID, content)
1172 if err != nil {
1173 return StepResult{}, false, err
1174 }
1175 }
1176
1177 case StreamPartTypeReasoningDelta:
1178 if active, exists := activeReasoningContent[part.ID]; exists {
1179 active.content += part.Delta
1180 active.options = part.ProviderMetadata
1181 activeReasoningContent[part.ID] = active
1182 }
1183 if opts.OnReasoningDelta != nil {
1184 err := opts.OnReasoningDelta(part.ID, part.Delta)
1185 if err != nil {
1186 return StepResult{}, false, err
1187 }
1188 }
1189
1190 case StreamPartTypeReasoningEnd:
1191 if active, exists := activeReasoningContent[part.ID]; exists {
1192 if part.ProviderMetadata != nil {
1193 active.options = part.ProviderMetadata
1194 }
1195 content := ReasoningContent{
1196 Text: active.content,
1197 ProviderMetadata: active.options,
1198 }
1199 stepContent = append(stepContent, content)
1200 if opts.OnReasoningEnd != nil {
1201 err := opts.OnReasoningEnd(part.ID, content)
1202 if err != nil {
1203 return StepResult{}, false, err
1204 }
1205 }
1206 delete(activeReasoningContent, part.ID)
1207 }
1208
1209 case StreamPartTypeToolInputStart:
1210 activeToolCalls[part.ID] = &ToolCallContent{
1211 ToolCallID: part.ID,
1212 ToolName: part.ToolCallName,
1213 Input: "",
1214 ProviderExecuted: part.ProviderExecuted,
1215 }
1216 if opts.OnToolInputStart != nil {
1217 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1218 if err != nil {
1219 return StepResult{}, false, err
1220 }
1221 }
1222
1223 case StreamPartTypeToolInputDelta:
1224 if toolCall, exists := activeToolCalls[part.ID]; exists {
1225 toolCall.Input += part.Delta
1226 }
1227 if opts.OnToolInputDelta != nil {
1228 err := opts.OnToolInputDelta(part.ID, part.Delta)
1229 if err != nil {
1230 return StepResult{}, false, err
1231 }
1232 }
1233
1234 case StreamPartTypeToolInputEnd:
1235 if opts.OnToolInputEnd != nil {
1236 err := opts.OnToolInputEnd(part.ID)
1237 if err != nil {
1238 return StepResult{}, false, err
1239 }
1240 }
1241
1242 case StreamPartTypeToolCall:
1243 toolCall := ToolCallContent{
1244 ToolCallID: part.ID,
1245 ToolName: part.ToolCallName,
1246 Input: part.ToolCallInput,
1247 ProviderExecuted: part.ProviderExecuted,
1248 ProviderMetadata: part.ProviderMetadata,
1249 }
1250
1251 // Validate and potentially repair the tool call
1252 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1253 stepToolCalls = append(stepToolCalls, validatedToolCall)
1254 stepContent = append(stepContent, validatedToolCall)
1255
1256 if opts.OnToolCall != nil {
1257 err := opts.OnToolCall(validatedToolCall)
1258 if err != nil {
1259 return StepResult{}, false, err
1260 }
1261 }
1262
1263 // Clean up active tool call
1264 delete(activeToolCalls, part.ID)
1265
1266 case StreamPartTypeSource:
1267 sourceContent := SourceContent{
1268 SourceType: part.SourceType,
1269 ID: part.ID,
1270 URL: part.URL,
1271 Title: part.Title,
1272 ProviderMetadata: part.ProviderMetadata,
1273 }
1274 stepContent = append(stepContent, sourceContent)
1275 if opts.OnSource != nil {
1276 err := opts.OnSource(sourceContent)
1277 if err != nil {
1278 return StepResult{}, false, err
1279 }
1280 }
1281
1282 case StreamPartTypeFinish:
1283 stepUsage = part.Usage
1284 stepFinishReason = part.FinishReason
1285 stepProviderMetadata = part.ProviderMetadata
1286 if opts.OnStreamFinish != nil {
1287 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1288 if err != nil {
1289 return StepResult{}, false, err
1290 }
1291 }
1292
1293 case StreamPartTypeError:
1294 if opts.OnError != nil {
1295 opts.OnError(part.Error)
1296 }
1297 return StepResult{}, false, part.Error
1298 }
1299 }
1300
1301 // Execute tools if any
1302 var toolResults []ToolResultContent
1303 if len(stepToolCalls) > 0 {
1304 var err error
1305 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1306 if err != nil {
1307 return StepResult{}, false, err
1308 }
1309 // Add tool results to content
1310 for _, result := range toolResults {
1311 stepContent = append(stepContent, result)
1312 }
1313 }
1314
1315 stepResult := StepResult{
1316 Response: Response{
1317 Content: stepContent,
1318 FinishReason: stepFinishReason,
1319 Usage: stepUsage,
1320 Warnings: stepWarnings,
1321 ProviderMetadata: stepProviderMetadata,
1322 },
1323 Messages: toResponseMessages(stepContent),
1324 }
1325
1326 // Determine if we should continue (has tool calls and not stopped)
1327 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1328
1329 return stepResult, shouldContinue, nil
1330}
1331
1332func addUsage(a, b Usage) Usage {
1333 return Usage{
1334 InputTokens: a.InputTokens + b.InputTokens,
1335 OutputTokens: a.OutputTokens + b.OutputTokens,
1336 TotalTokens: a.TotalTokens + b.TotalTokens,
1337 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1338 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1339 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1340 }
1341}
1342
1343// WithHeaders sets the headers for the agent.
1344func WithHeaders(headers map[string]string) AgentOption {
1345 return func(s *agentSettings) {
1346 s.headers = headers
1347 }
1348}
1349
1350// WithProviderOptions sets the provider options for the agent.
1351func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1352 return func(s *agentSettings) {
1353 s.providerOptions = providerOptions
1354 }
1355}