1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11)
12
13// StepResult represents the result of a single step in an agent execution.
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19// StopCondition defines a function that determines when an agent should stop executing.
20type StopCondition = func(steps []StepResult) bool
21
22// StepCountIs returns a stop condition that stops after the specified number of steps.
23func StepCountIs(stepCount int) StopCondition {
24 return func(steps []StepResult) bool {
25 return len(steps) >= stepCount
26 }
27}
28
29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
30func HasToolCall(toolName string) StopCondition {
31 return func(steps []StepResult) bool {
32 if len(steps) == 0 {
33 return false
34 }
35 lastStep := steps[len(steps)-1]
36 toolCalls := lastStep.Content.ToolCalls()
37 for _, toolCall := range toolCalls {
38 if toolCall.ToolName == toolName {
39 return true
40 }
41 }
42 return false
43 }
44}
45
46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
47func HasContent(contentType ContentType) StopCondition {
48 return func(steps []StepResult) bool {
49 if len(steps) == 0 {
50 return false
51 }
52 lastStep := steps[len(steps)-1]
53 for _, content := range lastStep.Content {
54 if content.GetType() == contentType {
55 return true
56 }
57 }
58 return false
59 }
60}
61
62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
63func FinishReasonIs(reason FinishReason) StopCondition {
64 return func(steps []StepResult) bool {
65 if len(steps) == 0 {
66 return false
67 }
68 lastStep := steps[len(steps)-1]
69 return lastStep.FinishReason == reason
70 }
71}
72
73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
74func MaxTokensUsed(maxTokens int64) StopCondition {
75 return func(steps []StepResult) bool {
76 var totalTokens int64
77 for _, step := range steps {
78 totalTokens += step.Usage.TotalTokens
79 }
80 return totalTokens >= maxTokens
81 }
82}
83
84// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
85type PrepareStepFunctionOptions struct {
86 Steps []StepResult
87 StepNumber int
88 Model LanguageModel
89 Messages []Message
90}
91
92// PrepareStepResult contains the result of preparing a step in an agent execution.
93type PrepareStepResult struct {
94 Model LanguageModel
95 Messages []Message
96 System *string
97 ToolChoice *ToolChoice
98 ActiveTools []string
99 DisableAllTools bool
100}
101
102// ToolCallRepairOptions contains the options for repairing a tool call.
103type ToolCallRepairOptions struct {
104 OriginalToolCall ToolCallContent
105 ValidationError error
106 AvailableTools []AgentTool
107 SystemPrompt string
108 Messages []Message
109}
110
111type (
112 // PrepareStepFunction defines a function that prepares a step in an agent execution.
113 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
114
115 // OnStepFinishedFunction defines a function that is called when a step finishes.
116 OnStepFinishedFunction = func(step StepResult)
117
118 // RepairToolCallFunction defines a function that repairs a tool call.
119 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
120)
121
122type agentSettings struct {
123 systemPrompt string
124 maxOutputTokens *int64
125 temperature *float64
126 topP *float64
127 topK *int64
128 presencePenalty *float64
129 frequencyPenalty *float64
130 headers map[string]string
131 providerOptions ProviderOptions
132
133 // TODO: add support for provider tools
134 tools []AgentTool
135 maxRetries *int
136
137 model LanguageModel
138
139 stopWhen []StopCondition
140 prepareStep PrepareStepFunction
141 repairToolCall RepairToolCallFunction
142 onRetry OnRetryCallback
143}
144
145// AgentCall represents a call to an agent.
146type AgentCall struct {
147 Prompt string `json:"prompt"`
148 Files []FilePart `json:"files"`
149 Messages []Message `json:"messages"`
150 MaxOutputTokens *int64
151 Temperature *float64 `json:"temperature"`
152 TopP *float64 `json:"top_p"`
153 TopK *int64 `json:"top_k"`
154 PresencePenalty *float64 `json:"presence_penalty"`
155 FrequencyPenalty *float64 `json:"frequency_penalty"`
156 ActiveTools []string `json:"active_tools"`
157 ProviderOptions ProviderOptions
158 OnRetry OnRetryCallback
159 MaxRetries *int
160
161 StopWhen []StopCondition
162 PrepareStep PrepareStepFunction
163 RepairToolCall RepairToolCallFunction
164}
165
166// Agent-level callbacks.
167type (
168 // OnAgentStartFunc is called when agent starts.
169 OnAgentStartFunc func()
170
171 // OnAgentFinishFunc is called when agent finishes.
172 OnAgentFinishFunc func(result *AgentResult) error
173
174 // OnStepStartFunc is called when a step starts.
175 OnStepStartFunc func(stepNumber int) error
176
177 // OnStepFinishFunc is called when a step finishes.
178 OnStepFinishFunc func(stepResult StepResult) error
179
180 // OnFinishFunc is called when entire agent completes.
181 OnFinishFunc func(result *AgentResult)
182
183 // OnErrorFunc is called when an error occurs.
184 OnErrorFunc func(error)
185)
186
187// Stream part callbacks - called for each corresponding stream part type.
188type (
189 // OnChunkFunc is called for each stream part (catch-all).
190 OnChunkFunc func(StreamPart) error
191
192 // OnWarningsFunc is called for warnings.
193 OnWarningsFunc func(warnings []CallWarning) error
194
195 // OnTextStartFunc is called when text starts.
196 OnTextStartFunc func(id string) error
197
198 // OnTextDeltaFunc is called for text deltas.
199 OnTextDeltaFunc func(id, text string) error
200
201 // OnTextEndFunc is called when text ends.
202 OnTextEndFunc func(id string) error
203
204 // OnReasoningStartFunc is called when reasoning starts.
205 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
206
207 // OnReasoningDeltaFunc is called for reasoning deltas.
208 OnReasoningDeltaFunc func(id, text string) error
209
210 // OnReasoningEndFunc is called when reasoning ends.
211 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
212
213 // OnToolInputStartFunc is called when tool input starts.
214 OnToolInputStartFunc func(id, toolName string) error
215
216 // OnToolInputDeltaFunc is called for tool input deltas.
217 OnToolInputDeltaFunc func(id, delta string) error
218
219 // OnToolInputEndFunc is called when tool input ends.
220 OnToolInputEndFunc func(id string) error
221
222 // OnToolCallFunc is called when tool call is complete.
223 OnToolCallFunc func(toolCall ToolCallContent) error
224
225 // OnBeforeToolExecutionFunc is called right before tool execution.
226 // It can modify tool input and/or skip tool execution by returning (modifiedInput, skipExecution, error).
227 // If skipExecution is true, a synthetic error result is created with the provided error.
228 // If modifiedInput is not empty, it replaces the original tool input.
229 OnBeforeToolExecutionFunc func(toolCall ToolCallContent) (modifiedInput string, skipExecution bool, err error)
230
231 // OnToolResultFunc is called when tool execution completes.
232 OnToolResultFunc func(result ToolResultContent) error
233
234 // OnSourceFunc is called for source references.
235 OnSourceFunc func(source SourceContent) error
236
237 // OnStreamFinishFunc is called when stream finishes.
238 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
239)
240
241// AgentStreamCall represents a streaming call to an agent.
242type AgentStreamCall struct {
243 Prompt string `json:"prompt"`
244 Files []FilePart `json:"files"`
245 Messages []Message `json:"messages"`
246 MaxOutputTokens *int64
247 Temperature *float64 `json:"temperature"`
248 TopP *float64 `json:"top_p"`
249 TopK *int64 `json:"top_k"`
250 PresencePenalty *float64 `json:"presence_penalty"`
251 FrequencyPenalty *float64 `json:"frequency_penalty"`
252 ActiveTools []string `json:"active_tools"`
253 Headers map[string]string
254 ProviderOptions ProviderOptions
255 OnRetry OnRetryCallback
256 MaxRetries *int
257
258 StopWhen []StopCondition
259 PrepareStep PrepareStepFunction
260 RepairToolCall RepairToolCallFunction
261
262 // Agent-level callbacks
263 OnAgentStart OnAgentStartFunc // Called when agent starts
264 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
265 OnStepStart OnStepStartFunc // Called when a step starts
266 OnStepFinish OnStepFinishFunc // Called when a step finishes
267 OnFinish OnFinishFunc // Called when entire agent completes
268 OnError OnErrorFunc // Called when an error occurs
269
270 // Stream part callbacks - called for each corresponding stream part type
271 OnChunk OnChunkFunc // Called for each stream part (catch-all)
272 OnWarnings OnWarningsFunc // Called for warnings
273 OnTextStart OnTextStartFunc // Called when text starts
274 OnTextDelta OnTextDeltaFunc // Called for text deltas
275 OnTextEnd OnTextEndFunc // Called when text ends
276 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
277 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
278 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
279 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
280 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
281 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
282 OnToolCall OnToolCallFunc // Called when tool call is complete
283 OnBeforeToolExecution OnBeforeToolExecutionFunc // Called right before tool execution
284 OnToolResult OnToolResultFunc // Called when tool execution completes
285 OnSource OnSourceFunc // Called for source references
286 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
287}
288
289// AgentResult represents the result of an agent execution.
290type AgentResult struct {
291 Steps []StepResult
292 // Final response
293 Response Response
294 TotalUsage Usage
295}
296
297// Agent represents an AI agent that can generate responses and stream responses.
298type Agent interface {
299 Generate(context.Context, AgentCall) (*AgentResult, error)
300 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
301}
302
303// AgentOption defines a function that configures agent settings.
304type AgentOption = func(*agentSettings)
305
306type agent struct {
307 settings agentSettings
308}
309
310// NewAgent creates a new agent with the given language model and options.
311func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
312 settings := agentSettings{
313 model: model,
314 }
315 for _, o := range opts {
316 o(&settings)
317 }
318 return &agent{
319 settings: settings,
320 }
321}
322
323func (a *agent) prepareCall(call AgentCall) AgentCall {
324 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
325 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
326 call.TopP = cmp.Or(call.TopP, a.settings.topP)
327 call.TopK = cmp.Or(call.TopK, a.settings.topK)
328 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
329 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
330 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
331
332 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
333 call.StopWhen = a.settings.stopWhen
334 }
335 if call.PrepareStep == nil && a.settings.prepareStep != nil {
336 call.PrepareStep = a.settings.prepareStep
337 }
338 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
339 call.RepairToolCall = a.settings.repairToolCall
340 }
341 if call.OnRetry == nil && a.settings.onRetry != nil {
342 call.OnRetry = a.settings.onRetry
343 }
344
345 providerOptions := ProviderOptions{}
346 if a.settings.providerOptions != nil {
347 maps.Copy(providerOptions, a.settings.providerOptions)
348 }
349 if call.ProviderOptions != nil {
350 maps.Copy(providerOptions, call.ProviderOptions)
351 }
352 call.ProviderOptions = providerOptions
353
354 headers := map[string]string{}
355
356 if a.settings.headers != nil {
357 maps.Copy(headers, a.settings.headers)
358 }
359
360 return call
361}
362
363// Generate implements Agent.
364func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
365 opts = a.prepareCall(opts)
366 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
367 if err != nil {
368 return nil, err
369 }
370 var responseMessages []Message
371 var steps []StepResult
372
373 for {
374 stepInputMessages := append(initialPrompt, responseMessages...)
375 stepModel := a.settings.model
376 stepSystemPrompt := a.settings.systemPrompt
377 stepActiveTools := opts.ActiveTools
378 stepToolChoice := ToolChoiceAuto
379 disableAllTools := false
380
381 if opts.PrepareStep != nil {
382 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
383 Model: stepModel,
384 Steps: steps,
385 StepNumber: len(steps),
386 Messages: stepInputMessages,
387 })
388 if err != nil {
389 return nil, err
390 }
391
392 ctx = updatedCtx
393
394 // Apply prepared step modifications
395 if prepared.Messages != nil {
396 stepInputMessages = prepared.Messages
397 }
398 if prepared.Model != nil {
399 stepModel = prepared.Model
400 }
401 if prepared.System != nil {
402 stepSystemPrompt = *prepared.System
403 }
404 if prepared.ToolChoice != nil {
405 stepToolChoice = *prepared.ToolChoice
406 }
407 if len(prepared.ActiveTools) > 0 {
408 stepActiveTools = prepared.ActiveTools
409 }
410 disableAllTools = prepared.DisableAllTools
411 }
412
413 // Recreate prompt with potentially modified system prompt
414 if stepSystemPrompt != a.settings.systemPrompt {
415 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
416 if err != nil {
417 return nil, err
418 }
419 // Replace system message part, keep the rest
420 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
421 stepInputMessages[0] = stepPrompt[0] // Replace system message
422 }
423 }
424
425 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
426
427 retryOptions := DefaultRetryOptions()
428 if opts.MaxRetries != nil {
429 retryOptions.MaxRetries = *opts.MaxRetries
430 }
431 retryOptions.OnRetry = opts.OnRetry
432 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
433
434 result, err := retry(ctx, func() (*Response, error) {
435 return stepModel.Generate(ctx, Call{
436 Prompt: stepInputMessages,
437 MaxOutputTokens: opts.MaxOutputTokens,
438 Temperature: opts.Temperature,
439 TopP: opts.TopP,
440 TopK: opts.TopK,
441 PresencePenalty: opts.PresencePenalty,
442 FrequencyPenalty: opts.FrequencyPenalty,
443 Tools: preparedTools,
444 ToolChoice: &stepToolChoice,
445 ProviderOptions: opts.ProviderOptions,
446 })
447 })
448 if err != nil {
449 return nil, err
450 }
451
452 var stepToolCalls []ToolCallContent
453 for _, content := range result.Content {
454 if content.GetType() == ContentTypeToolCall {
455 toolCall, ok := AsContentType[ToolCallContent](content)
456 if !ok {
457 continue
458 }
459
460 // Validate and potentially repair the tool call
461 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
462 stepToolCalls = append(stepToolCalls, validatedToolCall)
463 }
464 }
465
466 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil)
467
468 // Build step content with validated tool calls and tool results
469 stepContent := []Content{}
470 toolCallIndex := 0
471 for _, content := range result.Content {
472 if content.GetType() == ContentTypeToolCall {
473 // Replace with validated tool call
474 if toolCallIndex < len(stepToolCalls) {
475 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
476 toolCallIndex++
477 }
478 } else {
479 // Keep other content as-is
480 stepContent = append(stepContent, content)
481 }
482 }
483 // Add tool results
484 for _, result := range toolResults {
485 stepContent = append(stepContent, result)
486 }
487 currentStepMessages := toResponseMessages(stepContent)
488 responseMessages = append(responseMessages, currentStepMessages...)
489
490 stepResult := StepResult{
491 Response: Response{
492 Content: stepContent,
493 FinishReason: result.FinishReason,
494 Usage: result.Usage,
495 Warnings: result.Warnings,
496 ProviderMetadata: result.ProviderMetadata,
497 },
498 Messages: currentStepMessages,
499 }
500 steps = append(steps, stepResult)
501 shouldStop := isStopConditionMet(opts.StopWhen, steps)
502
503 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
504 break
505 }
506 }
507
508 totalUsage := Usage{}
509
510 for _, step := range steps {
511 usage := step.Usage
512 totalUsage.InputTokens += usage.InputTokens
513 totalUsage.OutputTokens += usage.OutputTokens
514 totalUsage.ReasoningTokens += usage.ReasoningTokens
515 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
516 totalUsage.CacheReadTokens += usage.CacheReadTokens
517 totalUsage.TotalTokens += usage.TotalTokens
518 }
519
520 agentResult := &AgentResult{
521 Steps: steps,
522 Response: steps[len(steps)-1].Response,
523 TotalUsage: totalUsage,
524 }
525 return agentResult, nil
526}
527
528func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
529 if len(conditions) == 0 {
530 return false
531 }
532
533 for _, condition := range conditions {
534 if condition(steps) {
535 return true
536 }
537 }
538 return false
539}
540
541func toResponseMessages(content []Content) []Message {
542 var assistantParts []MessagePart
543 var toolParts []MessagePart
544
545 for _, c := range content {
546 switch c.GetType() {
547 case ContentTypeText:
548 text, ok := AsContentType[TextContent](c)
549 if !ok {
550 continue
551 }
552 assistantParts = append(assistantParts, TextPart{
553 Text: text.Text,
554 ProviderOptions: ProviderOptions(text.ProviderMetadata),
555 })
556 case ContentTypeReasoning:
557 reasoning, ok := AsContentType[ReasoningContent](c)
558 if !ok {
559 continue
560 }
561 assistantParts = append(assistantParts, ReasoningPart{
562 Text: reasoning.Text,
563 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
564 })
565 case ContentTypeToolCall:
566 toolCall, ok := AsContentType[ToolCallContent](c)
567 if !ok {
568 continue
569 }
570 assistantParts = append(assistantParts, ToolCallPart{
571 ToolCallID: toolCall.ToolCallID,
572 ToolName: toolCall.ToolName,
573 Input: toolCall.Input,
574 ProviderExecuted: toolCall.ProviderExecuted,
575 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
576 })
577 case ContentTypeFile:
578 file, ok := AsContentType[FileContent](c)
579 if !ok {
580 continue
581 }
582 assistantParts = append(assistantParts, FilePart{
583 Data: file.Data,
584 MediaType: file.MediaType,
585 ProviderOptions: ProviderOptions(file.ProviderMetadata),
586 })
587 case ContentTypeSource:
588 // Sources are metadata about references used to generate the response.
589 // They don't need to be included in the conversation messages.
590 continue
591 case ContentTypeToolResult:
592 result, ok := AsContentType[ToolResultContent](c)
593 if !ok {
594 continue
595 }
596 toolParts = append(toolParts, ToolResultPart{
597 ToolCallID: result.ToolCallID,
598 Output: result.Result,
599 ProviderOptions: ProviderOptions(result.ProviderMetadata),
600 })
601 }
602 }
603
604 var messages []Message
605 if len(assistantParts) > 0 {
606 messages = append(messages, Message{
607 Role: MessageRoleAssistant,
608 Content: assistantParts,
609 })
610 }
611 if len(toolParts) > 0 {
612 messages = append(messages, Message{
613 Role: MessageRoleTool,
614 Content: toolParts,
615 })
616 }
617 return messages
618}
619
620func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, beforeExecCallback OnBeforeToolExecutionFunc, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
621 if len(toolCalls) == 0 {
622 return nil, nil
623 }
624
625 // Create a map for quick tool lookup
626 toolMap := make(map[string]AgentTool)
627 for _, tool := range allTools {
628 toolMap[tool.Info().Name] = tool
629 }
630
631 // Execute all tool calls sequentially in order
632 results := make([]ToolResultContent, 0, len(toolCalls))
633
634 for _, toolCall := range toolCalls {
635 // Skip invalid tool calls - create error result
636 if toolCall.Invalid {
637 result := ToolResultContent{
638 ToolCallID: toolCall.ToolCallID,
639 ToolName: toolCall.ToolName,
640 Result: ToolResultOutputContentError{
641 Error: toolCall.ValidationError,
642 },
643 ProviderExecuted: false,
644 }
645 results = append(results, result)
646 if toolResultCallback != nil {
647 if err := toolResultCallback(result); err != nil {
648 return nil, err
649 }
650 }
651 continue
652 }
653
654 tool, exists := toolMap[toolCall.ToolName]
655 if !exists {
656 result := ToolResultContent{
657 ToolCallID: toolCall.ToolCallID,
658 ToolName: toolCall.ToolName,
659 Result: ToolResultOutputContentError{
660 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
661 },
662 ProviderExecuted: false,
663 }
664 results = append(results, result)
665 if toolResultCallback != nil {
666 if err := toolResultCallback(result); err != nil {
667 return nil, err
668 }
669 }
670 continue
671 }
672
673 // Call OnBeforeToolExecution hook to allow modification or skipping
674 inputToUse := toolCall.Input
675 if beforeExecCallback != nil {
676 modifiedInput, skipExecution, err := beforeExecCallback(toolCall)
677 if err != nil {
678 // Callback returned error - create synthetic error result
679 result := ToolResultContent{
680 ToolCallID: toolCall.ToolCallID,
681 ToolName: toolCall.ToolName,
682 Result: ToolResultOutputContentError{
683 Error: err,
684 },
685 ProviderExecuted: false,
686 }
687 results = append(results, result)
688 if toolResultCallback != nil {
689 if cbErr := toolResultCallback(result); cbErr != nil {
690 return nil, cbErr
691 }
692 }
693 if skipExecution {
694 // Hook wants to skip execution
695 continue
696 } else {
697 // Hook returned error but doesn't want to skip - stop execution
698 return nil, err
699 }
700 }
701 if skipExecution {
702 // Hook wants to skip execution without error
703 result := ToolResultContent{
704 ToolCallID: toolCall.ToolCallID,
705 ToolName: toolCall.ToolName,
706 Result: ToolResultOutputContentError{
707 Error: errors.New("Tool execution skipped by hook"),
708 },
709 ProviderExecuted: false,
710 }
711 results = append(results, result)
712 if toolResultCallback != nil {
713 if cbErr := toolResultCallback(result); cbErr != nil {
714 return nil, cbErr
715 }
716 }
717 continue
718 }
719 if modifiedInput != "" {
720 inputToUse = modifiedInput
721 }
722 }
723
724 // Execute the tool
725 toolResult, err := tool.Run(ctx, ToolCall{
726 ID: toolCall.ToolCallID,
727 Name: toolCall.ToolName,
728 Input: inputToUse,
729 })
730 if err != nil {
731 result := ToolResultContent{
732 ToolCallID: toolCall.ToolCallID,
733 ToolName: toolCall.ToolName,
734 Result: ToolResultOutputContentError{
735 Error: err,
736 },
737 ClientMetadata: toolResult.Metadata,
738 ProviderExecuted: false,
739 }
740 if toolResultCallback != nil {
741 if cbErr := toolResultCallback(result); cbErr != nil {
742 return nil, cbErr
743 }
744 }
745 return nil, err
746 }
747
748 var result ToolResultContent
749 if toolResult.IsError {
750 result = ToolResultContent{
751 ToolCallID: toolCall.ToolCallID,
752 ToolName: toolCall.ToolName,
753 Result: ToolResultOutputContentError{
754 Error: errors.New(toolResult.Content),
755 },
756 ClientMetadata: toolResult.Metadata,
757 ProviderExecuted: false,
758 }
759 } else {
760 result = ToolResultContent{
761 ToolCallID: toolCall.ToolCallID,
762 ToolName: toolCall.ToolName,
763 Result: ToolResultOutputContentText{
764 Text: toolResult.Content,
765 },
766 ClientMetadata: toolResult.Metadata,
767 ProviderExecuted: false,
768 }
769 }
770 results = append(results, result)
771 if toolResultCallback != nil {
772 if err := toolResultCallback(result); err != nil {
773 return nil, err
774 }
775 }
776 }
777
778 return results, nil
779}
780
781// Stream implements Agent.
782func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
783 // Convert AgentStreamCall to AgentCall for preparation
784 call := AgentCall{
785 Prompt: opts.Prompt,
786 Files: opts.Files,
787 Messages: opts.Messages,
788 MaxOutputTokens: opts.MaxOutputTokens,
789 Temperature: opts.Temperature,
790 TopP: opts.TopP,
791 TopK: opts.TopK,
792 PresencePenalty: opts.PresencePenalty,
793 FrequencyPenalty: opts.FrequencyPenalty,
794 ActiveTools: opts.ActiveTools,
795 ProviderOptions: opts.ProviderOptions,
796 MaxRetries: opts.MaxRetries,
797 StopWhen: opts.StopWhen,
798 PrepareStep: opts.PrepareStep,
799 RepairToolCall: opts.RepairToolCall,
800 }
801
802 call = a.prepareCall(call)
803
804 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
805 if err != nil {
806 return nil, err
807 }
808
809 var responseMessages []Message
810 var steps []StepResult
811 var totalUsage Usage
812
813 // Start agent stream
814 if opts.OnAgentStart != nil {
815 opts.OnAgentStart()
816 }
817
818 for stepNumber := 0; ; stepNumber++ {
819 stepInputMessages := append(initialPrompt, responseMessages...)
820 stepModel := a.settings.model
821 stepSystemPrompt := a.settings.systemPrompt
822 stepActiveTools := call.ActiveTools
823 stepToolChoice := ToolChoiceAuto
824 disableAllTools := false
825
826 // Apply step preparation if provided
827 if call.PrepareStep != nil {
828 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
829 Model: stepModel,
830 Steps: steps,
831 StepNumber: stepNumber,
832 Messages: stepInputMessages,
833 })
834 if err != nil {
835 return nil, err
836 }
837
838 ctx = updatedCtx
839
840 if prepared.Messages != nil {
841 stepInputMessages = prepared.Messages
842 }
843 if prepared.Model != nil {
844 stepModel = prepared.Model
845 }
846 if prepared.System != nil {
847 stepSystemPrompt = *prepared.System
848 }
849 if prepared.ToolChoice != nil {
850 stepToolChoice = *prepared.ToolChoice
851 }
852 if len(prepared.ActiveTools) > 0 {
853 stepActiveTools = prepared.ActiveTools
854 }
855 disableAllTools = prepared.DisableAllTools
856 }
857
858 // Recreate prompt with potentially modified system prompt
859 if stepSystemPrompt != a.settings.systemPrompt {
860 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
861 if err != nil {
862 return nil, err
863 }
864 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
865 stepInputMessages[0] = stepPrompt[0]
866 }
867 }
868
869 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
870
871 // Start step stream
872 if opts.OnStepStart != nil {
873 _ = opts.OnStepStart(stepNumber)
874 }
875
876 // Create streaming call
877 streamCall := Call{
878 Prompt: stepInputMessages,
879 MaxOutputTokens: call.MaxOutputTokens,
880 Temperature: call.Temperature,
881 TopP: call.TopP,
882 TopK: call.TopK,
883 PresencePenalty: call.PresencePenalty,
884 FrequencyPenalty: call.FrequencyPenalty,
885 Tools: preparedTools,
886 ToolChoice: &stepToolChoice,
887 ProviderOptions: call.ProviderOptions,
888 }
889
890 // Get streaming response with retry logic
891 retryOptions := DefaultRetryOptions()
892 if call.MaxRetries != nil {
893 retryOptions.MaxRetries = *call.MaxRetries
894 }
895 retryOptions.OnRetry = call.OnRetry
896 retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions)
897
898 stream, err := retry(ctx, func() (StreamResponse, error) {
899 return stepModel.Stream(ctx, streamCall)
900 })
901 if err != nil {
902 if opts.OnError != nil {
903 opts.OnError(err)
904 }
905 return nil, err
906 }
907
908 // Process stream with tool execution
909 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
910 if err != nil {
911 if opts.OnError != nil {
912 opts.OnError(err)
913 }
914 return nil, err
915 }
916
917 steps = append(steps, stepResult)
918 totalUsage = addUsage(totalUsage, stepResult.Usage)
919
920 // Call step finished callback
921 if opts.OnStepFinish != nil {
922 _ = opts.OnStepFinish(stepResult)
923 }
924
925 // Add step messages to response messages
926 stepMessages := toResponseMessages(stepResult.Content)
927 responseMessages = append(responseMessages, stepMessages...)
928
929 // Check stop conditions
930 shouldStop := isStopConditionMet(call.StopWhen, steps)
931 if shouldStop || !shouldContinue {
932 break
933 }
934 }
935
936 // Finish agent stream
937 agentResult := &AgentResult{
938 Steps: steps,
939 Response: steps[len(steps)-1].Response,
940 TotalUsage: totalUsage,
941 }
942
943 if opts.OnFinish != nil {
944 opts.OnFinish(agentResult)
945 }
946
947 if opts.OnAgentFinish != nil {
948 _ = opts.OnAgentFinish(agentResult)
949 }
950
951 return agentResult, nil
952}
953
954func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
955 preparedTools := make([]Tool, 0, len(tools))
956
957 // If explicitly disabling all tools, return no tools
958 if disableAllTools {
959 return preparedTools
960 }
961
962 for _, tool := range tools {
963 // If activeTools has items, only include tools in the list
964 // If activeTools is empty, include all tools
965 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
966 continue
967 }
968 info := tool.Info()
969 preparedTools = append(preparedTools, FunctionTool{
970 Name: info.Name,
971 Description: info.Description,
972 InputSchema: map[string]any{
973 "type": "object",
974 "properties": info.Parameters,
975 "required": info.Required,
976 },
977 ProviderOptions: tool.ProviderOptions(),
978 })
979 }
980 return preparedTools
981}
982
983// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
984func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
985 if err := a.validateToolCall(toolCall, availableTools); err == nil {
986 return toolCall
987 } else { //nolint: revive
988 if repairFunc != nil {
989 repairOptions := ToolCallRepairOptions{
990 OriginalToolCall: toolCall,
991 ValidationError: err,
992 AvailableTools: availableTools,
993 SystemPrompt: systemPrompt,
994 Messages: messages,
995 }
996
997 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
998 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
999 return *repairedToolCall
1000 }
1001 }
1002 }
1003
1004 invalidToolCall := toolCall
1005 invalidToolCall.Invalid = true
1006 invalidToolCall.ValidationError = err
1007 return invalidToolCall
1008 }
1009}
1010
1011// validateToolCall validates a tool call against available tools and their schemas.
1012func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
1013 var tool AgentTool
1014 for _, t := range availableTools {
1015 if t.Info().Name == toolCall.ToolName {
1016 tool = t
1017 break
1018 }
1019 }
1020
1021 if tool == nil {
1022 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1023 }
1024
1025 // Validate JSON parsing
1026 var input map[string]any
1027 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1028 return fmt.Errorf("invalid JSON input: %w", err)
1029 }
1030
1031 // Basic schema validation (check required fields)
1032 // TODO: more robust schema validation using JSON Schema or similar
1033 toolInfo := tool.Info()
1034 for _, required := range toolInfo.Required {
1035 if _, exists := input[required]; !exists {
1036 return fmt.Errorf("missing required parameter: %s", required)
1037 }
1038 }
1039 return nil
1040}
1041
1042func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1043 if prompt == "" {
1044 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
1045 }
1046
1047 var preparedPrompt Prompt
1048
1049 if system != "" {
1050 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1051 }
1052 preparedPrompt = append(preparedPrompt, messages...)
1053 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1054 return preparedPrompt, nil
1055}
1056
1057// WithSystemPrompt sets the system prompt for the agent.
1058func WithSystemPrompt(prompt string) AgentOption {
1059 return func(s *agentSettings) {
1060 s.systemPrompt = prompt
1061 }
1062}
1063
1064// WithMaxOutputTokens sets the maximum output tokens for the agent.
1065func WithMaxOutputTokens(tokens int64) AgentOption {
1066 return func(s *agentSettings) {
1067 s.maxOutputTokens = &tokens
1068 }
1069}
1070
1071// WithTemperature sets the temperature for the agent.
1072func WithTemperature(temp float64) AgentOption {
1073 return func(s *agentSettings) {
1074 s.temperature = &temp
1075 }
1076}
1077
1078// WithTopP sets the top-p value for the agent.
1079func WithTopP(topP float64) AgentOption {
1080 return func(s *agentSettings) {
1081 s.topP = &topP
1082 }
1083}
1084
1085// WithTopK sets the top-k value for the agent.
1086func WithTopK(topK int64) AgentOption {
1087 return func(s *agentSettings) {
1088 s.topK = &topK
1089 }
1090}
1091
1092// WithPresencePenalty sets the presence penalty for the agent.
1093func WithPresencePenalty(penalty float64) AgentOption {
1094 return func(s *agentSettings) {
1095 s.presencePenalty = &penalty
1096 }
1097}
1098
1099// WithFrequencyPenalty sets the frequency penalty for the agent.
1100func WithFrequencyPenalty(penalty float64) AgentOption {
1101 return func(s *agentSettings) {
1102 s.frequencyPenalty = &penalty
1103 }
1104}
1105
1106// WithTools sets the tools for the agent.
1107func WithTools(tools ...AgentTool) AgentOption {
1108 return func(s *agentSettings) {
1109 s.tools = append(s.tools, tools...)
1110 }
1111}
1112
1113// WithStopConditions sets the stop conditions for the agent.
1114func WithStopConditions(conditions ...StopCondition) AgentOption {
1115 return func(s *agentSettings) {
1116 s.stopWhen = append(s.stopWhen, conditions...)
1117 }
1118}
1119
1120// WithPrepareStep sets the prepare step function for the agent.
1121func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1122 return func(s *agentSettings) {
1123 s.prepareStep = fn
1124 }
1125}
1126
1127// WithRepairToolCall sets the repair tool call function for the agent.
1128func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1129 return func(s *agentSettings) {
1130 s.repairToolCall = fn
1131 }
1132}
1133
1134// WithMaxRetries sets the maximum number of retries for the agent.
1135func WithMaxRetries(maxRetries int) AgentOption {
1136 return func(s *agentSettings) {
1137 s.maxRetries = &maxRetries
1138 }
1139}
1140
1141// WithOnRetry sets the retry callback for the agent.
1142func WithOnRetry(callback OnRetryCallback) AgentOption {
1143 return func(s *agentSettings) {
1144 s.onRetry = callback
1145 }
1146}
1147
1148// processStepStream processes a single step's stream and returns the step result.
1149func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1150 var stepContent []Content
1151 var stepToolCalls []ToolCallContent
1152 var stepUsage Usage
1153 stepFinishReason := FinishReasonUnknown
1154 var stepWarnings []CallWarning
1155 var stepProviderMetadata ProviderMetadata
1156
1157 activeToolCalls := make(map[string]*ToolCallContent)
1158 activeTextContent := make(map[string]string)
1159 type reasoningContent struct {
1160 content string
1161 options ProviderMetadata
1162 }
1163 activeReasoningContent := make(map[string]reasoningContent)
1164
1165 // Process stream parts
1166 for part := range stream {
1167 // Forward all parts to chunk callback
1168 if opts.OnChunk != nil {
1169 err := opts.OnChunk(part)
1170 if err != nil {
1171 return StepResult{}, false, err
1172 }
1173 }
1174
1175 switch part.Type {
1176 case StreamPartTypeWarnings:
1177 stepWarnings = part.Warnings
1178 if opts.OnWarnings != nil {
1179 err := opts.OnWarnings(part.Warnings)
1180 if err != nil {
1181 return StepResult{}, false, err
1182 }
1183 }
1184
1185 case StreamPartTypeTextStart:
1186 activeTextContent[part.ID] = ""
1187 if opts.OnTextStart != nil {
1188 err := opts.OnTextStart(part.ID)
1189 if err != nil {
1190 return StepResult{}, false, err
1191 }
1192 }
1193
1194 case StreamPartTypeTextDelta:
1195 if _, exists := activeTextContent[part.ID]; exists {
1196 activeTextContent[part.ID] += part.Delta
1197 }
1198 if opts.OnTextDelta != nil {
1199 err := opts.OnTextDelta(part.ID, part.Delta)
1200 if err != nil {
1201 return StepResult{}, false, err
1202 }
1203 }
1204
1205 case StreamPartTypeTextEnd:
1206 if text, exists := activeTextContent[part.ID]; exists {
1207 stepContent = append(stepContent, TextContent{
1208 Text: text,
1209 ProviderMetadata: part.ProviderMetadata,
1210 })
1211 delete(activeTextContent, part.ID)
1212 }
1213 if opts.OnTextEnd != nil {
1214 err := opts.OnTextEnd(part.ID)
1215 if err != nil {
1216 return StepResult{}, false, err
1217 }
1218 }
1219
1220 case StreamPartTypeReasoningStart:
1221 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1222 if opts.OnReasoningStart != nil {
1223 content := ReasoningContent{
1224 Text: part.Delta,
1225 ProviderMetadata: part.ProviderMetadata,
1226 }
1227 err := opts.OnReasoningStart(part.ID, content)
1228 if err != nil {
1229 return StepResult{}, false, err
1230 }
1231 }
1232
1233 case StreamPartTypeReasoningDelta:
1234 if active, exists := activeReasoningContent[part.ID]; exists {
1235 active.content += part.Delta
1236 active.options = part.ProviderMetadata
1237 activeReasoningContent[part.ID] = active
1238 }
1239 if opts.OnReasoningDelta != nil {
1240 err := opts.OnReasoningDelta(part.ID, part.Delta)
1241 if err != nil {
1242 return StepResult{}, false, err
1243 }
1244 }
1245
1246 case StreamPartTypeReasoningEnd:
1247 if active, exists := activeReasoningContent[part.ID]; exists {
1248 if part.ProviderMetadata != nil {
1249 active.options = part.ProviderMetadata
1250 }
1251 content := ReasoningContent{
1252 Text: active.content,
1253 ProviderMetadata: active.options,
1254 }
1255 stepContent = append(stepContent, content)
1256 if opts.OnReasoningEnd != nil {
1257 err := opts.OnReasoningEnd(part.ID, content)
1258 if err != nil {
1259 return StepResult{}, false, err
1260 }
1261 }
1262 delete(activeReasoningContent, part.ID)
1263 }
1264
1265 case StreamPartTypeToolInputStart:
1266 activeToolCalls[part.ID] = &ToolCallContent{
1267 ToolCallID: part.ID,
1268 ToolName: part.ToolCallName,
1269 Input: "",
1270 ProviderExecuted: part.ProviderExecuted,
1271 }
1272 if opts.OnToolInputStart != nil {
1273 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1274 if err != nil {
1275 return StepResult{}, false, err
1276 }
1277 }
1278
1279 case StreamPartTypeToolInputDelta:
1280 if toolCall, exists := activeToolCalls[part.ID]; exists {
1281 toolCall.Input += part.Delta
1282 }
1283 if opts.OnToolInputDelta != nil {
1284 err := opts.OnToolInputDelta(part.ID, part.Delta)
1285 if err != nil {
1286 return StepResult{}, false, err
1287 }
1288 }
1289
1290 case StreamPartTypeToolInputEnd:
1291 if opts.OnToolInputEnd != nil {
1292 err := opts.OnToolInputEnd(part.ID)
1293 if err != nil {
1294 return StepResult{}, false, err
1295 }
1296 }
1297
1298 case StreamPartTypeToolCall:
1299 toolCall := ToolCallContent{
1300 ToolCallID: part.ID,
1301 ToolName: part.ToolCallName,
1302 Input: part.ToolCallInput,
1303 ProviderExecuted: part.ProviderExecuted,
1304 ProviderMetadata: part.ProviderMetadata,
1305 }
1306
1307 // Validate and potentially repair the tool call
1308 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1309 stepToolCalls = append(stepToolCalls, validatedToolCall)
1310 stepContent = append(stepContent, validatedToolCall)
1311
1312 if opts.OnToolCall != nil {
1313 err := opts.OnToolCall(validatedToolCall)
1314 if err != nil {
1315 return StepResult{}, false, err
1316 }
1317 }
1318
1319 // Clean up active tool call
1320 delete(activeToolCalls, part.ID)
1321
1322 case StreamPartTypeSource:
1323 sourceContent := SourceContent{
1324 SourceType: part.SourceType,
1325 ID: part.ID,
1326 URL: part.URL,
1327 Title: part.Title,
1328 ProviderMetadata: part.ProviderMetadata,
1329 }
1330 stepContent = append(stepContent, sourceContent)
1331 if opts.OnSource != nil {
1332 err := opts.OnSource(sourceContent)
1333 if err != nil {
1334 return StepResult{}, false, err
1335 }
1336 }
1337
1338 case StreamPartTypeFinish:
1339 stepUsage = part.Usage
1340 stepFinishReason = part.FinishReason
1341 stepProviderMetadata = part.ProviderMetadata
1342 if opts.OnStreamFinish != nil {
1343 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1344 if err != nil {
1345 return StepResult{}, false, err
1346 }
1347 }
1348
1349 case StreamPartTypeError:
1350 if opts.OnError != nil {
1351 opts.OnError(part.Error)
1352 }
1353 return StepResult{}, false, part.Error
1354 }
1355 }
1356
1357 // Execute tools if any
1358 var toolResults []ToolResultContent
1359 if len(stepToolCalls) > 0 {
1360 var err error
1361 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnBeforeToolExecution, opts.OnToolResult)
1362 if err != nil {
1363 return StepResult{}, false, err
1364 }
1365 // Add tool results to content
1366 for _, result := range toolResults {
1367 stepContent = append(stepContent, result)
1368 }
1369 }
1370
1371 stepResult := StepResult{
1372 Response: Response{
1373 Content: stepContent,
1374 FinishReason: stepFinishReason,
1375 Usage: stepUsage,
1376 Warnings: stepWarnings,
1377 ProviderMetadata: stepProviderMetadata,
1378 },
1379 Messages: toResponseMessages(stepContent),
1380 }
1381
1382 // Determine if we should continue (has tool calls and not stopped)
1383 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1384
1385 return stepResult, shouldContinue, nil
1386}
1387
1388func addUsage(a, b Usage) Usage {
1389 return Usage{
1390 InputTokens: a.InputTokens + b.InputTokens,
1391 OutputTokens: a.OutputTokens + b.OutputTokens,
1392 TotalTokens: a.TotalTokens + b.TotalTokens,
1393 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1394 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1395 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1396 }
1397}
1398
1399// WithHeaders sets the headers for the agent.
1400func WithHeaders(headers map[string]string) AgentOption {
1401 return func(s *agentSettings) {
1402 s.headers = headers
1403 }
1404}
1405
1406// WithProviderOptions sets the provider options for the agent.
1407func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1408 return func(s *agentSettings) {
1409 s.providerOptions = providerOptions
1410 }
1411}