1package ai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12)
13
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19type StopCondition = func(steps []StepResult) bool
20
21// StepCountIs returns a stop condition that stops after the specified number of steps.
22func StepCountIs(stepCount int) StopCondition {
23 return func(steps []StepResult) bool {
24 return len(steps) >= stepCount
25 }
26}
27
28// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
29func HasToolCall(toolName string) StopCondition {
30 return func(steps []StepResult) bool {
31 if len(steps) == 0 {
32 return false
33 }
34 lastStep := steps[len(steps)-1]
35 toolCalls := lastStep.Content.ToolCalls()
36 for _, toolCall := range toolCalls {
37 if toolCall.ToolName == toolName {
38 return true
39 }
40 }
41 return false
42 }
43}
44
45// HasContent returns a stop condition that stops when the specified content type appears in the last step.
46func HasContent(contentType ContentType) StopCondition {
47 return func(steps []StepResult) bool {
48 if len(steps) == 0 {
49 return false
50 }
51 lastStep := steps[len(steps)-1]
52 for _, content := range lastStep.Content {
53 if content.GetType() == contentType {
54 return true
55 }
56 }
57 return false
58 }
59}
60
61// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
62func FinishReasonIs(reason FinishReason) StopCondition {
63 return func(steps []StepResult) bool {
64 if len(steps) == 0 {
65 return false
66 }
67 lastStep := steps[len(steps)-1]
68 return lastStep.FinishReason == reason
69 }
70}
71
72// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
73func MaxTokensUsed(maxTokens int64) StopCondition {
74 return func(steps []StepResult) bool {
75 var totalTokens int64
76 for _, step := range steps {
77 totalTokens += step.Usage.TotalTokens
78 }
79 return totalTokens >= maxTokens
80 }
81}
82
83type PrepareStepFunctionOptions struct {
84 Steps []StepResult
85 StepNumber int
86 Model LanguageModel
87 Messages []Message
88}
89
90type PrepareStepResult struct {
91 Model LanguageModel
92 Messages []Message
93 System *string
94 ToolChoice *ToolChoice
95 ActiveTools []string
96 DisableAllTools bool
97}
98
99type ToolCallRepairOptions struct {
100 OriginalToolCall ToolCallContent
101 ValidationError error
102 AvailableTools []AgentTool
103 SystemPrompt string
104 Messages []Message
105}
106
107type (
108 PrepareStepFunction = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
109 OnStepFinishedFunction = func(step StepResult)
110 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
111)
112
113type agentSettings struct {
114 systemPrompt string
115 maxOutputTokens *int64
116 temperature *float64
117 topP *float64
118 topK *int64
119 presencePenalty *float64
120 frequencyPenalty *float64
121 headers map[string]string
122 providerOptions ProviderOptions
123
124 // TODO: add support for provider tools
125 tools []AgentTool
126 maxRetries *int
127
128 model LanguageModel
129
130 stopWhen []StopCondition
131 prepareStep PrepareStepFunction
132 repairToolCall RepairToolCallFunction
133 onRetry OnRetryCallback
134}
135
136type AgentCall struct {
137 Prompt string `json:"prompt"`
138 Files []FilePart `json:"files"`
139 Messages []Message `json:"messages"`
140 MaxOutputTokens *int64
141 Temperature *float64 `json:"temperature"`
142 TopP *float64 `json:"top_p"`
143 TopK *int64 `json:"top_k"`
144 PresencePenalty *float64 `json:"presence_penalty"`
145 FrequencyPenalty *float64 `json:"frequency_penalty"`
146 ActiveTools []string `json:"active_tools"`
147 Headers map[string]string
148 ProviderOptions ProviderOptions
149 OnRetry OnRetryCallback
150 MaxRetries *int
151
152 StopWhen []StopCondition
153 PrepareStep PrepareStepFunction
154 RepairToolCall RepairToolCallFunction
155}
156
157type AgentStreamCall struct {
158 Prompt string `json:"prompt"`
159 Files []FilePart `json:"files"`
160 Messages []Message `json:"messages"`
161 MaxOutputTokens *int64
162 Temperature *float64 `json:"temperature"`
163 TopP *float64 `json:"top_p"`
164 TopK *int64 `json:"top_k"`
165 PresencePenalty *float64 `json:"presence_penalty"`
166 FrequencyPenalty *float64 `json:"frequency_penalty"`
167 ActiveTools []string `json:"active_tools"`
168 Headers map[string]string
169 ProviderOptions ProviderOptions
170 OnRetry OnRetryCallback
171 MaxRetries *int
172
173 StopWhen []StopCondition
174 PrepareStep PrepareStepFunction
175 RepairToolCall RepairToolCallFunction
176
177 // Agent-level callbacks
178 OnAgentStart func() // Called when agent starts
179 OnAgentFinish func(result *AgentResult) error // Called when agent finishes
180 OnStepStart func(stepNumber int) error // Called when a step starts
181 OnStepFinish func(stepResult StepResult) error // Called when a step finishes
182 OnFinish func(result *AgentResult) // Called when entire agent completes
183 OnError func(error) // Called when an error occurs
184
185 // Stream part callbacks - called for each corresponding stream part type
186 OnChunk func(StreamPart) error // Called for each stream part (catch-all)
187 OnWarnings func(warnings []CallWarning) error // Called for warnings
188 OnTextStart func(id string) error // Called when text starts
189 OnTextDelta func(id, text string) error // Called for text deltas
190 OnTextEnd func(id string) error // Called when text ends
191 OnReasoningStart func(id string) error // Called when reasoning starts
192 OnReasoningDelta func(id, text string) error // Called for reasoning deltas
193 OnReasoningEnd func(id string, reasoning ReasoningContent) error // Called when reasoning ends
194 OnToolInputStart func(id, toolName string) error // Called when tool input starts
195 OnToolInputDelta func(id, delta string) error // Called for tool input deltas
196 OnToolInputEnd func(id string) error // Called when tool input ends
197 OnToolCall func(toolCall ToolCallContent) error // Called when tool call is complete
198 OnToolResult func(result ToolResultContent) error // Called when tool execution completes
199 OnSource func(source SourceContent) error // Called for source references
200 OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error // Called when stream finishes
201}
202
203type AgentResult struct {
204 Steps []StepResult
205 // Final response
206 Response Response
207 TotalUsage Usage
208}
209
210type Agent interface {
211 Generate(context.Context, AgentCall) (*AgentResult, error)
212 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
213}
214
215type AgentOption = func(*agentSettings)
216
217type agent struct {
218 settings agentSettings
219}
220
221func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
222 settings := agentSettings{
223 model: model,
224 }
225 for _, o := range opts {
226 o(&settings)
227 }
228 return &agent{
229 settings: settings,
230 }
231}
232
233func (a *agent) prepareCall(call AgentCall) AgentCall {
234 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
235 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
236 call.TopP = cmp.Or(call.TopP, a.settings.topP)
237 call.TopK = cmp.Or(call.TopK, a.settings.topK)
238 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
239 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
240 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
241
242 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
243 call.StopWhen = a.settings.stopWhen
244 }
245 if call.PrepareStep == nil && a.settings.prepareStep != nil {
246 call.PrepareStep = a.settings.prepareStep
247 }
248 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
249 call.RepairToolCall = a.settings.repairToolCall
250 }
251 if call.OnRetry == nil && a.settings.onRetry != nil {
252 call.OnRetry = a.settings.onRetry
253 }
254
255 providerOptions := ProviderOptions{}
256 if a.settings.providerOptions != nil {
257 maps.Copy(providerOptions, a.settings.providerOptions)
258 }
259 if call.ProviderOptions != nil {
260 maps.Copy(providerOptions, call.ProviderOptions)
261 }
262 call.ProviderOptions = providerOptions
263
264 headers := map[string]string{}
265
266 if a.settings.headers != nil {
267 maps.Copy(headers, a.settings.headers)
268 }
269
270 if call.Headers != nil {
271 maps.Copy(headers, call.Headers)
272 }
273 call.Headers = headers
274 return call
275}
276
277// Generate implements Agent.
278func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
279 opts = a.prepareCall(opts)
280 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
281 if err != nil {
282 return nil, err
283 }
284 var responseMessages []Message
285 var steps []StepResult
286
287 for {
288 stepInputMessages := append(initialPrompt, responseMessages...)
289 stepModel := a.settings.model
290 stepSystemPrompt := a.settings.systemPrompt
291 stepActiveTools := opts.ActiveTools
292 stepToolChoice := ToolChoiceAuto
293 disableAllTools := false
294
295 if opts.PrepareStep != nil {
296 prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
297 Model: stepModel,
298 Steps: steps,
299 StepNumber: len(steps),
300 Messages: stepInputMessages,
301 })
302 if err != nil {
303 return nil, err
304 }
305
306 // Apply prepared step modifications
307 if prepared.Messages != nil {
308 stepInputMessages = prepared.Messages
309 }
310 if prepared.Model != nil {
311 stepModel = prepared.Model
312 }
313 if prepared.System != nil {
314 stepSystemPrompt = *prepared.System
315 }
316 if prepared.ToolChoice != nil {
317 stepToolChoice = *prepared.ToolChoice
318 }
319 if len(prepared.ActiveTools) > 0 {
320 stepActiveTools = prepared.ActiveTools
321 }
322 disableAllTools = prepared.DisableAllTools
323 }
324
325 // Recreate prompt with potentially modified system prompt
326 if stepSystemPrompt != a.settings.systemPrompt {
327 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
328 if err != nil {
329 return nil, err
330 }
331 // Replace system message part, keep the rest
332 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
333 stepInputMessages[0] = stepPrompt[0] // Replace system message
334 }
335 }
336
337 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
338
339 retryOptions := DefaultRetryOptions()
340 retryOptions.OnRetry = opts.OnRetry
341 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
342
343 result, err := retry(ctx, func() (*Response, error) {
344 return stepModel.Generate(ctx, Call{
345 Prompt: stepInputMessages,
346 MaxOutputTokens: opts.MaxOutputTokens,
347 Temperature: opts.Temperature,
348 TopP: opts.TopP,
349 TopK: opts.TopK,
350 PresencePenalty: opts.PresencePenalty,
351 FrequencyPenalty: opts.FrequencyPenalty,
352 Tools: preparedTools,
353 ToolChoice: &stepToolChoice,
354 Headers: opts.Headers,
355 ProviderOptions: opts.ProviderOptions,
356 })
357 })
358 if err != nil {
359 return nil, err
360 }
361
362 var stepToolCalls []ToolCallContent
363 for _, content := range result.Content {
364 if content.GetType() == ContentTypeToolCall {
365 toolCall, ok := AsContentType[ToolCallContent](content)
366 if !ok {
367 continue
368 }
369
370 // Validate and potentially repair the tool call
371 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
372 stepToolCalls = append(stepToolCalls, validatedToolCall)
373 }
374 }
375
376 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
377
378 // Build step content with validated tool calls and tool results
379 stepContent := []Content{}
380 toolCallIndex := 0
381 for _, content := range result.Content {
382 if content.GetType() == ContentTypeToolCall {
383 // Replace with validated tool call
384 if toolCallIndex < len(stepToolCalls) {
385 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
386 toolCallIndex++
387 }
388 } else {
389 // Keep other content as-is
390 stepContent = append(stepContent, content)
391 }
392 }
393 // Add tool results
394 for _, result := range toolResults {
395 stepContent = append(stepContent, result)
396 }
397 currentStepMessages := toResponseMessages(stepContent)
398 responseMessages = append(responseMessages, currentStepMessages...)
399
400 stepResult := StepResult{
401 Response: Response{
402 Content: stepContent,
403 FinishReason: result.FinishReason,
404 Usage: result.Usage,
405 Warnings: result.Warnings,
406 ProviderMetadata: result.ProviderMetadata,
407 },
408 Messages: currentStepMessages,
409 }
410 steps = append(steps, stepResult)
411 shouldStop := isStopConditionMet(opts.StopWhen, steps)
412
413 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
414 break
415 }
416 }
417
418 totalUsage := Usage{}
419
420 for _, step := range steps {
421 usage := step.Usage
422 totalUsage.InputTokens += usage.InputTokens
423 totalUsage.OutputTokens += usage.OutputTokens
424 totalUsage.ReasoningTokens += usage.ReasoningTokens
425 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
426 totalUsage.CacheReadTokens += usage.CacheReadTokens
427 totalUsage.TotalTokens += usage.TotalTokens
428 }
429
430 agentResult := &AgentResult{
431 Steps: steps,
432 Response: steps[len(steps)-1].Response,
433 TotalUsage: totalUsage,
434 }
435 return agentResult, nil
436}
437
438func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
439 if len(conditions) == 0 {
440 return false
441 }
442
443 for _, condition := range conditions {
444 if condition(steps) {
445 return true
446 }
447 }
448 return false
449}
450
451func toResponseMessages(content []Content) []Message {
452 var assistantParts []MessagePart
453 var toolParts []MessagePart
454
455 for _, c := range content {
456 switch c.GetType() {
457 case ContentTypeText:
458 text, ok := AsContentType[TextContent](c)
459 if !ok {
460 continue
461 }
462 assistantParts = append(assistantParts, TextPart{
463 Text: text.Text,
464 ProviderOptions: ProviderOptions(text.ProviderMetadata),
465 })
466 case ContentTypeReasoning:
467 reasoning, ok := AsContentType[ReasoningContent](c)
468 if !ok {
469 continue
470 }
471 assistantParts = append(assistantParts, ReasoningPart{
472 Text: reasoning.Text,
473 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
474 })
475 case ContentTypeToolCall:
476 toolCall, ok := AsContentType[ToolCallContent](c)
477 if !ok {
478 continue
479 }
480 assistantParts = append(assistantParts, ToolCallPart{
481 ToolCallID: toolCall.ToolCallID,
482 ToolName: toolCall.ToolName,
483 Input: toolCall.Input,
484 ProviderExecuted: toolCall.ProviderExecuted,
485 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
486 })
487 case ContentTypeFile:
488 file, ok := AsContentType[FileContent](c)
489 if !ok {
490 continue
491 }
492 assistantParts = append(assistantParts, FilePart{
493 Data: file.Data,
494 MediaType: file.MediaType,
495 ProviderOptions: ProviderOptions(file.ProviderMetadata),
496 })
497 case ContentTypeSource:
498 // Sources are metadata about references used to generate the response.
499 // They don't need to be included in the conversation messages.
500 continue
501 case ContentTypeToolResult:
502 result, ok := AsContentType[ToolResultContent](c)
503 if !ok {
504 continue
505 }
506 toolParts = append(toolParts, ToolResultPart{
507 ToolCallID: result.ToolCallID,
508 Output: result.Result,
509 ProviderOptions: ProviderOptions(result.ProviderMetadata),
510 })
511 }
512 }
513
514 var messages []Message
515 if len(assistantParts) > 0 {
516 messages = append(messages, Message{
517 Role: MessageRoleAssistant,
518 Content: assistantParts,
519 })
520 }
521 if len(toolParts) > 0 {
522 messages = append(messages, Message{
523 Role: MessageRoleTool,
524 Content: toolParts,
525 })
526 }
527 return messages
528}
529
530func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
531 if len(toolCalls) == 0 {
532 return nil, nil
533 }
534
535 // Create a map for quick tool lookup
536 toolMap := make(map[string]AgentTool)
537 for _, tool := range allTools {
538 toolMap[tool.Info().Name] = tool
539 }
540
541 // Execute all tool calls in parallel
542 results := make([]ToolResultContent, len(toolCalls))
543 executeErrors := make([]error, len(toolCalls))
544
545 var wg sync.WaitGroup
546
547 for i, toolCall := range toolCalls {
548 wg.Add(1)
549 go func(index int, call ToolCallContent) {
550 defer wg.Done()
551
552 // Skip invalid tool calls - create error result
553 if call.Invalid {
554 results[index] = ToolResultContent{
555 ToolCallID: call.ToolCallID,
556 ToolName: call.ToolName,
557 Result: ToolResultOutputContentError{
558 Error: call.ValidationError,
559 },
560 ProviderExecuted: false,
561 }
562 if toolResultCallback != nil {
563 err := toolResultCallback(results[index])
564 if err != nil {
565 executeErrors[index] = err
566 }
567 }
568
569 return
570 }
571
572 tool, exists := toolMap[call.ToolName]
573 if !exists {
574 results[index] = ToolResultContent{
575 ToolCallID: call.ToolCallID,
576 ToolName: call.ToolName,
577 Result: ToolResultOutputContentError{
578 Error: errors.New("Error: Tool not found: " + call.ToolName),
579 },
580 ProviderExecuted: false,
581 }
582
583 if toolResultCallback != nil {
584 err := toolResultCallback(results[index])
585 if err != nil {
586 executeErrors[index] = err
587 }
588 }
589 return
590 }
591
592 // Execute the tool
593 result, err := tool.Run(ctx, ToolCall{
594 ID: call.ToolCallID,
595 Name: call.ToolName,
596 Input: call.Input,
597 })
598 if err != nil {
599 results[index] = ToolResultContent{
600 ToolCallID: call.ToolCallID,
601 ToolName: call.ToolName,
602 Result: ToolResultOutputContentError{
603 Error: err,
604 },
605 ClientMetadata: result.Metadata,
606 ProviderExecuted: false,
607 }
608 if toolResultCallback != nil {
609 cbErr := toolResultCallback(results[index])
610 if cbErr != nil {
611 executeErrors[index] = cbErr
612 }
613 }
614 executeErrors[index] = err
615 return
616 }
617
618 if result.IsError {
619 results[index] = ToolResultContent{
620 ToolCallID: call.ToolCallID,
621 ToolName: call.ToolName,
622 Result: ToolResultOutputContentError{
623 Error: errors.New(result.Content),
624 },
625 ClientMetadata: result.Metadata,
626 ProviderExecuted: false,
627 }
628
629 if toolResultCallback != nil {
630 err := toolResultCallback(results[index])
631 if err != nil {
632 executeErrors[index] = err
633 }
634 }
635 } else {
636 results[index] = ToolResultContent{
637 ToolCallID: call.ToolCallID,
638 ToolName: toolCall.ToolName,
639 Result: ToolResultOutputContentText{
640 Text: result.Content,
641 },
642 ClientMetadata: result.Metadata,
643 ProviderExecuted: false,
644 }
645 if toolResultCallback != nil {
646 err := toolResultCallback(results[index])
647 if err != nil {
648 executeErrors[index] = err
649 }
650 }
651 }
652 }(i, toolCall)
653 }
654
655 // Wait for all tool executions to complete
656 wg.Wait()
657
658 for _, err := range executeErrors {
659 if err != nil {
660 return nil, err
661 }
662 }
663
664 return results, nil
665}
666
667// Stream implements Agent.
668func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
669 // Convert AgentStreamCall to AgentCall for preparation
670 call := AgentCall{
671 Prompt: opts.Prompt,
672 Files: opts.Files,
673 Messages: opts.Messages,
674 MaxOutputTokens: opts.MaxOutputTokens,
675 Temperature: opts.Temperature,
676 TopP: opts.TopP,
677 TopK: opts.TopK,
678 PresencePenalty: opts.PresencePenalty,
679 FrequencyPenalty: opts.FrequencyPenalty,
680 ActiveTools: opts.ActiveTools,
681 Headers: opts.Headers,
682 ProviderOptions: opts.ProviderOptions,
683 MaxRetries: opts.MaxRetries,
684 StopWhen: opts.StopWhen,
685 PrepareStep: opts.PrepareStep,
686 RepairToolCall: opts.RepairToolCall,
687 }
688
689 call = a.prepareCall(call)
690
691 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
692 if err != nil {
693 return nil, err
694 }
695
696 var responseMessages []Message
697 var steps []StepResult
698 var totalUsage Usage
699
700 // Start agent stream
701 if opts.OnAgentStart != nil {
702 opts.OnAgentStart()
703 }
704
705 for stepNumber := 0; ; stepNumber++ {
706 stepInputMessages := append(initialPrompt, responseMessages...)
707 stepModel := a.settings.model
708 stepSystemPrompt := a.settings.systemPrompt
709 stepActiveTools := call.ActiveTools
710 stepToolChoice := ToolChoiceAuto
711 disableAllTools := false
712
713 // Apply step preparation if provided
714 if call.PrepareStep != nil {
715 prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
716 Model: stepModel,
717 Steps: steps,
718 StepNumber: stepNumber,
719 Messages: stepInputMessages,
720 })
721 if err != nil {
722 return nil, err
723 }
724
725 if prepared.Messages != nil {
726 stepInputMessages = prepared.Messages
727 }
728 if prepared.Model != nil {
729 stepModel = prepared.Model
730 }
731 if prepared.System != nil {
732 stepSystemPrompt = *prepared.System
733 }
734 if prepared.ToolChoice != nil {
735 stepToolChoice = *prepared.ToolChoice
736 }
737 if len(prepared.ActiveTools) > 0 {
738 stepActiveTools = prepared.ActiveTools
739 }
740 disableAllTools = prepared.DisableAllTools
741 }
742
743 // Recreate prompt with potentially modified system prompt
744 if stepSystemPrompt != a.settings.systemPrompt {
745 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
746 if err != nil {
747 return nil, err
748 }
749 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
750 stepInputMessages[0] = stepPrompt[0]
751 }
752 }
753
754 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
755
756 // Start step stream
757 if opts.OnStepStart != nil {
758 _ = opts.OnStepStart(stepNumber)
759 }
760
761 // Create streaming call
762 streamCall := Call{
763 Prompt: stepInputMessages,
764 MaxOutputTokens: call.MaxOutputTokens,
765 Temperature: call.Temperature,
766 TopP: call.TopP,
767 TopK: call.TopK,
768 PresencePenalty: call.PresencePenalty,
769 FrequencyPenalty: call.FrequencyPenalty,
770 Tools: preparedTools,
771 ToolChoice: &stepToolChoice,
772 Headers: call.Headers,
773 ProviderOptions: call.ProviderOptions,
774 }
775
776 // Get streaming response
777 stream, err := stepModel.Stream(ctx, streamCall)
778 if err != nil {
779 if opts.OnError != nil {
780 opts.OnError(err)
781 }
782 return nil, err
783 }
784
785 // Process stream with tool execution
786 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
787 if err != nil {
788 if opts.OnError != nil {
789 opts.OnError(err)
790 }
791 return nil, err
792 }
793
794 steps = append(steps, stepResult)
795 totalUsage = addUsage(totalUsage, stepResult.Usage)
796
797 // Call step finished callback
798 if opts.OnStepFinish != nil {
799 _ = opts.OnStepFinish(stepResult)
800 }
801
802 // Add step messages to response messages
803 stepMessages := toResponseMessages(stepResult.Content)
804 responseMessages = append(responseMessages, stepMessages...)
805
806 // Check stop conditions
807 shouldStop := isStopConditionMet(call.StopWhen, steps)
808 if shouldStop || !shouldContinue {
809 break
810 }
811 }
812
813 // Finish agent stream
814 agentResult := &AgentResult{
815 Steps: steps,
816 Response: steps[len(steps)-1].Response,
817 TotalUsage: totalUsage,
818 }
819
820 if opts.OnFinish != nil {
821 opts.OnFinish(agentResult)
822 }
823
824 if opts.OnAgentFinish != nil {
825 _ = opts.OnAgentFinish(agentResult)
826 }
827
828 return agentResult, nil
829}
830
831func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
832 preparedTools := make([]Tool, 0, len(tools))
833
834 // If explicitly disabling all tools, return no tools
835 if disableAllTools {
836 return preparedTools
837 }
838
839 for _, tool := range tools {
840 // If activeTools has items, only include tools in the list
841 // If activeTools is empty, include all tools
842 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
843 continue
844 }
845 info := tool.Info()
846 preparedTools = append(preparedTools, FunctionTool{
847 Name: info.Name,
848 Description: info.Description,
849 InputSchema: map[string]any{
850 "type": "object",
851 "properties": info.Parameters,
852 "required": info.Required,
853 },
854 })
855 }
856 return preparedTools
857}
858
859// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
860func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
861 if err := a.validateToolCall(toolCall, availableTools); err == nil {
862 return toolCall
863 } else {
864 if repairFunc != nil {
865 repairOptions := ToolCallRepairOptions{
866 OriginalToolCall: toolCall,
867 ValidationError: err,
868 AvailableTools: availableTools,
869 SystemPrompt: systemPrompt,
870 Messages: messages,
871 }
872
873 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
874 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
875 return *repairedToolCall
876 }
877 }
878 }
879
880 invalidToolCall := toolCall
881 invalidToolCall.Invalid = true
882 invalidToolCall.ValidationError = err
883 return invalidToolCall
884 }
885}
886
887// validateToolCall validates a tool call against available tools and their schemas.
888func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
889 var tool AgentTool
890 for _, t := range availableTools {
891 if t.Info().Name == toolCall.ToolName {
892 tool = t
893 break
894 }
895 }
896
897 if tool == nil {
898 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
899 }
900
901 // Validate JSON parsing
902 var input map[string]any
903 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
904 return fmt.Errorf("invalid JSON input: %w", err)
905 }
906
907 // Basic schema validation (check required fields)
908 // TODO: more robust schema validation using JSON Schema or similar
909 toolInfo := tool.Info()
910 for _, required := range toolInfo.Required {
911 if _, exists := input[required]; !exists {
912 return fmt.Errorf("missing required parameter: %s", required)
913 }
914 }
915 return nil
916}
917
918func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
919 if prompt == "" {
920 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
921 }
922
923 var preparedPrompt Prompt
924
925 if system != "" {
926 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
927 }
928
929 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
930 preparedPrompt = append(preparedPrompt, messages...)
931 return preparedPrompt, nil
932}
933
934func WithSystemPrompt(prompt string) AgentOption {
935 return func(s *agentSettings) {
936 s.systemPrompt = prompt
937 }
938}
939
940func WithMaxOutputTokens(tokens int64) AgentOption {
941 return func(s *agentSettings) {
942 s.maxOutputTokens = &tokens
943 }
944}
945
946func WithTemperature(temp float64) AgentOption {
947 return func(s *agentSettings) {
948 s.temperature = &temp
949 }
950}
951
952func WithTopP(topP float64) AgentOption {
953 return func(s *agentSettings) {
954 s.topP = &topP
955 }
956}
957
958func WithTopK(topK int64) AgentOption {
959 return func(s *agentSettings) {
960 s.topK = &topK
961 }
962}
963
964func WithPresencePenalty(penalty float64) AgentOption {
965 return func(s *agentSettings) {
966 s.presencePenalty = &penalty
967 }
968}
969
970func WithFrequencyPenalty(penalty float64) AgentOption {
971 return func(s *agentSettings) {
972 s.frequencyPenalty = &penalty
973 }
974}
975
976func WithTools(tools ...AgentTool) AgentOption {
977 return func(s *agentSettings) {
978 s.tools = append(s.tools, tools...)
979 }
980}
981
982func WithStopConditions(conditions ...StopCondition) AgentOption {
983 return func(s *agentSettings) {
984 s.stopWhen = append(s.stopWhen, conditions...)
985 }
986}
987
988func WithPrepareStep(fn PrepareStepFunction) AgentOption {
989 return func(s *agentSettings) {
990 s.prepareStep = fn
991 }
992}
993
994func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
995 return func(s *agentSettings) {
996 s.repairToolCall = fn
997 }
998}
999
1000// processStepStream processes a single step's stream and returns the step result.
1001func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1002 var stepContent []Content
1003 var stepToolCalls []ToolCallContent
1004 var stepUsage Usage
1005 stepFinishReason := FinishReasonUnknown
1006 var stepWarnings []CallWarning
1007 var stepProviderMetadata ProviderMetadata
1008
1009 activeToolCalls := make(map[string]*ToolCallContent)
1010 activeTextContent := make(map[string]string)
1011
1012 // Process stream parts
1013 for part := range stream {
1014 // Forward all parts to chunk callback
1015 if opts.OnChunk != nil {
1016 err := opts.OnChunk(part)
1017 if err != nil {
1018 return StepResult{}, false, err
1019 }
1020 }
1021
1022 switch part.Type {
1023 case StreamPartTypeWarnings:
1024 stepWarnings = part.Warnings
1025 if opts.OnWarnings != nil {
1026 err := opts.OnWarnings(part.Warnings)
1027 if err != nil {
1028 return StepResult{}, false, err
1029 }
1030 }
1031
1032 case StreamPartTypeTextStart:
1033 activeTextContent[part.ID] = ""
1034 if opts.OnTextStart != nil {
1035 err := opts.OnTextStart(part.ID)
1036 if err != nil {
1037 return StepResult{}, false, err
1038 }
1039 }
1040
1041 case StreamPartTypeTextDelta:
1042 if _, exists := activeTextContent[part.ID]; exists {
1043 activeTextContent[part.ID] += part.Delta
1044 }
1045 if opts.OnTextDelta != nil {
1046 err := opts.OnTextDelta(part.ID, part.Delta)
1047 if err != nil {
1048 return StepResult{}, false, err
1049 }
1050 }
1051
1052 case StreamPartTypeTextEnd:
1053 if text, exists := activeTextContent[part.ID]; exists {
1054 stepContent = append(stepContent, TextContent{
1055 Text: text,
1056 ProviderMetadata: part.ProviderMetadata,
1057 })
1058 delete(activeTextContent, part.ID)
1059 }
1060 if opts.OnTextEnd != nil {
1061 err := opts.OnTextEnd(part.ID)
1062 if err != nil {
1063 return StepResult{}, false, err
1064 }
1065 }
1066
1067 case StreamPartTypeReasoningStart:
1068 activeTextContent[part.ID] = ""
1069 if opts.OnReasoningStart != nil {
1070 err := opts.OnReasoningStart(part.ID)
1071 if err != nil {
1072 return StepResult{}, false, err
1073 }
1074 }
1075
1076 case StreamPartTypeReasoningDelta:
1077 if _, exists := activeTextContent[part.ID]; exists {
1078 activeTextContent[part.ID] += part.Delta
1079 }
1080 if opts.OnReasoningDelta != nil {
1081 err := opts.OnReasoningDelta(part.ID, part.Delta)
1082 if err != nil {
1083 return StepResult{}, false, err
1084 }
1085 }
1086
1087 case StreamPartTypeReasoningEnd:
1088 if text, exists := activeTextContent[part.ID]; exists {
1089 stepContent = append(stepContent, ReasoningContent{
1090 Text: text,
1091 ProviderMetadata: part.ProviderMetadata,
1092 })
1093 if opts.OnReasoningEnd != nil {
1094 err := opts.OnReasoningEnd(part.ID, ReasoningContent{
1095 Text: text,
1096 ProviderMetadata: part.ProviderMetadata,
1097 })
1098 if err != nil {
1099 return StepResult{}, false, err
1100 }
1101 }
1102 delete(activeTextContent, part.ID)
1103 }
1104
1105 case StreamPartTypeToolInputStart:
1106 activeToolCalls[part.ID] = &ToolCallContent{
1107 ToolCallID: part.ID,
1108 ToolName: part.ToolCallName,
1109 Input: "",
1110 ProviderExecuted: part.ProviderExecuted,
1111 }
1112 if opts.OnToolInputStart != nil {
1113 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1114 if err != nil {
1115 return StepResult{}, false, err
1116 }
1117 }
1118
1119 case StreamPartTypeToolInputDelta:
1120 if toolCall, exists := activeToolCalls[part.ID]; exists {
1121 toolCall.Input += part.Delta
1122 }
1123 if opts.OnToolInputDelta != nil {
1124 err := opts.OnToolInputDelta(part.ID, part.Delta)
1125 if err != nil {
1126 return StepResult{}, false, err
1127 }
1128 }
1129
1130 case StreamPartTypeToolInputEnd:
1131 if opts.OnToolInputEnd != nil {
1132 err := opts.OnToolInputEnd(part.ID)
1133 if err != nil {
1134 return StepResult{}, false, err
1135 }
1136 }
1137
1138 case StreamPartTypeToolCall:
1139 toolCall := ToolCallContent{
1140 ToolCallID: part.ID,
1141 ToolName: part.ToolCallName,
1142 Input: part.ToolCallInput,
1143 ProviderExecuted: part.ProviderExecuted,
1144 ProviderMetadata: part.ProviderMetadata,
1145 }
1146
1147 // Validate and potentially repair the tool call
1148 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1149 stepToolCalls = append(stepToolCalls, validatedToolCall)
1150 stepContent = append(stepContent, validatedToolCall)
1151
1152 if opts.OnToolCall != nil {
1153 err := opts.OnToolCall(validatedToolCall)
1154 if err != nil {
1155 return StepResult{}, false, err
1156 }
1157 }
1158
1159 // Clean up active tool call
1160 delete(activeToolCalls, part.ID)
1161
1162 case StreamPartTypeSource:
1163 sourceContent := SourceContent{
1164 SourceType: part.SourceType,
1165 ID: part.ID,
1166 URL: part.URL,
1167 Title: part.Title,
1168 ProviderMetadata: part.ProviderMetadata,
1169 }
1170 stepContent = append(stepContent, sourceContent)
1171 if opts.OnSource != nil {
1172 err := opts.OnSource(sourceContent)
1173 if err != nil {
1174 return StepResult{}, false, err
1175 }
1176 }
1177
1178 case StreamPartTypeFinish:
1179 stepUsage = part.Usage
1180 stepFinishReason = part.FinishReason
1181 stepProviderMetadata = part.ProviderMetadata
1182 if opts.OnStreamFinish != nil {
1183 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1184 if err != nil {
1185 return StepResult{}, false, err
1186 }
1187 }
1188
1189 case StreamPartTypeError:
1190 if opts.OnError != nil {
1191 opts.OnError(part.Error)
1192 }
1193 return StepResult{}, false, part.Error
1194 }
1195 }
1196
1197 // Execute tools if any
1198 var toolResults []ToolResultContent
1199 if len(stepToolCalls) > 0 {
1200 var err error
1201 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1202 if err != nil {
1203 return StepResult{}, false, err
1204 }
1205 // Add tool results to content
1206 for _, result := range toolResults {
1207 stepContent = append(stepContent, result)
1208 }
1209 }
1210
1211 stepResult := StepResult{
1212 Response: Response{
1213 Content: stepContent,
1214 FinishReason: stepFinishReason,
1215 Usage: stepUsage,
1216 Warnings: stepWarnings,
1217 ProviderMetadata: stepProviderMetadata,
1218 },
1219 Messages: toResponseMessages(stepContent),
1220 }
1221
1222 // Determine if we should continue (has tool calls and not stopped)
1223 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1224
1225 return stepResult, shouldContinue, nil
1226}
1227
1228func addUsage(a, b Usage) Usage {
1229 return Usage{
1230 InputTokens: a.InputTokens + b.InputTokens,
1231 OutputTokens: a.OutputTokens + b.OutputTokens,
1232 TotalTokens: a.TotalTokens + b.TotalTokens,
1233 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1234 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1235 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1236 }
1237}
1238
1239func WithHeaders(headers map[string]string) AgentOption {
1240 return func(s *agentSettings) {
1241 s.headers = headers
1242 }
1243}
1244
1245func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1246 return func(s *agentSettings) {
1247 s.providerOptions = providerOptions
1248 }
1249}