1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11)
12
13// StepResult represents the result of a single step in an agent execution.
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19// StopCondition defines a function that determines when an agent should stop executing.
20type StopCondition = func(steps []StepResult) bool
21
22// StepCountIs returns a stop condition that stops after the specified number of steps.
23func StepCountIs(stepCount int) StopCondition {
24 return func(steps []StepResult) bool {
25 return len(steps) >= stepCount
26 }
27}
28
29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
30func HasToolCall(toolName string) StopCondition {
31 return func(steps []StepResult) bool {
32 if len(steps) == 0 {
33 return false
34 }
35 lastStep := steps[len(steps)-1]
36 toolCalls := lastStep.Content.ToolCalls()
37 for _, toolCall := range toolCalls {
38 if toolCall.ToolName == toolName {
39 return true
40 }
41 }
42 return false
43 }
44}
45
46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
47func HasContent(contentType ContentType) StopCondition {
48 return func(steps []StepResult) bool {
49 if len(steps) == 0 {
50 return false
51 }
52 lastStep := steps[len(steps)-1]
53 for _, content := range lastStep.Content {
54 if content.GetType() == contentType {
55 return true
56 }
57 }
58 return false
59 }
60}
61
62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
63func FinishReasonIs(reason FinishReason) StopCondition {
64 return func(steps []StepResult) bool {
65 if len(steps) == 0 {
66 return false
67 }
68 lastStep := steps[len(steps)-1]
69 return lastStep.FinishReason == reason
70 }
71}
72
73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
74func MaxTokensUsed(maxTokens int64) StopCondition {
75 return func(steps []StepResult) bool {
76 var totalTokens int64
77 for _, step := range steps {
78 totalTokens += step.Usage.TotalTokens
79 }
80 return totalTokens >= maxTokens
81 }
82}
83
84// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
85type PrepareStepFunctionOptions struct {
86 Steps []StepResult
87 StepNumber int
88 Model LanguageModel
89 Messages []Message
90}
91
92// PrepareStepResult contains the result of preparing a step in an agent execution.
93type PrepareStepResult struct {
94 Model LanguageModel
95 Messages []Message
96 System *string
97 ToolChoice *ToolChoice
98 ActiveTools []string
99 DisableAllTools bool
100}
101
102// ToolCallRepairOptions contains the options for repairing a tool call.
103type ToolCallRepairOptions struct {
104 OriginalToolCall ToolCallContent
105 ValidationError error
106 AvailableTools []AgentTool
107 SystemPrompt string
108 Messages []Message
109}
110
111type (
112 // PrepareStepFunction defines a function that prepares a step in an agent execution.
113 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
114
115 // OnStepFinishedFunction defines a function that is called when a step finishes.
116 OnStepFinishedFunction = func(step StepResult)
117
118 // RepairToolCallFunction defines a function that repairs a tool call.
119 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
120)
121
122type agentSettings struct {
123 systemPrompt string
124 maxOutputTokens *int64
125 temperature *float64
126 topP *float64
127 topK *int64
128 presencePenalty *float64
129 frequencyPenalty *float64
130 headers map[string]string
131 providerOptions ProviderOptions
132
133 // TODO: add support for provider tools
134 tools []AgentTool
135 maxRetries *int
136
137 model LanguageModel
138
139 stopWhen []StopCondition
140 prepareStep PrepareStepFunction
141 repairToolCall RepairToolCallFunction
142 onRetry OnRetryCallback
143}
144
145// AgentCall represents a call to an agent.
146type AgentCall struct {
147 Prompt string `json:"prompt"`
148 Files []FilePart `json:"files"`
149 Messages []Message `json:"messages"`
150 MaxOutputTokens *int64
151 Temperature *float64 `json:"temperature"`
152 TopP *float64 `json:"top_p"`
153 TopK *int64 `json:"top_k"`
154 PresencePenalty *float64 `json:"presence_penalty"`
155 FrequencyPenalty *float64 `json:"frequency_penalty"`
156 ActiveTools []string `json:"active_tools"`
157 ProviderOptions ProviderOptions
158 OnRetry OnRetryCallback
159 MaxRetries *int
160
161 StopWhen []StopCondition
162 PrepareStep PrepareStepFunction
163 RepairToolCall RepairToolCallFunction
164}
165
166// Agent-level callbacks.
167type (
168 // OnAgentStartFunc is called when agent starts.
169 OnAgentStartFunc func()
170
171 // OnAgentFinishFunc is called when agent finishes.
172 OnAgentFinishFunc func(result *AgentResult) error
173
174 // OnStepStartFunc is called when a step starts.
175 OnStepStartFunc func(stepNumber int) error
176
177 // OnStepFinishFunc is called when a step finishes.
178 OnStepFinishFunc func(stepResult StepResult) error
179
180 // OnFinishFunc is called when entire agent completes.
181 OnFinishFunc func(result *AgentResult)
182
183 // OnErrorFunc is called when an error occurs.
184 OnErrorFunc func(error)
185)
186
187// Stream part callbacks - called for each corresponding stream part type.
188type (
189 // OnChunkFunc is called for each stream part (catch-all).
190 OnChunkFunc func(StreamPart) error
191
192 // OnWarningsFunc is called for warnings.
193 OnWarningsFunc func(warnings []CallWarning) error
194
195 // OnTextStartFunc is called when text starts.
196 OnTextStartFunc func(id string) error
197
198 // OnTextDeltaFunc is called for text deltas.
199 OnTextDeltaFunc func(id, text string) error
200
201 // OnTextEndFunc is called when text ends.
202 OnTextEndFunc func(id string) error
203
204 // OnReasoningStartFunc is called when reasoning starts.
205 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
206
207 // OnReasoningDeltaFunc is called for reasoning deltas.
208 OnReasoningDeltaFunc func(id, text string) error
209
210 // OnReasoningEndFunc is called when reasoning ends.
211 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
212
213 // OnToolInputStartFunc is called when tool input starts.
214 OnToolInputStartFunc func(id, toolName string) error
215
216 // OnToolInputDeltaFunc is called for tool input deltas.
217 OnToolInputDeltaFunc func(id, delta string) error
218
219 // OnToolInputEndFunc is called when tool input ends.
220 OnToolInputEndFunc func(id string) error
221
222 // OnToolCallFunc is called when tool call is complete.
223 OnToolCallFunc func(toolCall ToolCallContent) error
224
225 // OnBeforeToolExecutionFunc is called right before tool execution.
226 // It can modify tool input and/or skip tool execution by returning (modifiedInput, skipExecution, error).
227 // If skipExecution is true, a synthetic error result is created with the provided error.
228 // If modifiedInput is not empty, it replaces the original tool input.
229 OnBeforeToolExecutionFunc func(toolCall ToolCallContent) (modifiedInput string, skipExecution bool, err error)
230
231 // OnToolResultFunc is called when tool execution completes.
232 OnToolResultFunc func(result ToolResultContent) error
233
234 // OnSourceFunc is called for source references.
235 OnSourceFunc func(source SourceContent) error
236
237 // OnStreamFinishFunc is called when stream finishes.
238 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
239)
240
241// AgentStreamCall represents a streaming call to an agent.
242type AgentStreamCall struct {
243 Prompt string `json:"prompt"`
244 Files []FilePart `json:"files"`
245 Messages []Message `json:"messages"`
246 MaxOutputTokens *int64
247 Temperature *float64 `json:"temperature"`
248 TopP *float64 `json:"top_p"`
249 TopK *int64 `json:"top_k"`
250 PresencePenalty *float64 `json:"presence_penalty"`
251 FrequencyPenalty *float64 `json:"frequency_penalty"`
252 ActiveTools []string `json:"active_tools"`
253 Headers map[string]string
254 ProviderOptions ProviderOptions
255 OnRetry OnRetryCallback
256 MaxRetries *int
257
258 StopWhen []StopCondition
259 PrepareStep PrepareStepFunction
260 RepairToolCall RepairToolCallFunction
261
262 // Agent-level callbacks
263 OnAgentStart OnAgentStartFunc // Called when agent starts
264 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
265 OnStepStart OnStepStartFunc // Called when a step starts
266 OnStepFinish OnStepFinishFunc // Called when a step finishes
267 OnFinish OnFinishFunc // Called when entire agent completes
268 OnError OnErrorFunc // Called when an error occurs
269
270 // Stream part callbacks - called for each corresponding stream part type
271 OnChunk OnChunkFunc // Called for each stream part (catch-all)
272 OnWarnings OnWarningsFunc // Called for warnings
273 OnTextStart OnTextStartFunc // Called when text starts
274 OnTextDelta OnTextDeltaFunc // Called for text deltas
275 OnTextEnd OnTextEndFunc // Called when text ends
276 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
277 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
278 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
279 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
280 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
281 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
282 OnToolCall OnToolCallFunc // Called when tool call is complete
283 OnBeforeToolExecution OnBeforeToolExecutionFunc // Called right before tool execution
284 OnToolResult OnToolResultFunc // Called when tool execution completes
285 OnSource OnSourceFunc // Called for source references
286 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
287}
288
289// AgentResult represents the result of an agent execution.
290type AgentResult struct {
291 Steps []StepResult
292 // Final response
293 Response Response
294 TotalUsage Usage
295}
296
297// Agent represents an AI agent that can generate responses and stream responses.
298type Agent interface {
299 Generate(context.Context, AgentCall) (*AgentResult, error)
300 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
301}
302
303// AgentOption defines a function that configures agent settings.
304type AgentOption = func(*agentSettings)
305
306type agent struct {
307 settings agentSettings
308}
309
310// NewAgent creates a new agent with the given language model and options.
311func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
312 settings := agentSettings{
313 model: model,
314 }
315 for _, o := range opts {
316 o(&settings)
317 }
318 return &agent{
319 settings: settings,
320 }
321}
322
323func (a *agent) prepareCall(call AgentCall) AgentCall {
324 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
325 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
326 call.TopP = cmp.Or(call.TopP, a.settings.topP)
327 call.TopK = cmp.Or(call.TopK, a.settings.topK)
328 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
329 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
330 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
331
332 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
333 call.StopWhen = a.settings.stopWhen
334 }
335 if call.PrepareStep == nil && a.settings.prepareStep != nil {
336 call.PrepareStep = a.settings.prepareStep
337 }
338 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
339 call.RepairToolCall = a.settings.repairToolCall
340 }
341 if call.OnRetry == nil && a.settings.onRetry != nil {
342 call.OnRetry = a.settings.onRetry
343 }
344
345 providerOptions := ProviderOptions{}
346 if a.settings.providerOptions != nil {
347 maps.Copy(providerOptions, a.settings.providerOptions)
348 }
349 if call.ProviderOptions != nil {
350 maps.Copy(providerOptions, call.ProviderOptions)
351 }
352 call.ProviderOptions = providerOptions
353
354 headers := map[string]string{}
355
356 if a.settings.headers != nil {
357 maps.Copy(headers, a.settings.headers)
358 }
359
360 return call
361}
362
363// Generate implements Agent.
364func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
365 opts = a.prepareCall(opts)
366 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
367 if err != nil {
368 return nil, err
369 }
370 var responseMessages []Message
371 var steps []StepResult
372
373 for {
374 stepInputMessages := append(initialPrompt, responseMessages...)
375 stepModel := a.settings.model
376 stepSystemPrompt := a.settings.systemPrompt
377 stepActiveTools := opts.ActiveTools
378 stepToolChoice := ToolChoiceAuto
379 disableAllTools := false
380
381 if opts.PrepareStep != nil {
382 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
383 Model: stepModel,
384 Steps: steps,
385 StepNumber: len(steps),
386 Messages: stepInputMessages,
387 })
388 if err != nil {
389 return nil, err
390 }
391
392 ctx = updatedCtx
393
394 // Apply prepared step modifications
395 if prepared.Messages != nil {
396 stepInputMessages = prepared.Messages
397 }
398 if prepared.Model != nil {
399 stepModel = prepared.Model
400 }
401 if prepared.System != nil {
402 stepSystemPrompt = *prepared.System
403 }
404 if prepared.ToolChoice != nil {
405 stepToolChoice = *prepared.ToolChoice
406 }
407 if len(prepared.ActiveTools) > 0 {
408 stepActiveTools = prepared.ActiveTools
409 }
410 disableAllTools = prepared.DisableAllTools
411 }
412
413 // Recreate prompt with potentially modified system prompt
414 if stepSystemPrompt != a.settings.systemPrompt {
415 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
416 if err != nil {
417 return nil, err
418 }
419 // Replace system message part, keep the rest
420 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
421 stepInputMessages[0] = stepPrompt[0] // Replace system message
422 }
423 }
424
425 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
426
427 retryOptions := DefaultRetryOptions()
428 retryOptions.OnRetry = opts.OnRetry
429 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
430
431 result, err := retry(ctx, func() (*Response, error) {
432 return stepModel.Generate(ctx, Call{
433 Prompt: stepInputMessages,
434 MaxOutputTokens: opts.MaxOutputTokens,
435 Temperature: opts.Temperature,
436 TopP: opts.TopP,
437 TopK: opts.TopK,
438 PresencePenalty: opts.PresencePenalty,
439 FrequencyPenalty: opts.FrequencyPenalty,
440 Tools: preparedTools,
441 ToolChoice: &stepToolChoice,
442 ProviderOptions: opts.ProviderOptions,
443 })
444 })
445 if err != nil {
446 return nil, err
447 }
448
449 var stepToolCalls []ToolCallContent
450 for _, content := range result.Content {
451 if content.GetType() == ContentTypeToolCall {
452 toolCall, ok := AsContentType[ToolCallContent](content)
453 if !ok {
454 continue
455 }
456
457 // Validate and potentially repair the tool call
458 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
459 stepToolCalls = append(stepToolCalls, validatedToolCall)
460 }
461 }
462
463 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil)
464
465 // Build step content with validated tool calls and tool results
466 stepContent := []Content{}
467 toolCallIndex := 0
468 for _, content := range result.Content {
469 if content.GetType() == ContentTypeToolCall {
470 // Replace with validated tool call
471 if toolCallIndex < len(stepToolCalls) {
472 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
473 toolCallIndex++
474 }
475 } else {
476 // Keep other content as-is
477 stepContent = append(stepContent, content)
478 }
479 }
480 // Add tool results
481 for _, result := range toolResults {
482 stepContent = append(stepContent, result)
483 }
484 currentStepMessages := toResponseMessages(stepContent)
485 responseMessages = append(responseMessages, currentStepMessages...)
486
487 stepResult := StepResult{
488 Response: Response{
489 Content: stepContent,
490 FinishReason: result.FinishReason,
491 Usage: result.Usage,
492 Warnings: result.Warnings,
493 ProviderMetadata: result.ProviderMetadata,
494 },
495 Messages: currentStepMessages,
496 }
497 steps = append(steps, stepResult)
498 shouldStop := isStopConditionMet(opts.StopWhen, steps)
499
500 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
501 break
502 }
503 }
504
505 totalUsage := Usage{}
506
507 for _, step := range steps {
508 usage := step.Usage
509 totalUsage.InputTokens += usage.InputTokens
510 totalUsage.OutputTokens += usage.OutputTokens
511 totalUsage.ReasoningTokens += usage.ReasoningTokens
512 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
513 totalUsage.CacheReadTokens += usage.CacheReadTokens
514 totalUsage.TotalTokens += usage.TotalTokens
515 }
516
517 agentResult := &AgentResult{
518 Steps: steps,
519 Response: steps[len(steps)-1].Response,
520 TotalUsage: totalUsage,
521 }
522 return agentResult, nil
523}
524
525func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
526 if len(conditions) == 0 {
527 return false
528 }
529
530 for _, condition := range conditions {
531 if condition(steps) {
532 return true
533 }
534 }
535 return false
536}
537
538func toResponseMessages(content []Content) []Message {
539 var assistantParts []MessagePart
540 var toolParts []MessagePart
541
542 for _, c := range content {
543 switch c.GetType() {
544 case ContentTypeText:
545 text, ok := AsContentType[TextContent](c)
546 if !ok {
547 continue
548 }
549 assistantParts = append(assistantParts, TextPart{
550 Text: text.Text,
551 ProviderOptions: ProviderOptions(text.ProviderMetadata),
552 })
553 case ContentTypeReasoning:
554 reasoning, ok := AsContentType[ReasoningContent](c)
555 if !ok {
556 continue
557 }
558 assistantParts = append(assistantParts, ReasoningPart{
559 Text: reasoning.Text,
560 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
561 })
562 case ContentTypeToolCall:
563 toolCall, ok := AsContentType[ToolCallContent](c)
564 if !ok {
565 continue
566 }
567 assistantParts = append(assistantParts, ToolCallPart{
568 ToolCallID: toolCall.ToolCallID,
569 ToolName: toolCall.ToolName,
570 Input: toolCall.Input,
571 ProviderExecuted: toolCall.ProviderExecuted,
572 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
573 })
574 case ContentTypeFile:
575 file, ok := AsContentType[FileContent](c)
576 if !ok {
577 continue
578 }
579 assistantParts = append(assistantParts, FilePart{
580 Data: file.Data,
581 MediaType: file.MediaType,
582 ProviderOptions: ProviderOptions(file.ProviderMetadata),
583 })
584 case ContentTypeSource:
585 // Sources are metadata about references used to generate the response.
586 // They don't need to be included in the conversation messages.
587 continue
588 case ContentTypeToolResult:
589 result, ok := AsContentType[ToolResultContent](c)
590 if !ok {
591 continue
592 }
593 toolParts = append(toolParts, ToolResultPart{
594 ToolCallID: result.ToolCallID,
595 Output: result.Result,
596 ProviderOptions: ProviderOptions(result.ProviderMetadata),
597 })
598 }
599 }
600
601 var messages []Message
602 if len(assistantParts) > 0 {
603 messages = append(messages, Message{
604 Role: MessageRoleAssistant,
605 Content: assistantParts,
606 })
607 }
608 if len(toolParts) > 0 {
609 messages = append(messages, Message{
610 Role: MessageRoleTool,
611 Content: toolParts,
612 })
613 }
614 return messages
615}
616
617func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, beforeExecCallback OnBeforeToolExecutionFunc, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
618 if len(toolCalls) == 0 {
619 return nil, nil
620 }
621
622 // Create a map for quick tool lookup
623 toolMap := make(map[string]AgentTool)
624 for _, tool := range allTools {
625 toolMap[tool.Info().Name] = tool
626 }
627
628 // Execute all tool calls sequentially in order
629 results := make([]ToolResultContent, 0, len(toolCalls))
630
631 for _, toolCall := range toolCalls {
632 // Skip invalid tool calls - create error result
633 if toolCall.Invalid {
634 result := ToolResultContent{
635 ToolCallID: toolCall.ToolCallID,
636 ToolName: toolCall.ToolName,
637 Result: ToolResultOutputContentError{
638 Error: toolCall.ValidationError,
639 },
640 ProviderExecuted: false,
641 }
642 results = append(results, result)
643 if toolResultCallback != nil {
644 if err := toolResultCallback(result); err != nil {
645 return nil, err
646 }
647 }
648 continue
649 }
650
651 tool, exists := toolMap[toolCall.ToolName]
652 if !exists {
653 result := ToolResultContent{
654 ToolCallID: toolCall.ToolCallID,
655 ToolName: toolCall.ToolName,
656 Result: ToolResultOutputContentError{
657 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
658 },
659 ProviderExecuted: false,
660 }
661 results = append(results, result)
662 if toolResultCallback != nil {
663 if err := toolResultCallback(result); err != nil {
664 return nil, err
665 }
666 }
667 continue
668 }
669
670 // Call OnBeforeToolExecution hook to allow modification or skipping
671 inputToUse := toolCall.Input
672 if beforeExecCallback != nil {
673 modifiedInput, skipExecution, err := beforeExecCallback(toolCall)
674 if err != nil {
675 // Callback returned error - create synthetic error result
676 result := ToolResultContent{
677 ToolCallID: toolCall.ToolCallID,
678 ToolName: toolCall.ToolName,
679 Result: ToolResultOutputContentError{
680 Error: err,
681 },
682 ProviderExecuted: false,
683 }
684 results = append(results, result)
685 if toolResultCallback != nil {
686 if cbErr := toolResultCallback(result); cbErr != nil {
687 return nil, cbErr
688 }
689 }
690 if skipExecution {
691 // Hook wants to skip execution
692 continue
693 } else {
694 // Hook returned error but doesn't want to skip - stop execution
695 return nil, err
696 }
697 }
698 if skipExecution {
699 // Hook wants to skip execution without error
700 result := ToolResultContent{
701 ToolCallID: toolCall.ToolCallID,
702 ToolName: toolCall.ToolName,
703 Result: ToolResultOutputContentError{
704 Error: errors.New("Tool execution skipped by hook"),
705 },
706 ProviderExecuted: false,
707 }
708 results = append(results, result)
709 if toolResultCallback != nil {
710 if cbErr := toolResultCallback(result); cbErr != nil {
711 return nil, cbErr
712 }
713 }
714 continue
715 }
716 if modifiedInput != "" {
717 inputToUse = modifiedInput
718 }
719 }
720
721 // Execute the tool
722 toolResult, err := tool.Run(ctx, ToolCall{
723 ID: toolCall.ToolCallID,
724 Name: toolCall.ToolName,
725 Input: inputToUse,
726 })
727 if err != nil {
728 result := ToolResultContent{
729 ToolCallID: toolCall.ToolCallID,
730 ToolName: toolCall.ToolName,
731 Result: ToolResultOutputContentError{
732 Error: err,
733 },
734 ClientMetadata: toolResult.Metadata,
735 ProviderExecuted: false,
736 }
737 if toolResultCallback != nil {
738 if cbErr := toolResultCallback(result); cbErr != nil {
739 return nil, cbErr
740 }
741 }
742 return nil, err
743 }
744
745 var result ToolResultContent
746 if toolResult.IsError {
747 result = ToolResultContent{
748 ToolCallID: toolCall.ToolCallID,
749 ToolName: toolCall.ToolName,
750 Result: ToolResultOutputContentError{
751 Error: errors.New(toolResult.Content),
752 },
753 ClientMetadata: toolResult.Metadata,
754 ProviderExecuted: false,
755 }
756 } else {
757 result = ToolResultContent{
758 ToolCallID: toolCall.ToolCallID,
759 ToolName: toolCall.ToolName,
760 Result: ToolResultOutputContentText{
761 Text: toolResult.Content,
762 },
763 ClientMetadata: toolResult.Metadata,
764 ProviderExecuted: false,
765 }
766 }
767 results = append(results, result)
768 if toolResultCallback != nil {
769 if err := toolResultCallback(result); err != nil {
770 return nil, err
771 }
772 }
773 }
774
775 return results, nil
776}
777
778// Stream implements Agent.
779func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
780 // Convert AgentStreamCall to AgentCall for preparation
781 call := AgentCall{
782 Prompt: opts.Prompt,
783 Files: opts.Files,
784 Messages: opts.Messages,
785 MaxOutputTokens: opts.MaxOutputTokens,
786 Temperature: opts.Temperature,
787 TopP: opts.TopP,
788 TopK: opts.TopK,
789 PresencePenalty: opts.PresencePenalty,
790 FrequencyPenalty: opts.FrequencyPenalty,
791 ActiveTools: opts.ActiveTools,
792 ProviderOptions: opts.ProviderOptions,
793 MaxRetries: opts.MaxRetries,
794 StopWhen: opts.StopWhen,
795 PrepareStep: opts.PrepareStep,
796 RepairToolCall: opts.RepairToolCall,
797 }
798
799 call = a.prepareCall(call)
800
801 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
802 if err != nil {
803 return nil, err
804 }
805
806 var responseMessages []Message
807 var steps []StepResult
808 var totalUsage Usage
809
810 // Start agent stream
811 if opts.OnAgentStart != nil {
812 opts.OnAgentStart()
813 }
814
815 for stepNumber := 0; ; stepNumber++ {
816 stepInputMessages := append(initialPrompt, responseMessages...)
817 stepModel := a.settings.model
818 stepSystemPrompt := a.settings.systemPrompt
819 stepActiveTools := call.ActiveTools
820 stepToolChoice := ToolChoiceAuto
821 disableAllTools := false
822
823 // Apply step preparation if provided
824 if call.PrepareStep != nil {
825 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
826 Model: stepModel,
827 Steps: steps,
828 StepNumber: stepNumber,
829 Messages: stepInputMessages,
830 })
831 if err != nil {
832 return nil, err
833 }
834
835 ctx = updatedCtx
836
837 if prepared.Messages != nil {
838 stepInputMessages = prepared.Messages
839 }
840 if prepared.Model != nil {
841 stepModel = prepared.Model
842 }
843 if prepared.System != nil {
844 stepSystemPrompt = *prepared.System
845 }
846 if prepared.ToolChoice != nil {
847 stepToolChoice = *prepared.ToolChoice
848 }
849 if len(prepared.ActiveTools) > 0 {
850 stepActiveTools = prepared.ActiveTools
851 }
852 disableAllTools = prepared.DisableAllTools
853 }
854
855 // Recreate prompt with potentially modified system prompt
856 if stepSystemPrompt != a.settings.systemPrompt {
857 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
858 if err != nil {
859 return nil, err
860 }
861 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
862 stepInputMessages[0] = stepPrompt[0]
863 }
864 }
865
866 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
867
868 // Start step stream
869 if opts.OnStepStart != nil {
870 _ = opts.OnStepStart(stepNumber)
871 }
872
873 // Create streaming call
874 streamCall := Call{
875 Prompt: stepInputMessages,
876 MaxOutputTokens: call.MaxOutputTokens,
877 Temperature: call.Temperature,
878 TopP: call.TopP,
879 TopK: call.TopK,
880 PresencePenalty: call.PresencePenalty,
881 FrequencyPenalty: call.FrequencyPenalty,
882 Tools: preparedTools,
883 ToolChoice: &stepToolChoice,
884 ProviderOptions: call.ProviderOptions,
885 }
886
887 // Get streaming response
888 stream, err := stepModel.Stream(ctx, streamCall)
889 if err != nil {
890 if opts.OnError != nil {
891 opts.OnError(err)
892 }
893 return nil, err
894 }
895
896 // Process stream with tool execution
897 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
898 if err != nil {
899 if opts.OnError != nil {
900 opts.OnError(err)
901 }
902 return nil, err
903 }
904
905 steps = append(steps, stepResult)
906 totalUsage = addUsage(totalUsage, stepResult.Usage)
907
908 // Call step finished callback
909 if opts.OnStepFinish != nil {
910 _ = opts.OnStepFinish(stepResult)
911 }
912
913 // Add step messages to response messages
914 stepMessages := toResponseMessages(stepResult.Content)
915 responseMessages = append(responseMessages, stepMessages...)
916
917 // Check stop conditions
918 shouldStop := isStopConditionMet(call.StopWhen, steps)
919 if shouldStop || !shouldContinue {
920 break
921 }
922 }
923
924 // Finish agent stream
925 agentResult := &AgentResult{
926 Steps: steps,
927 Response: steps[len(steps)-1].Response,
928 TotalUsage: totalUsage,
929 }
930
931 if opts.OnFinish != nil {
932 opts.OnFinish(agentResult)
933 }
934
935 if opts.OnAgentFinish != nil {
936 _ = opts.OnAgentFinish(agentResult)
937 }
938
939 return agentResult, nil
940}
941
942func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
943 preparedTools := make([]Tool, 0, len(tools))
944
945 // If explicitly disabling all tools, return no tools
946 if disableAllTools {
947 return preparedTools
948 }
949
950 for _, tool := range tools {
951 // If activeTools has items, only include tools in the list
952 // If activeTools is empty, include all tools
953 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
954 continue
955 }
956 info := tool.Info()
957 preparedTools = append(preparedTools, FunctionTool{
958 Name: info.Name,
959 Description: info.Description,
960 InputSchema: map[string]any{
961 "type": "object",
962 "properties": info.Parameters,
963 "required": info.Required,
964 },
965 ProviderOptions: tool.ProviderOptions(),
966 })
967 }
968 return preparedTools
969}
970
971// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
972func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
973 if err := a.validateToolCall(toolCall, availableTools); err == nil {
974 return toolCall
975 } else { //nolint: revive
976 if repairFunc != nil {
977 repairOptions := ToolCallRepairOptions{
978 OriginalToolCall: toolCall,
979 ValidationError: err,
980 AvailableTools: availableTools,
981 SystemPrompt: systemPrompt,
982 Messages: messages,
983 }
984
985 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
986 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
987 return *repairedToolCall
988 }
989 }
990 }
991
992 invalidToolCall := toolCall
993 invalidToolCall.Invalid = true
994 invalidToolCall.ValidationError = err
995 return invalidToolCall
996 }
997}
998
999// validateToolCall validates a tool call against available tools and their schemas.
1000func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
1001 var tool AgentTool
1002 for _, t := range availableTools {
1003 if t.Info().Name == toolCall.ToolName {
1004 tool = t
1005 break
1006 }
1007 }
1008
1009 if tool == nil {
1010 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
1011 }
1012
1013 // Validate JSON parsing
1014 var input map[string]any
1015 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
1016 return fmt.Errorf("invalid JSON input: %w", err)
1017 }
1018
1019 // Basic schema validation (check required fields)
1020 // TODO: more robust schema validation using JSON Schema or similar
1021 toolInfo := tool.Info()
1022 for _, required := range toolInfo.Required {
1023 if _, exists := input[required]; !exists {
1024 return fmt.Errorf("missing required parameter: %s", required)
1025 }
1026 }
1027 return nil
1028}
1029
1030func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
1031 if prompt == "" {
1032 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
1033 }
1034
1035 var preparedPrompt Prompt
1036
1037 if system != "" {
1038 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1039 }
1040 preparedPrompt = append(preparedPrompt, messages...)
1041 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1042 return preparedPrompt, nil
1043}
1044
1045// WithSystemPrompt sets the system prompt for the agent.
1046func WithSystemPrompt(prompt string) AgentOption {
1047 return func(s *agentSettings) {
1048 s.systemPrompt = prompt
1049 }
1050}
1051
1052// WithMaxOutputTokens sets the maximum output tokens for the agent.
1053func WithMaxOutputTokens(tokens int64) AgentOption {
1054 return func(s *agentSettings) {
1055 s.maxOutputTokens = &tokens
1056 }
1057}
1058
1059// WithTemperature sets the temperature for the agent.
1060func WithTemperature(temp float64) AgentOption {
1061 return func(s *agentSettings) {
1062 s.temperature = &temp
1063 }
1064}
1065
1066// WithTopP sets the top-p value for the agent.
1067func WithTopP(topP float64) AgentOption {
1068 return func(s *agentSettings) {
1069 s.topP = &topP
1070 }
1071}
1072
1073// WithTopK sets the top-k value for the agent.
1074func WithTopK(topK int64) AgentOption {
1075 return func(s *agentSettings) {
1076 s.topK = &topK
1077 }
1078}
1079
1080// WithPresencePenalty sets the presence penalty for the agent.
1081func WithPresencePenalty(penalty float64) AgentOption {
1082 return func(s *agentSettings) {
1083 s.presencePenalty = &penalty
1084 }
1085}
1086
1087// WithFrequencyPenalty sets the frequency penalty for the agent.
1088func WithFrequencyPenalty(penalty float64) AgentOption {
1089 return func(s *agentSettings) {
1090 s.frequencyPenalty = &penalty
1091 }
1092}
1093
1094// WithTools sets the tools for the agent.
1095func WithTools(tools ...AgentTool) AgentOption {
1096 return func(s *agentSettings) {
1097 s.tools = append(s.tools, tools...)
1098 }
1099}
1100
1101// WithStopConditions sets the stop conditions for the agent.
1102func WithStopConditions(conditions ...StopCondition) AgentOption {
1103 return func(s *agentSettings) {
1104 s.stopWhen = append(s.stopWhen, conditions...)
1105 }
1106}
1107
1108// WithPrepareStep sets the prepare step function for the agent.
1109func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1110 return func(s *agentSettings) {
1111 s.prepareStep = fn
1112 }
1113}
1114
1115// WithRepairToolCall sets the repair tool call function for the agent.
1116func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1117 return func(s *agentSettings) {
1118 s.repairToolCall = fn
1119 }
1120}
1121
1122// processStepStream processes a single step's stream and returns the step result.
1123func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1124 var stepContent []Content
1125 var stepToolCalls []ToolCallContent
1126 var stepUsage Usage
1127 stepFinishReason := FinishReasonUnknown
1128 var stepWarnings []CallWarning
1129 var stepProviderMetadata ProviderMetadata
1130
1131 activeToolCalls := make(map[string]*ToolCallContent)
1132 activeTextContent := make(map[string]string)
1133 type reasoningContent struct {
1134 content string
1135 options ProviderMetadata
1136 }
1137 activeReasoningContent := make(map[string]reasoningContent)
1138
1139 // Process stream parts
1140 for part := range stream {
1141 // Forward all parts to chunk callback
1142 if opts.OnChunk != nil {
1143 err := opts.OnChunk(part)
1144 if err != nil {
1145 return StepResult{}, false, err
1146 }
1147 }
1148
1149 switch part.Type {
1150 case StreamPartTypeWarnings:
1151 stepWarnings = part.Warnings
1152 if opts.OnWarnings != nil {
1153 err := opts.OnWarnings(part.Warnings)
1154 if err != nil {
1155 return StepResult{}, false, err
1156 }
1157 }
1158
1159 case StreamPartTypeTextStart:
1160 activeTextContent[part.ID] = ""
1161 if opts.OnTextStart != nil {
1162 err := opts.OnTextStart(part.ID)
1163 if err != nil {
1164 return StepResult{}, false, err
1165 }
1166 }
1167
1168 case StreamPartTypeTextDelta:
1169 if _, exists := activeTextContent[part.ID]; exists {
1170 activeTextContent[part.ID] += part.Delta
1171 }
1172 if opts.OnTextDelta != nil {
1173 err := opts.OnTextDelta(part.ID, part.Delta)
1174 if err != nil {
1175 return StepResult{}, false, err
1176 }
1177 }
1178
1179 case StreamPartTypeTextEnd:
1180 if text, exists := activeTextContent[part.ID]; exists {
1181 stepContent = append(stepContent, TextContent{
1182 Text: text,
1183 ProviderMetadata: part.ProviderMetadata,
1184 })
1185 delete(activeTextContent, part.ID)
1186 }
1187 if opts.OnTextEnd != nil {
1188 err := opts.OnTextEnd(part.ID)
1189 if err != nil {
1190 return StepResult{}, false, err
1191 }
1192 }
1193
1194 case StreamPartTypeReasoningStart:
1195 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1196 if opts.OnReasoningStart != nil {
1197 content := ReasoningContent{
1198 Text: part.Delta,
1199 ProviderMetadata: part.ProviderMetadata,
1200 }
1201 err := opts.OnReasoningStart(part.ID, content)
1202 if err != nil {
1203 return StepResult{}, false, err
1204 }
1205 }
1206
1207 case StreamPartTypeReasoningDelta:
1208 if active, exists := activeReasoningContent[part.ID]; exists {
1209 active.content += part.Delta
1210 active.options = part.ProviderMetadata
1211 activeReasoningContent[part.ID] = active
1212 }
1213 if opts.OnReasoningDelta != nil {
1214 err := opts.OnReasoningDelta(part.ID, part.Delta)
1215 if err != nil {
1216 return StepResult{}, false, err
1217 }
1218 }
1219
1220 case StreamPartTypeReasoningEnd:
1221 if active, exists := activeReasoningContent[part.ID]; exists {
1222 if part.ProviderMetadata != nil {
1223 active.options = part.ProviderMetadata
1224 }
1225 content := ReasoningContent{
1226 Text: active.content,
1227 ProviderMetadata: active.options,
1228 }
1229 stepContent = append(stepContent, content)
1230 if opts.OnReasoningEnd != nil {
1231 err := opts.OnReasoningEnd(part.ID, content)
1232 if err != nil {
1233 return StepResult{}, false, err
1234 }
1235 }
1236 delete(activeReasoningContent, part.ID)
1237 }
1238
1239 case StreamPartTypeToolInputStart:
1240 activeToolCalls[part.ID] = &ToolCallContent{
1241 ToolCallID: part.ID,
1242 ToolName: part.ToolCallName,
1243 Input: "",
1244 ProviderExecuted: part.ProviderExecuted,
1245 }
1246 if opts.OnToolInputStart != nil {
1247 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1248 if err != nil {
1249 return StepResult{}, false, err
1250 }
1251 }
1252
1253 case StreamPartTypeToolInputDelta:
1254 if toolCall, exists := activeToolCalls[part.ID]; exists {
1255 toolCall.Input += part.Delta
1256 }
1257 if opts.OnToolInputDelta != nil {
1258 err := opts.OnToolInputDelta(part.ID, part.Delta)
1259 if err != nil {
1260 return StepResult{}, false, err
1261 }
1262 }
1263
1264 case StreamPartTypeToolInputEnd:
1265 if opts.OnToolInputEnd != nil {
1266 err := opts.OnToolInputEnd(part.ID)
1267 if err != nil {
1268 return StepResult{}, false, err
1269 }
1270 }
1271
1272 case StreamPartTypeToolCall:
1273 toolCall := ToolCallContent{
1274 ToolCallID: part.ID,
1275 ToolName: part.ToolCallName,
1276 Input: part.ToolCallInput,
1277 ProviderExecuted: part.ProviderExecuted,
1278 ProviderMetadata: part.ProviderMetadata,
1279 }
1280
1281 // Validate and potentially repair the tool call
1282 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1283 stepToolCalls = append(stepToolCalls, validatedToolCall)
1284 stepContent = append(stepContent, validatedToolCall)
1285
1286 if opts.OnToolCall != nil {
1287 err := opts.OnToolCall(validatedToolCall)
1288 if err != nil {
1289 return StepResult{}, false, err
1290 }
1291 }
1292
1293 // Clean up active tool call
1294 delete(activeToolCalls, part.ID)
1295
1296 case StreamPartTypeSource:
1297 sourceContent := SourceContent{
1298 SourceType: part.SourceType,
1299 ID: part.ID,
1300 URL: part.URL,
1301 Title: part.Title,
1302 ProviderMetadata: part.ProviderMetadata,
1303 }
1304 stepContent = append(stepContent, sourceContent)
1305 if opts.OnSource != nil {
1306 err := opts.OnSource(sourceContent)
1307 if err != nil {
1308 return StepResult{}, false, err
1309 }
1310 }
1311
1312 case StreamPartTypeFinish:
1313 stepUsage = part.Usage
1314 stepFinishReason = part.FinishReason
1315 stepProviderMetadata = part.ProviderMetadata
1316 if opts.OnStreamFinish != nil {
1317 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1318 if err != nil {
1319 return StepResult{}, false, err
1320 }
1321 }
1322
1323 case StreamPartTypeError:
1324 if opts.OnError != nil {
1325 opts.OnError(part.Error)
1326 }
1327 return StepResult{}, false, part.Error
1328 }
1329 }
1330
1331 // Execute tools if any
1332 var toolResults []ToolResultContent
1333 if len(stepToolCalls) > 0 {
1334 var err error
1335 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnBeforeToolExecution, opts.OnToolResult)
1336 if err != nil {
1337 return StepResult{}, false, err
1338 }
1339 // Add tool results to content
1340 for _, result := range toolResults {
1341 stepContent = append(stepContent, result)
1342 }
1343 }
1344
1345 stepResult := StepResult{
1346 Response: Response{
1347 Content: stepContent,
1348 FinishReason: stepFinishReason,
1349 Usage: stepUsage,
1350 Warnings: stepWarnings,
1351 ProviderMetadata: stepProviderMetadata,
1352 },
1353 Messages: toResponseMessages(stepContent),
1354 }
1355
1356 // Determine if we should continue (has tool calls and not stopped)
1357 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1358
1359 return stepResult, shouldContinue, nil
1360}
1361
1362func addUsage(a, b Usage) Usage {
1363 return Usage{
1364 InputTokens: a.InputTokens + b.InputTokens,
1365 OutputTokens: a.OutputTokens + b.OutputTokens,
1366 TotalTokens: a.TotalTokens + b.TotalTokens,
1367 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1368 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1369 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1370 }
1371}
1372
1373// WithHeaders sets the headers for the agent.
1374func WithHeaders(headers map[string]string) AgentOption {
1375 return func(s *agentSettings) {
1376 s.headers = headers
1377 }
1378}
1379
1380// WithProviderOptions sets the provider options for the agent.
1381func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1382 return func(s *agentSettings) {
1383 s.providerOptions = providerOptions
1384 }
1385}