1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12
13 "charm.land/fantasy/schema"
14)
15
16// StepResult represents the result of a single step in an agent execution.
17type StepResult struct {
18 Response
19 Messages []Message
20}
21
22// stepExecutionResult encapsulates the result of executing a step with stream processing.
23type stepExecutionResult struct {
24 StepResult StepResult
25 ShouldContinue bool
26}
27
28// StopCondition defines a function that determines when an agent should stop executing.
29type StopCondition = func(steps []StepResult) bool
30
31// StepCountIs returns a stop condition that stops after the specified number of steps.
32func StepCountIs(stepCount int) StopCondition {
33 return func(steps []StepResult) bool {
34 return len(steps) >= stepCount
35 }
36}
37
38// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
39func HasToolCall(toolName string) StopCondition {
40 return func(steps []StepResult) bool {
41 if len(steps) == 0 {
42 return false
43 }
44 lastStep := steps[len(steps)-1]
45 toolCalls := lastStep.Content.ToolCalls()
46 for _, toolCall := range toolCalls {
47 if toolCall.ToolName == toolName {
48 return true
49 }
50 }
51 return false
52 }
53}
54
55// HasContent returns a stop condition that stops when the specified content type appears in the last step.
56func HasContent(contentType ContentType) StopCondition {
57 return func(steps []StepResult) bool {
58 if len(steps) == 0 {
59 return false
60 }
61 lastStep := steps[len(steps)-1]
62 for _, content := range lastStep.Content {
63 if content.GetType() == contentType {
64 return true
65 }
66 }
67 return false
68 }
69}
70
71// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
72func FinishReasonIs(reason FinishReason) StopCondition {
73 return func(steps []StepResult) bool {
74 if len(steps) == 0 {
75 return false
76 }
77 lastStep := steps[len(steps)-1]
78 return lastStep.FinishReason == reason
79 }
80}
81
82// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
83func MaxTokensUsed(maxTokens int64) StopCondition {
84 return func(steps []StepResult) bool {
85 var totalTokens int64
86 for _, step := range steps {
87 totalTokens += step.Usage.TotalTokens
88 }
89 return totalTokens >= maxTokens
90 }
91}
92
93// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
94type PrepareStepFunctionOptions struct {
95 Steps []StepResult
96 StepNumber int
97 Model LanguageModel
98 Messages []Message
99}
100
101// PrepareStepResult contains the result of preparing a step in an agent execution.
102type PrepareStepResult struct {
103 Model LanguageModel
104 Messages []Message
105 System *string
106 ToolChoice *ToolChoice
107 ActiveTools []string
108 DisableAllTools bool
109 Tools []AgentTool
110}
111
112// ToolCallRepairOptions contains the options for repairing a tool call.
113type ToolCallRepairOptions struct {
114 OriginalToolCall ToolCallContent
115 ValidationError error
116 AvailableTools []AgentTool
117 SystemPrompt string
118 Messages []Message
119}
120
121type (
122 // PrepareStepFunction defines a function that prepares a step in an agent execution.
123 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
124
125 // OnStepFinishedFunction defines a function that is called when a step finishes.
126 OnStepFinishedFunction = func(step StepResult)
127
128 // RepairToolCallFunction defines a function that repairs a tool call.
129 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
130)
131
132type agentSettings struct {
133 systemPrompt string
134 maxOutputTokens *int64
135 temperature *float64
136 topP *float64
137 topK *int64
138 presencePenalty *float64
139 frequencyPenalty *float64
140 headers map[string]string
141 providerOptions ProviderOptions
142
143 // TODO: add support for provider tools
144 tools []AgentTool
145 maxRetries *int
146
147 model LanguageModel
148
149 stopWhen []StopCondition
150 prepareStep PrepareStepFunction
151 repairToolCall RepairToolCallFunction
152 onRetry OnRetryCallback
153}
154
155// AgentCall represents a call to an agent.
156type AgentCall struct {
157 Prompt string `json:"prompt"`
158 Files []FilePart `json:"files"`
159 Messages []Message `json:"messages"`
160 MaxOutputTokens *int64
161 Temperature *float64 `json:"temperature"`
162 TopP *float64 `json:"top_p"`
163 TopK *int64 `json:"top_k"`
164 PresencePenalty *float64 `json:"presence_penalty"`
165 FrequencyPenalty *float64 `json:"frequency_penalty"`
166 ActiveTools []string `json:"active_tools"`
167 ProviderOptions ProviderOptions
168 OnRetry OnRetryCallback
169 MaxRetries *int
170
171 StopWhen []StopCondition
172 PrepareStep PrepareStepFunction
173 RepairToolCall RepairToolCallFunction
174}
175
176// Agent-level callbacks.
177type (
178 // OnAgentStartFunc is called when agent starts.
179 OnAgentStartFunc func()
180
181 // OnAgentFinishFunc is called when agent finishes.
182 OnAgentFinishFunc func(result *AgentResult) error
183
184 // OnStepStartFunc is called when a step starts.
185 OnStepStartFunc func(stepNumber int) error
186
187 // OnStepFinishFunc is called when a step finishes.
188 OnStepFinishFunc func(stepResult StepResult) error
189
190 // OnFinishFunc is called when entire agent completes.
191 OnFinishFunc func(result *AgentResult)
192
193 // OnErrorFunc is called when an error occurs.
194 OnErrorFunc func(error)
195)
196
197// Stream part callbacks - called for each corresponding stream part type.
198type (
199 // OnChunkFunc is called for each stream part (catch-all).
200 OnChunkFunc func(StreamPart) error
201
202 // OnWarningsFunc is called for warnings.
203 OnWarningsFunc func(warnings []CallWarning) error
204
205 // OnTextStartFunc is called when text starts.
206 OnTextStartFunc func(id string) error
207
208 // OnTextDeltaFunc is called for text deltas.
209 OnTextDeltaFunc func(id, text string) error
210
211 // OnTextEndFunc is called when text ends.
212 OnTextEndFunc func(id string) error
213
214 // OnReasoningStartFunc is called when reasoning starts.
215 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
216
217 // OnReasoningDeltaFunc is called for reasoning deltas.
218 OnReasoningDeltaFunc func(id, text string) error
219
220 // OnReasoningEndFunc is called when reasoning ends.
221 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
222
223 // OnToolInputStartFunc is called when tool input starts.
224 OnToolInputStartFunc func(id, toolName string) error
225
226 // OnToolInputDeltaFunc is called for tool input deltas.
227 OnToolInputDeltaFunc func(id, delta string) error
228
229 // OnToolInputEndFunc is called when tool input ends.
230 OnToolInputEndFunc func(id string) error
231
232 // OnToolCallFunc is called when tool call is complete.
233 OnToolCallFunc func(toolCall ToolCallContent) error
234
235 // OnToolResultFunc is called when tool execution completes.
236 OnToolResultFunc func(result ToolResultContent) error
237
238 // OnSourceFunc is called for source references.
239 OnSourceFunc func(source SourceContent) error
240
241 // OnStreamFinishFunc is called when stream finishes.
242 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
243)
244
245// AgentStreamCall represents a streaming call to an agent.
246type AgentStreamCall struct {
247 Prompt string `json:"prompt"`
248 Files []FilePart `json:"files"`
249 Messages []Message `json:"messages"`
250 MaxOutputTokens *int64
251 Temperature *float64 `json:"temperature"`
252 TopP *float64 `json:"top_p"`
253 TopK *int64 `json:"top_k"`
254 PresencePenalty *float64 `json:"presence_penalty"`
255 FrequencyPenalty *float64 `json:"frequency_penalty"`
256 ActiveTools []string `json:"active_tools"`
257 Headers map[string]string
258 ProviderOptions ProviderOptions
259 OnRetry OnRetryCallback
260 MaxRetries *int
261
262 StopWhen []StopCondition
263 PrepareStep PrepareStepFunction
264 RepairToolCall RepairToolCallFunction
265
266 // Agent-level callbacks
267 OnAgentStart OnAgentStartFunc // Called when agent starts
268 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
269 OnStepStart OnStepStartFunc // Called when a step starts
270 OnStepFinish OnStepFinishFunc // Called when a step finishes
271 OnFinish OnFinishFunc // Called when entire agent completes
272 OnError OnErrorFunc // Called when an error occurs
273
274 // Stream part callbacks - called for each corresponding stream part type
275 OnChunk OnChunkFunc // Called for each stream part (catch-all)
276 OnWarnings OnWarningsFunc // Called for warnings
277 OnTextStart OnTextStartFunc // Called when text starts
278 OnTextDelta OnTextDeltaFunc // Called for text deltas
279 OnTextEnd OnTextEndFunc // Called when text ends
280 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
281 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
282 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
283 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
284 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
285 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
286 OnToolCall OnToolCallFunc // Called when tool call is complete
287 OnToolResult OnToolResultFunc // Called when tool execution completes
288 OnSource OnSourceFunc // Called for source references
289 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
290}
291
292// AgentResult represents the result of an agent execution.
293type AgentResult struct {
294 Steps []StepResult
295 // Final response
296 Response Response
297 TotalUsage Usage
298}
299
300// Agent represents an AI agent that can generate responses and stream responses.
301type Agent interface {
302 Generate(context.Context, AgentCall) (*AgentResult, error)
303 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
304}
305
306// AgentOption defines a function that configures agent settings.
307type AgentOption = func(*agentSettings)
308
309type agent struct {
310 settings agentSettings
311}
312
313// NewAgent creates a new agent with the given language model and options.
314func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
315 settings := agentSettings{
316 model: model,
317 }
318 for _, o := range opts {
319 o(&settings)
320 }
321 return &agent{
322 settings: settings,
323 }
324}
325
326func (a *agent) prepareCall(call AgentCall) AgentCall {
327 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
328 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
329 call.TopP = cmp.Or(call.TopP, a.settings.topP)
330 call.TopK = cmp.Or(call.TopK, a.settings.topK)
331 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
332 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
333 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
334
335 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
336 call.StopWhen = a.settings.stopWhen
337 }
338 if call.PrepareStep == nil && a.settings.prepareStep != nil {
339 call.PrepareStep = a.settings.prepareStep
340 }
341 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
342 call.RepairToolCall = a.settings.repairToolCall
343 }
344 if call.OnRetry == nil && a.settings.onRetry != nil {
345 call.OnRetry = a.settings.onRetry
346 }
347
348 providerOptions := ProviderOptions{}
349 if a.settings.providerOptions != nil {
350 maps.Copy(providerOptions, a.settings.providerOptions)
351 }
352 if call.ProviderOptions != nil {
353 maps.Copy(providerOptions, call.ProviderOptions)
354 }
355 call.ProviderOptions = providerOptions
356
357 headers := map[string]string{}
358
359 if a.settings.headers != nil {
360 maps.Copy(headers, a.settings.headers)
361 }
362
363 return call
364}
365
366// Generate implements Agent.
367func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
368 opts = a.prepareCall(opts)
369 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
370 if err != nil {
371 return nil, err
372 }
373 var responseMessages []Message
374 var steps []StepResult
375
376 for {
377 stepInputMessages := append(initialPrompt, responseMessages...)
378 stepModel := a.settings.model
379 stepSystemPrompt := a.settings.systemPrompt
380 stepActiveTools := opts.ActiveTools
381 stepToolChoice := ToolChoiceAuto
382 disableAllTools := false
383 stepTools := a.settings.tools
384 if opts.PrepareStep != nil {
385 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
386 Model: stepModel,
387 Steps: steps,
388 StepNumber: len(steps),
389 Messages: stepInputMessages,
390 })
391 if err != nil {
392 return nil, err
393 }
394
395 ctx = updatedCtx
396
397 // Apply prepared step modifications
398 if prepared.Messages != nil {
399 stepInputMessages = prepared.Messages
400 }
401 if prepared.Model != nil {
402 stepModel = prepared.Model
403 }
404 if prepared.System != nil {
405 stepSystemPrompt = *prepared.System
406 }
407 if prepared.ToolChoice != nil {
408 stepToolChoice = *prepared.ToolChoice
409 }
410 if len(prepared.ActiveTools) > 0 {
411 stepActiveTools = prepared.ActiveTools
412 }
413 disableAllTools = prepared.DisableAllTools
414 if prepared.Tools != nil {
415 stepTools = prepared.Tools
416 }
417 }
418
419 // Recreate prompt with potentially modified system prompt
420 if stepSystemPrompt != a.settings.systemPrompt {
421 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
422 if err != nil {
423 return nil, err
424 }
425 // Replace system message part, keep the rest
426 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
427 stepInputMessages[0] = stepPrompt[0] // Replace system message
428 }
429 }
430
431 preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
432
433 retryOptions := DefaultRetryOptions()
434 if opts.MaxRetries != nil {
435 retryOptions.MaxRetries = *opts.MaxRetries
436 }
437 retryOptions.OnRetry = opts.OnRetry
438 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
439
440 result, err := retry(ctx, func() (*Response, error) {
441 return stepModel.Generate(ctx, Call{
442 Prompt: stepInputMessages,
443 MaxOutputTokens: opts.MaxOutputTokens,
444 Temperature: opts.Temperature,
445 TopP: opts.TopP,
446 TopK: opts.TopK,
447 PresencePenalty: opts.PresencePenalty,
448 FrequencyPenalty: opts.FrequencyPenalty,
449 Tools: preparedTools,
450 ToolChoice: &stepToolChoice,
451 ProviderOptions: opts.ProviderOptions,
452 })
453 })
454 if err != nil {
455 return nil, err
456 }
457
458 var stepToolCalls []ToolCallContent
459 for _, content := range result.Content {
460 if content.GetType() == ContentTypeToolCall {
461 toolCall, ok := AsContentType[ToolCallContent](content)
462 if !ok {
463 continue
464 }
465
466 // Validate and potentially repair the tool call
467 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
468 stepToolCalls = append(stepToolCalls, validatedToolCall)
469 }
470 }
471
472 toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
473
474 // Build step content with validated tool calls and tool results
475 stepContent := []Content{}
476 toolCallIndex := 0
477 for _, content := range result.Content {
478 if content.GetType() == ContentTypeToolCall {
479 // Replace with validated tool call
480 if toolCallIndex < len(stepToolCalls) {
481 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
482 toolCallIndex++
483 }
484 } else {
485 // Keep other content as-is
486 stepContent = append(stepContent, content)
487 }
488 }
489 // Add tool results
490 for _, result := range toolResults {
491 stepContent = append(stepContent, result)
492 }
493 currentStepMessages := toResponseMessages(stepContent)
494 responseMessages = append(responseMessages, currentStepMessages...)
495
496 stepResult := StepResult{
497 Response: Response{
498 Content: stepContent,
499 FinishReason: result.FinishReason,
500 Usage: result.Usage,
501 Warnings: result.Warnings,
502 ProviderMetadata: result.ProviderMetadata,
503 },
504 Messages: currentStepMessages,
505 }
506 steps = append(steps, stepResult)
507 shouldStop := isStopConditionMet(opts.StopWhen, steps)
508
509 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
510 break
511 }
512 }
513
514 totalUsage := Usage{}
515
516 for _, step := range steps {
517 usage := step.Usage
518 totalUsage.InputTokens += usage.InputTokens
519 totalUsage.OutputTokens += usage.OutputTokens
520 totalUsage.ReasoningTokens += usage.ReasoningTokens
521 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
522 totalUsage.CacheReadTokens += usage.CacheReadTokens
523 totalUsage.TotalTokens += usage.TotalTokens
524 }
525
526 agentResult := &AgentResult{
527 Steps: steps,
528 Response: steps[len(steps)-1].Response,
529 TotalUsage: totalUsage,
530 }
531 return agentResult, nil
532}
533
534func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
535 if len(conditions) == 0 {
536 return false
537 }
538
539 for _, condition := range conditions {
540 if condition(steps) {
541 return true
542 }
543 }
544 return false
545}
546
547func toResponseMessages(content []Content) []Message {
548 var assistantParts []MessagePart
549 var toolParts []MessagePart
550
551 for _, c := range content {
552 switch c.GetType() {
553 case ContentTypeText:
554 text, ok := AsContentType[TextContent](c)
555 if !ok {
556 continue
557 }
558 assistantParts = append(assistantParts, TextPart{
559 Text: text.Text,
560 ProviderOptions: ProviderOptions(text.ProviderMetadata),
561 })
562 case ContentTypeReasoning:
563 reasoning, ok := AsContentType[ReasoningContent](c)
564 if !ok {
565 continue
566 }
567 assistantParts = append(assistantParts, ReasoningPart{
568 Text: reasoning.Text,
569 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
570 })
571 case ContentTypeToolCall:
572 toolCall, ok := AsContentType[ToolCallContent](c)
573 if !ok {
574 continue
575 }
576 assistantParts = append(assistantParts, ToolCallPart{
577 ToolCallID: toolCall.ToolCallID,
578 ToolName: toolCall.ToolName,
579 Input: toolCall.Input,
580 ProviderExecuted: toolCall.ProviderExecuted,
581 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
582 })
583 case ContentTypeFile:
584 file, ok := AsContentType[FileContent](c)
585 if !ok {
586 continue
587 }
588 assistantParts = append(assistantParts, FilePart{
589 Data: file.Data,
590 MediaType: file.MediaType,
591 ProviderOptions: ProviderOptions(file.ProviderMetadata),
592 })
593 case ContentTypeSource:
594 // Sources are metadata about references used to generate the response.
595 // They don't need to be included in the conversation messages.
596 continue
597 case ContentTypeToolResult:
598 result, ok := AsContentType[ToolResultContent](c)
599 if !ok {
600 continue
601 }
602 toolParts = append(toolParts, ToolResultPart{
603 ToolCallID: result.ToolCallID,
604 Output: result.Result,
605 ProviderOptions: ProviderOptions(result.ProviderMetadata),
606 })
607 }
608 }
609
610 var messages []Message
611 if len(assistantParts) > 0 {
612 messages = append(messages, Message{
613 Role: MessageRoleAssistant,
614 Content: assistantParts,
615 })
616 }
617 if len(toolParts) > 0 {
618 messages = append(messages, Message{
619 Role: MessageRoleTool,
620 Content: toolParts,
621 })
622 }
623 return messages
624}
625
626func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
627 if len(toolCalls) == 0 {
628 return nil, nil
629 }
630
631 // Create a map for quick tool lookup
632 toolMap := make(map[string]AgentTool)
633 for _, tool := range allTools {
634 toolMap[tool.Info().Name] = tool
635 }
636
637 // Execute all tool calls sequentially in order
638 results := make([]ToolResultContent, 0, len(toolCalls))
639
640 for _, toolCall := range toolCalls {
641 result, isCriticalError := a.executeSingleTool(ctx, toolMap, toolCall, toolResultCallback)
642 results = append(results, result)
643 if isCriticalError {
644 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
645 return nil, errorResult.Error
646 }
647 }
648 }
649
650 return results, nil
651}
652
653// executeSingleTool executes a single tool and returns its result and a critical error flag.
654func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
655 result := ToolResultContent{
656 ToolCallID: toolCall.ToolCallID,
657 ToolName: toolCall.ToolName,
658 ProviderExecuted: false,
659 }
660
661 // Skip invalid tool calls - create error result (not critical)
662 if toolCall.Invalid {
663 result.Result = ToolResultOutputContentError{
664 Error: toolCall.ValidationError,
665 }
666 if toolResultCallback != nil {
667 _ = toolResultCallback(result)
668 }
669 return result, false
670 }
671
672 tool, exists := toolMap[toolCall.ToolName]
673 if !exists {
674 result.Result = ToolResultOutputContentError{
675 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
676 }
677 if toolResultCallback != nil {
678 _ = toolResultCallback(result)
679 }
680 return result, false
681 }
682
683 // Execute the tool
684 toolResult, err := tool.Run(ctx, ToolCall{
685 ID: toolCall.ToolCallID,
686 Name: toolCall.ToolName,
687 Input: toolCall.Input,
688 })
689 if err != nil {
690 result.Result = ToolResultOutputContentError{
691 Error: err,
692 }
693 result.ClientMetadata = toolResult.Metadata
694 if toolResultCallback != nil {
695 _ = toolResultCallback(result)
696 }
697 return result, true
698 }
699
700 result.ClientMetadata = toolResult.Metadata
701 if toolResult.IsError {
702 result.Result = ToolResultOutputContentError{
703 Error: errors.New(toolResult.Content),
704 }
705 } else if toolResult.Type == "image" || toolResult.Type == "media" {
706 result.Result = ToolResultOutputContentMedia{
707 Data: string(toolResult.Data),
708 MediaType: toolResult.MediaType,
709 Text: toolResult.Content,
710 }
711 } else {
712 result.Result = ToolResultOutputContentText{
713 Text: toolResult.Content,
714 }
715 }
716 if toolResultCallback != nil {
717 _ = toolResultCallback(result)
718 }
719 return result, false
720}
721
722// Stream implements Agent.
723func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
724 // Convert AgentStreamCall to AgentCall for preparation
725 call := AgentCall{
726 Prompt: opts.Prompt,
727 Files: opts.Files,
728 Messages: opts.Messages,
729 MaxOutputTokens: opts.MaxOutputTokens,
730 Temperature: opts.Temperature,
731 TopP: opts.TopP,
732 TopK: opts.TopK,
733 PresencePenalty: opts.PresencePenalty,
734 FrequencyPenalty: opts.FrequencyPenalty,
735 ActiveTools: opts.ActiveTools,
736 ProviderOptions: opts.ProviderOptions,
737 MaxRetries: opts.MaxRetries,
738 OnRetry: opts.OnRetry,
739 StopWhen: opts.StopWhen,
740 PrepareStep: opts.PrepareStep,
741 RepairToolCall: opts.RepairToolCall,
742 }
743
744 call = a.prepareCall(call)
745
746 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
747 if err != nil {
748 return nil, err
749 }
750
751 var responseMessages []Message
752 var steps []StepResult
753 var totalUsage Usage
754
755 // Start agent stream
756 if opts.OnAgentStart != nil {
757 opts.OnAgentStart()
758 }
759
760 for stepNumber := 0; ; stepNumber++ {
761 stepInputMessages := append(initialPrompt, responseMessages...)
762 stepModel := a.settings.model
763 stepSystemPrompt := a.settings.systemPrompt
764 stepActiveTools := call.ActiveTools
765 stepToolChoice := ToolChoiceAuto
766 disableAllTools := false
767 stepTools := a.settings.tools
768 // Apply step preparation if provided
769 if call.PrepareStep != nil {
770 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
771 Model: stepModel,
772 Steps: steps,
773 StepNumber: stepNumber,
774 Messages: stepInputMessages,
775 })
776 if err != nil {
777 return nil, err
778 }
779
780 ctx = updatedCtx
781
782 if prepared.Messages != nil {
783 stepInputMessages = prepared.Messages
784 }
785 if prepared.Model != nil {
786 stepModel = prepared.Model
787 }
788 if prepared.System != nil {
789 stepSystemPrompt = *prepared.System
790 }
791 if prepared.ToolChoice != nil {
792 stepToolChoice = *prepared.ToolChoice
793 }
794 if len(prepared.ActiveTools) > 0 {
795 stepActiveTools = prepared.ActiveTools
796 }
797 disableAllTools = prepared.DisableAllTools
798 if prepared.Tools != nil {
799 stepTools = prepared.Tools
800 }
801 }
802
803 // Recreate prompt with potentially modified system prompt
804 if stepSystemPrompt != a.settings.systemPrompt {
805 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
806 if err != nil {
807 return nil, err
808 }
809 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
810 stepInputMessages[0] = stepPrompt[0]
811 }
812 }
813
814 preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
815
816 // Start step stream
817 if opts.OnStepStart != nil {
818 _ = opts.OnStepStart(stepNumber)
819 }
820
821 // Create streaming call
822 streamCall := Call{
823 Prompt: stepInputMessages,
824 MaxOutputTokens: call.MaxOutputTokens,
825 Temperature: call.Temperature,
826 TopP: call.TopP,
827 TopK: call.TopK,
828 PresencePenalty: call.PresencePenalty,
829 FrequencyPenalty: call.FrequencyPenalty,
830 Tools: preparedTools,
831 ToolChoice: &stepToolChoice,
832 ProviderOptions: call.ProviderOptions,
833 }
834
835 // Execute step with retry logic wrapping both stream creation and processing
836 retryOptions := DefaultRetryOptions()
837 if call.MaxRetries != nil {
838 retryOptions.MaxRetries = *call.MaxRetries
839 }
840 retryOptions.OnRetry = call.OnRetry
841 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
842
843 result, err := retry(ctx, func() (stepExecutionResult, error) {
844 // Create the stream
845 stream, err := stepModel.Stream(ctx, streamCall)
846 if err != nil {
847 return stepExecutionResult{}, err
848 }
849
850 // Process the stream
851 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
852 if err != nil {
853 return stepExecutionResult{}, err
854 }
855
856 return result, nil
857 })
858 if err != nil {
859 if opts.OnError != nil {
860 opts.OnError(err)
861 }
862 return nil, err
863 }
864
865 steps = append(steps, result.StepResult)
866 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
867
868 // Call step finished callback
869 if opts.OnStepFinish != nil {
870 _ = opts.OnStepFinish(result.StepResult)
871 }
872
873 // Add step messages to response messages
874 stepMessages := toResponseMessages(result.StepResult.Content)
875 responseMessages = append(responseMessages, stepMessages...)
876
877 // Check stop conditions
878 shouldStop := isStopConditionMet(call.StopWhen, steps)
879 if shouldStop || !result.ShouldContinue {
880 break
881 }
882 }
883
884 // Finish agent stream
885 agentResult := &AgentResult{
886 Steps: steps,
887 Response: steps[len(steps)-1].Response,
888 TotalUsage: totalUsage,
889 }
890
891 if opts.OnFinish != nil {
892 opts.OnFinish(agentResult)
893 }
894
895 if opts.OnAgentFinish != nil {
896 _ = opts.OnAgentFinish(agentResult)
897 }
898
899 return agentResult, nil
900}
901
902func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
903 preparedTools := make([]Tool, 0, len(tools))
904
905 // If explicitly disabling all tools, return no tools
906 if disableAllTools {
907 return preparedTools
908 }
909
910 for _, tool := range tools {
911 // If activeTools has items, only include tools in the list
912 // If activeTools is empty, include all tools
913 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
914 continue
915 }
916 info := tool.Info()
917 inputSchema := map[string]any{
918 "type": "object",
919 "properties": info.Parameters,
920 "required": info.Required,
921 }
922 schema.Normalize(inputSchema)
923 preparedTools = append(preparedTools, FunctionTool{
924 Name: info.Name,
925 Description: info.Description,
926 InputSchema: inputSchema,
927 ProviderOptions: tool.ProviderOptions(),
928 })
929 }
930 return preparedTools
931}
932
933// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
934func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
935 if err := a.validateToolCall(toolCall, availableTools); err == nil {
936 return toolCall
937 } else { //nolint: revive
938 if repairFunc != nil {
939 repairOptions := ToolCallRepairOptions{
940 OriginalToolCall: toolCall,
941 ValidationError: err,
942 AvailableTools: availableTools,
943 SystemPrompt: systemPrompt,
944 Messages: messages,
945 }
946
947 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
948 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
949 return *repairedToolCall
950 }
951 }
952 }
953
954 invalidToolCall := toolCall
955 invalidToolCall.Invalid = true
956 invalidToolCall.ValidationError = err
957 return invalidToolCall
958 }
959}
960
961// validateToolCall validates a tool call against available tools and their schemas.
962func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
963 var tool AgentTool
964 for _, t := range availableTools {
965 if t.Info().Name == toolCall.ToolName {
966 tool = t
967 break
968 }
969 }
970
971 if tool == nil {
972 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
973 }
974
975 // Validate JSON parsing
976 var input map[string]any
977 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
978 return fmt.Errorf("invalid JSON input: %w", err)
979 }
980
981 // Basic schema validation (check required fields)
982 // TODO: more robust schema validation using JSON Schema or similar
983 toolInfo := tool.Info()
984 for _, required := range toolInfo.Required {
985 if _, exists := input[required]; !exists {
986 return fmt.Errorf("missing required parameter: %s", required)
987 }
988 }
989 return nil
990}
991
992func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
993 if prompt == "" {
994 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
995 }
996
997 var preparedPrompt Prompt
998
999 if system != "" {
1000 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1001 }
1002 preparedPrompt = append(preparedPrompt, messages...)
1003 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1004 return preparedPrompt, nil
1005}
1006
1007// WithSystemPrompt sets the system prompt for the agent.
1008func WithSystemPrompt(prompt string) AgentOption {
1009 return func(s *agentSettings) {
1010 s.systemPrompt = prompt
1011 }
1012}
1013
1014// WithMaxOutputTokens sets the maximum output tokens for the agent.
1015func WithMaxOutputTokens(tokens int64) AgentOption {
1016 return func(s *agentSettings) {
1017 s.maxOutputTokens = &tokens
1018 }
1019}
1020
1021// WithTemperature sets the temperature for the agent.
1022func WithTemperature(temp float64) AgentOption {
1023 return func(s *agentSettings) {
1024 s.temperature = &temp
1025 }
1026}
1027
1028// WithTopP sets the top-p value for the agent.
1029func WithTopP(topP float64) AgentOption {
1030 return func(s *agentSettings) {
1031 s.topP = &topP
1032 }
1033}
1034
1035// WithTopK sets the top-k value for the agent.
1036func WithTopK(topK int64) AgentOption {
1037 return func(s *agentSettings) {
1038 s.topK = &topK
1039 }
1040}
1041
1042// WithPresencePenalty sets the presence penalty for the agent.
1043func WithPresencePenalty(penalty float64) AgentOption {
1044 return func(s *agentSettings) {
1045 s.presencePenalty = &penalty
1046 }
1047}
1048
1049// WithFrequencyPenalty sets the frequency penalty for the agent.
1050func WithFrequencyPenalty(penalty float64) AgentOption {
1051 return func(s *agentSettings) {
1052 s.frequencyPenalty = &penalty
1053 }
1054}
1055
1056// WithTools sets the tools for the agent.
1057func WithTools(tools ...AgentTool) AgentOption {
1058 return func(s *agentSettings) {
1059 s.tools = append(s.tools, tools...)
1060 }
1061}
1062
1063// WithStopConditions sets the stop conditions for the agent.
1064func WithStopConditions(conditions ...StopCondition) AgentOption {
1065 return func(s *agentSettings) {
1066 s.stopWhen = append(s.stopWhen, conditions...)
1067 }
1068}
1069
1070// WithPrepareStep sets the prepare step function for the agent.
1071func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1072 return func(s *agentSettings) {
1073 s.prepareStep = fn
1074 }
1075}
1076
1077// WithRepairToolCall sets the repair tool call function for the agent.
1078func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1079 return func(s *agentSettings) {
1080 s.repairToolCall = fn
1081 }
1082}
1083
1084// WithMaxRetries sets the maximum number of retries for the agent.
1085func WithMaxRetries(maxRetries int) AgentOption {
1086 return func(s *agentSettings) {
1087 s.maxRetries = &maxRetries
1088 }
1089}
1090
1091// WithOnRetry sets the retry callback for the agent.
1092func WithOnRetry(callback OnRetryCallback) AgentOption {
1093 return func(s *agentSettings) {
1094 s.onRetry = callback
1095 }
1096}
1097
1098// processStepStream processes a single step's stream and returns the step result.
1099func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1100 var stepContent []Content
1101 var stepToolCalls []ToolCallContent
1102 var stepUsage Usage
1103 stepFinishReason := FinishReasonUnknown
1104 var stepWarnings []CallWarning
1105 var stepProviderMetadata ProviderMetadata
1106
1107 activeToolCalls := make(map[string]*ToolCallContent)
1108 activeTextContent := make(map[string]string)
1109 type reasoningContent struct {
1110 content string
1111 options ProviderMetadata
1112 }
1113 activeReasoningContent := make(map[string]reasoningContent)
1114
1115 // Set up concurrent tool execution
1116 type toolExecutionRequest struct {
1117 toolCall ToolCallContent
1118 parallel bool
1119 }
1120 toolChan := make(chan toolExecutionRequest, 10)
1121 var toolExecutionWg sync.WaitGroup
1122 var toolStateMu sync.Mutex
1123 toolResults := make([]ToolResultContent, 0)
1124 var toolExecutionErr error
1125
1126 // Create a map for quick tool lookup
1127 toolMap := make(map[string]AgentTool)
1128 for _, tool := range stepTools {
1129 toolMap[tool.Info().Name] = tool
1130 }
1131
1132 // Semaphores for controlling parallelism
1133 parallelSem := make(chan struct{}, 5)
1134 var sequentialMu sync.Mutex
1135
1136 // Single coordinator goroutine that dispatches tools
1137 toolExecutionWg.Go(func() {
1138 for req := range toolChan {
1139 if req.parallel {
1140 parallelSem <- struct{}{}
1141 toolExecutionWg.Go(func() {
1142 defer func() { <-parallelSem }()
1143 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1144 toolStateMu.Lock()
1145 toolResults = append(toolResults, result)
1146 if isCriticalError && toolExecutionErr == nil {
1147 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1148 toolExecutionErr = errorResult.Error
1149 }
1150 }
1151 toolStateMu.Unlock()
1152 })
1153 } else {
1154 sequentialMu.Lock()
1155 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1156 toolStateMu.Lock()
1157 toolResults = append(toolResults, result)
1158 if isCriticalError && toolExecutionErr == nil {
1159 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1160 toolExecutionErr = errorResult.Error
1161 }
1162 }
1163 toolStateMu.Unlock()
1164 sequentialMu.Unlock()
1165 }
1166 }
1167 })
1168
1169 // Process stream parts
1170 for part := range stream {
1171 // Forward all parts to chunk callback
1172 if opts.OnChunk != nil {
1173 err := opts.OnChunk(part)
1174 if err != nil {
1175 return stepExecutionResult{}, err
1176 }
1177 }
1178
1179 switch part.Type {
1180 case StreamPartTypeWarnings:
1181 stepWarnings = part.Warnings
1182 if opts.OnWarnings != nil {
1183 err := opts.OnWarnings(part.Warnings)
1184 if err != nil {
1185 return stepExecutionResult{}, err
1186 }
1187 }
1188
1189 case StreamPartTypeTextStart:
1190 activeTextContent[part.ID] = ""
1191 if opts.OnTextStart != nil {
1192 err := opts.OnTextStart(part.ID)
1193 if err != nil {
1194 return stepExecutionResult{}, err
1195 }
1196 }
1197
1198 case StreamPartTypeTextDelta:
1199 if _, exists := activeTextContent[part.ID]; exists {
1200 activeTextContent[part.ID] += part.Delta
1201 }
1202 if opts.OnTextDelta != nil {
1203 err := opts.OnTextDelta(part.ID, part.Delta)
1204 if err != nil {
1205 return stepExecutionResult{}, err
1206 }
1207 }
1208
1209 case StreamPartTypeTextEnd:
1210 if text, exists := activeTextContent[part.ID]; exists {
1211 stepContent = append(stepContent, TextContent{
1212 Text: text,
1213 ProviderMetadata: part.ProviderMetadata,
1214 })
1215 delete(activeTextContent, part.ID)
1216 }
1217 if opts.OnTextEnd != nil {
1218 err := opts.OnTextEnd(part.ID)
1219 if err != nil {
1220 return stepExecutionResult{}, err
1221 }
1222 }
1223
1224 case StreamPartTypeReasoningStart:
1225 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1226 if opts.OnReasoningStart != nil {
1227 content := ReasoningContent{
1228 Text: part.Delta,
1229 ProviderMetadata: part.ProviderMetadata,
1230 }
1231 err := opts.OnReasoningStart(part.ID, content)
1232 if err != nil {
1233 return stepExecutionResult{}, err
1234 }
1235 }
1236
1237 case StreamPartTypeReasoningDelta:
1238 if active, exists := activeReasoningContent[part.ID]; exists {
1239 active.content += part.Delta
1240 active.options = part.ProviderMetadata
1241 activeReasoningContent[part.ID] = active
1242 }
1243 if opts.OnReasoningDelta != nil {
1244 err := opts.OnReasoningDelta(part.ID, part.Delta)
1245 if err != nil {
1246 return stepExecutionResult{}, err
1247 }
1248 }
1249
1250 case StreamPartTypeReasoningEnd:
1251 if active, exists := activeReasoningContent[part.ID]; exists {
1252 if part.ProviderMetadata != nil {
1253 active.options = part.ProviderMetadata
1254 }
1255 content := ReasoningContent{
1256 Text: active.content,
1257 ProviderMetadata: active.options,
1258 }
1259 stepContent = append(stepContent, content)
1260 if opts.OnReasoningEnd != nil {
1261 err := opts.OnReasoningEnd(part.ID, content)
1262 if err != nil {
1263 return stepExecutionResult{}, err
1264 }
1265 }
1266 delete(activeReasoningContent, part.ID)
1267 }
1268
1269 case StreamPartTypeToolInputStart:
1270 activeToolCalls[part.ID] = &ToolCallContent{
1271 ToolCallID: part.ID,
1272 ToolName: part.ToolCallName,
1273 Input: "",
1274 ProviderExecuted: part.ProviderExecuted,
1275 }
1276 if opts.OnToolInputStart != nil {
1277 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1278 if err != nil {
1279 return stepExecutionResult{}, err
1280 }
1281 }
1282
1283 case StreamPartTypeToolInputDelta:
1284 if toolCall, exists := activeToolCalls[part.ID]; exists {
1285 toolCall.Input += part.Delta
1286 }
1287 if opts.OnToolInputDelta != nil {
1288 err := opts.OnToolInputDelta(part.ID, part.Delta)
1289 if err != nil {
1290 return stepExecutionResult{}, err
1291 }
1292 }
1293
1294 case StreamPartTypeToolInputEnd:
1295 if opts.OnToolInputEnd != nil {
1296 err := opts.OnToolInputEnd(part.ID)
1297 if err != nil {
1298 return stepExecutionResult{}, err
1299 }
1300 }
1301
1302 case StreamPartTypeToolCall:
1303 toolCall := ToolCallContent{
1304 ToolCallID: part.ID,
1305 ToolName: part.ToolCallName,
1306 Input: part.ToolCallInput,
1307 ProviderExecuted: part.ProviderExecuted,
1308 ProviderMetadata: part.ProviderMetadata,
1309 }
1310
1311 // Validate and potentially repair the tool call
1312 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1313 stepToolCalls = append(stepToolCalls, validatedToolCall)
1314 stepContent = append(stepContent, validatedToolCall)
1315
1316 if opts.OnToolCall != nil {
1317 err := opts.OnToolCall(validatedToolCall)
1318 if err != nil {
1319 return stepExecutionResult{}, err
1320 }
1321 }
1322
1323 // Determine if tool can run in parallel
1324 isParallel := false
1325 if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1326 isParallel = tool.Info().Parallel
1327 }
1328
1329 // Send tool call to execution channel
1330 toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1331
1332 // Clean up active tool call
1333 delete(activeToolCalls, part.ID)
1334
1335 case StreamPartTypeSource:
1336 sourceContent := SourceContent{
1337 SourceType: part.SourceType,
1338 ID: part.ID,
1339 URL: part.URL,
1340 Title: part.Title,
1341 ProviderMetadata: part.ProviderMetadata,
1342 }
1343 stepContent = append(stepContent, sourceContent)
1344 if opts.OnSource != nil {
1345 err := opts.OnSource(sourceContent)
1346 if err != nil {
1347 return stepExecutionResult{}, err
1348 }
1349 }
1350
1351 case StreamPartTypeFinish:
1352 stepUsage = part.Usage
1353 stepFinishReason = part.FinishReason
1354 stepProviderMetadata = part.ProviderMetadata
1355 if opts.OnStreamFinish != nil {
1356 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1357 if err != nil {
1358 return stepExecutionResult{}, err
1359 }
1360 }
1361
1362 case StreamPartTypeError:
1363 return stepExecutionResult{}, part.Error
1364 }
1365 }
1366
1367 // Close the tool execution channel and wait for all executions to complete
1368 close(toolChan)
1369 toolExecutionWg.Wait()
1370
1371 // Check for tool execution errors
1372 if toolExecutionErr != nil {
1373 return stepExecutionResult{}, toolExecutionErr
1374 }
1375
1376 // Add tool results to content if any
1377 if len(toolResults) > 0 {
1378 for _, result := range toolResults {
1379 stepContent = append(stepContent, result)
1380 }
1381 }
1382
1383 stepResult := StepResult{
1384 Response: Response{
1385 Content: stepContent,
1386 FinishReason: stepFinishReason,
1387 Usage: stepUsage,
1388 Warnings: stepWarnings,
1389 ProviderMetadata: stepProviderMetadata,
1390 },
1391 Messages: toResponseMessages(stepContent),
1392 }
1393
1394 // Determine if we should continue (has tool calls and not stopped)
1395 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1396
1397 return stepExecutionResult{
1398 StepResult: stepResult,
1399 ShouldContinue: shouldContinue,
1400 }, nil
1401}
1402
1403func addUsage(a, b Usage) Usage {
1404 return Usage{
1405 InputTokens: a.InputTokens + b.InputTokens,
1406 OutputTokens: a.OutputTokens + b.OutputTokens,
1407 TotalTokens: a.TotalTokens + b.TotalTokens,
1408 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1409 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1410 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1411 }
1412}
1413
1414// WithHeaders sets the headers for the agent.
1415func WithHeaders(headers map[string]string) AgentOption {
1416 return func(s *agentSettings) {
1417 s.headers = headers
1418 }
1419}
1420
1421// WithProviderOptions sets the provider options for the agent.
1422func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1423 return func(s *agentSettings) {
1424 s.providerOptions = providerOptions
1425 }
1426}