1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sync"
11)
12
13type StepResult struct {
14 Response
15 Messages []Message
16}
17
18type StopCondition = func(steps []StepResult) bool
19
20// StepCountIs returns a stop condition that stops after the specified number of steps.
21func StepCountIs(stepCount int) StopCondition {
22 return func(steps []StepResult) bool {
23 return len(steps) >= stepCount
24 }
25}
26
27// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
28func HasToolCall(toolName string) StopCondition {
29 return func(steps []StepResult) bool {
30 if len(steps) == 0 {
31 return false
32 }
33 lastStep := steps[len(steps)-1]
34 toolCalls := lastStep.Content.ToolCalls()
35 for _, toolCall := range toolCalls {
36 if toolCall.ToolName == toolName {
37 return true
38 }
39 }
40 return false
41 }
42}
43
44// HasContent returns a stop condition that stops when the specified content type appears in the last step.
45func HasContent(contentType ContentType) StopCondition {
46 return func(steps []StepResult) bool {
47 if len(steps) == 0 {
48 return false
49 }
50 lastStep := steps[len(steps)-1]
51 for _, content := range lastStep.Content {
52 if content.GetType() == contentType {
53 return true
54 }
55 }
56 return false
57 }
58}
59
60// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
61func FinishReasonIs(reason FinishReason) StopCondition {
62 return func(steps []StepResult) bool {
63 if len(steps) == 0 {
64 return false
65 }
66 lastStep := steps[len(steps)-1]
67 return lastStep.FinishReason == reason
68 }
69}
70
71// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
72func MaxTokensUsed(maxTokens int64) StopCondition {
73 return func(steps []StepResult) bool {
74 var totalTokens int64
75 for _, step := range steps {
76 totalTokens += step.Usage.TotalTokens
77 }
78 return totalTokens >= maxTokens
79 }
80}
81
82type PrepareStepFunctionOptions struct {
83 Steps []StepResult
84 StepNumber int
85 Model LanguageModel
86 Messages []Message
87}
88
89type PrepareStepResult struct {
90 Model LanguageModel
91 Messages []Message
92 System *string
93 ToolChoice *ToolChoice
94 ActiveTools []string
95 DisableAllTools bool
96}
97
98type ToolCallRepairOptions struct {
99 OriginalToolCall ToolCallContent
100 ValidationError error
101 AvailableTools []AgentTool
102 SystemPrompt string
103 Messages []Message
104}
105
106type (
107 PrepareStepFunction = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
108 OnStepFinishedFunction = func(step StepResult)
109 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
110)
111
112type agentSettings struct {
113 systemPrompt string
114 maxOutputTokens *int64
115 temperature *float64
116 topP *float64
117 topK *int64
118 presencePenalty *float64
119 frequencyPenalty *float64
120 headers map[string]string
121 providerOptions ProviderOptions
122
123 // TODO: add support for provider tools
124 tools []AgentTool
125 maxRetries *int
126
127 model LanguageModel
128
129 stopWhen []StopCondition
130 prepareStep PrepareStepFunction
131 repairToolCall RepairToolCallFunction
132 onRetry OnRetryCallback
133}
134
135type AgentCall struct {
136 Prompt string `json:"prompt"`
137 Files []FilePart `json:"files"`
138 Messages []Message `json:"messages"`
139 MaxOutputTokens *int64
140 Temperature *float64 `json:"temperature"`
141 TopP *float64 `json:"top_p"`
142 TopK *int64 `json:"top_k"`
143 PresencePenalty *float64 `json:"presence_penalty"`
144 FrequencyPenalty *float64 `json:"frequency_penalty"`
145 ActiveTools []string `json:"active_tools"`
146 Headers map[string]string
147 ProviderOptions ProviderOptions
148 OnRetry OnRetryCallback
149 MaxRetries *int
150
151 StopWhen []StopCondition
152 PrepareStep PrepareStepFunction
153 RepairToolCall RepairToolCallFunction
154}
155
156type AgentStreamCall struct {
157 Prompt string `json:"prompt"`
158 Files []FilePart `json:"files"`
159 Messages []Message `json:"messages"`
160 MaxOutputTokens *int64
161 Temperature *float64 `json:"temperature"`
162 TopP *float64 `json:"top_p"`
163 TopK *int64 `json:"top_k"`
164 PresencePenalty *float64 `json:"presence_penalty"`
165 FrequencyPenalty *float64 `json:"frequency_penalty"`
166 ActiveTools []string `json:"active_tools"`
167 Headers map[string]string
168 ProviderOptions ProviderOptions
169 OnRetry OnRetryCallback
170 MaxRetries *int
171
172 StopWhen []StopCondition
173 PrepareStep PrepareStepFunction
174 RepairToolCall RepairToolCallFunction
175
176 // Agent-level callbacks
177 OnAgentStart func() // Called when agent starts
178 OnAgentFinish func(result *AgentResult) // Called when agent finishes
179 OnStepStart func(stepNumber int) // Called when a step starts
180 OnStepFinish func(stepResult StepResult) // Called when a step finishes
181 OnFinish func(result *AgentResult) // Called when entire agent completes
182 OnError func(error) // Called when an error occurs
183
184 // Stream part callbacks - called for each corresponding stream part type
185 OnChunk func(StreamPart) // Called for each stream part (catch-all)
186 OnWarnings func(warnings []CallWarning) // Called for warnings
187 OnTextStart func(id string) // Called when text starts
188 OnTextDelta func(id, text string) // Called for text deltas
189 OnTextEnd func(id string) // Called when text ends
190 OnReasoningStart func(id string) // Called when reasoning starts
191 OnReasoningDelta func(id, text string) // Called for reasoning deltas
192 OnReasoningEnd func(id string, reasoning ReasoningContent) // Called when reasoning ends
193 OnToolInputStart func(id, toolName string) // Called when tool input starts
194 OnToolInputDelta func(id, delta string) // Called for tool input deltas
195 OnToolInputEnd func(id string) // Called when tool input ends
196 OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete
197 OnToolResult func(result ToolResultContent) // Called when tool execution completes
198 OnSource func(source SourceContent) // Called for source references
199 OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) // Called when stream finishes
200 OnStreamError func(error) // Called when stream error occurs
201}
202
203type AgentResult struct {
204 Steps []StepResult
205 // Final response
206 Response Response
207 TotalUsage Usage
208}
209
210type Agent interface {
211 Generate(context.Context, AgentCall) (*AgentResult, error)
212 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
213}
214
215type AgentOption = func(*agentSettings)
216
217type agent struct {
218 settings agentSettings
219}
220
221func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
222 settings := agentSettings{
223 model: model,
224 }
225 for _, o := range opts {
226 o(&settings)
227 }
228 return &agent{
229 settings: settings,
230 }
231}
232
233func (a *agent) prepareCall(call AgentCall) AgentCall {
234 if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
235 call.MaxOutputTokens = a.settings.maxOutputTokens
236 }
237 if call.Temperature == nil && a.settings.temperature != nil {
238 call.Temperature = a.settings.temperature
239 }
240 if call.TopP == nil && a.settings.topP != nil {
241 call.TopP = a.settings.topP
242 }
243 if call.TopK == nil && a.settings.topK != nil {
244 call.TopK = a.settings.topK
245 }
246 if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
247 call.PresencePenalty = a.settings.presencePenalty
248 }
249 if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
250 call.FrequencyPenalty = a.settings.frequencyPenalty
251 }
252 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
253 call.StopWhen = a.settings.stopWhen
254 }
255 if call.PrepareStep == nil && a.settings.prepareStep != nil {
256 call.PrepareStep = a.settings.prepareStep
257 }
258 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
259 call.RepairToolCall = a.settings.repairToolCall
260 }
261 if call.OnRetry == nil && a.settings.onRetry != nil {
262 call.OnRetry = a.settings.onRetry
263 }
264 if call.MaxRetries == nil && a.settings.maxRetries != nil {
265 call.MaxRetries = a.settings.maxRetries
266 }
267
268 providerOptions := ProviderOptions{}
269 if a.settings.providerOptions != nil {
270 maps.Copy(providerOptions, a.settings.providerOptions)
271 }
272 if call.ProviderOptions != nil {
273 maps.Copy(providerOptions, call.ProviderOptions)
274 }
275 call.ProviderOptions = providerOptions
276
277 headers := map[string]string{}
278
279 if a.settings.headers != nil {
280 maps.Copy(headers, a.settings.headers)
281 }
282
283 if call.Headers != nil {
284 maps.Copy(headers, call.Headers)
285 }
286 call.Headers = headers
287 return call
288}
289
290// Generate implements Agent.
291func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
292 opts = a.prepareCall(opts)
293 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
294 if err != nil {
295 return nil, err
296 }
297 var responseMessages []Message
298 var steps []StepResult
299
300 for {
301 stepInputMessages := append(initialPrompt, responseMessages...)
302 stepModel := a.settings.model
303 stepSystemPrompt := a.settings.systemPrompt
304 stepActiveTools := opts.ActiveTools
305 stepToolChoice := ToolChoiceAuto
306 disableAllTools := false
307
308 if opts.PrepareStep != nil {
309 prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
310 Model: stepModel,
311 Steps: steps,
312 StepNumber: len(steps),
313 Messages: stepInputMessages,
314 })
315 if err != nil {
316 return nil, err
317 }
318
319 // Apply prepared step modifications
320 if prepared.Messages != nil {
321 stepInputMessages = prepared.Messages
322 }
323 if prepared.Model != nil {
324 stepModel = prepared.Model
325 }
326 if prepared.System != nil {
327 stepSystemPrompt = *prepared.System
328 }
329 if prepared.ToolChoice != nil {
330 stepToolChoice = *prepared.ToolChoice
331 }
332 if len(prepared.ActiveTools) > 0 {
333 stepActiveTools = prepared.ActiveTools
334 }
335 disableAllTools = prepared.DisableAllTools
336 }
337
338 // Recreate prompt with potentially modified system prompt
339 if stepSystemPrompt != a.settings.systemPrompt {
340 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
341 if err != nil {
342 return nil, err
343 }
344 // Replace system message part, keep the rest
345 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
346 stepInputMessages[0] = stepPrompt[0] // Replace system message
347 }
348 }
349
350 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
351
352 retryOptions := DefaultRetryOptions()
353 retryOptions.OnRetry = opts.OnRetry
354 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
355
356 result, err := retry(ctx, func() (*Response, error) {
357 return stepModel.Generate(ctx, Call{
358 Prompt: stepInputMessages,
359 MaxOutputTokens: opts.MaxOutputTokens,
360 Temperature: opts.Temperature,
361 TopP: opts.TopP,
362 TopK: opts.TopK,
363 PresencePenalty: opts.PresencePenalty,
364 FrequencyPenalty: opts.FrequencyPenalty,
365 Tools: preparedTools,
366 ToolChoice: &stepToolChoice,
367 Headers: opts.Headers,
368 ProviderOptions: opts.ProviderOptions,
369 })
370 })
371 if err != nil {
372 return nil, err
373 }
374
375 var stepToolCalls []ToolCallContent
376 for _, content := range result.Content {
377 if content.GetType() == ContentTypeToolCall {
378 toolCall, ok := AsContentType[ToolCallContent](content)
379 if !ok {
380 continue
381 }
382
383 // Validate and potentially repair the tool call
384 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
385 stepToolCalls = append(stepToolCalls, validatedToolCall)
386 }
387 }
388
389 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
390
391 // Build step content with validated tool calls and tool results
392 stepContent := []Content{}
393 toolCallIndex := 0
394 for _, content := range result.Content {
395 if content.GetType() == ContentTypeToolCall {
396 // Replace with validated tool call
397 if toolCallIndex < len(stepToolCalls) {
398 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
399 toolCallIndex++
400 }
401 } else {
402 // Keep other content as-is
403 stepContent = append(stepContent, content)
404 }
405 }
406 // Add tool results
407 for _, result := range toolResults {
408 stepContent = append(stepContent, result)
409 }
410 currentStepMessages := toResponseMessages(stepContent)
411 responseMessages = append(responseMessages, currentStepMessages...)
412
413 stepResult := StepResult{
414 Response: Response{
415 Content: stepContent,
416 FinishReason: result.FinishReason,
417 Usage: result.Usage,
418 Warnings: result.Warnings,
419 ProviderMetadata: result.ProviderMetadata,
420 },
421 Messages: currentStepMessages,
422 }
423 steps = append(steps, stepResult)
424 shouldStop := isStopConditionMet(opts.StopWhen, steps)
425
426 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
427 break
428 }
429 }
430
431 totalUsage := Usage{}
432
433 for _, step := range steps {
434 usage := step.Usage
435 totalUsage.InputTokens += usage.InputTokens
436 totalUsage.OutputTokens += usage.OutputTokens
437 totalUsage.ReasoningTokens += usage.ReasoningTokens
438 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
439 totalUsage.CacheReadTokens += usage.CacheReadTokens
440 totalUsage.TotalTokens += usage.TotalTokens
441 }
442
443 agentResult := &AgentResult{
444 Steps: steps,
445 Response: steps[len(steps)-1].Response,
446 TotalUsage: totalUsage,
447 }
448 return agentResult, nil
449}
450
451func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
452 if len(conditions) == 0 {
453 return false
454 }
455
456 for _, condition := range conditions {
457 if condition(steps) {
458 return true
459 }
460 }
461 return false
462}
463
464func toResponseMessages(content []Content) []Message {
465 var assistantParts []MessagePart
466 var toolParts []MessagePart
467
468 for _, c := range content {
469 switch c.GetType() {
470 case ContentTypeText:
471 text, ok := AsContentType[TextContent](c)
472 if !ok {
473 continue
474 }
475 assistantParts = append(assistantParts, TextPart{
476 Text: text.Text,
477 ProviderOptions: ProviderOptions(text.ProviderMetadata),
478 })
479 case ContentTypeReasoning:
480 reasoning, ok := AsContentType[ReasoningContent](c)
481 if !ok {
482 continue
483 }
484 assistantParts = append(assistantParts, ReasoningPart{
485 Text: reasoning.Text,
486 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
487 })
488 case ContentTypeToolCall:
489 toolCall, ok := AsContentType[ToolCallContent](c)
490 if !ok {
491 continue
492 }
493 assistantParts = append(assistantParts, ToolCallPart{
494 ToolCallID: toolCall.ToolCallID,
495 ToolName: toolCall.ToolName,
496 Input: toolCall.Input,
497 ProviderExecuted: toolCall.ProviderExecuted,
498 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
499 })
500 case ContentTypeFile:
501 file, ok := AsContentType[FileContent](c)
502 if !ok {
503 continue
504 }
505 assistantParts = append(assistantParts, FilePart{
506 Data: file.Data,
507 MediaType: file.MediaType,
508 ProviderOptions: ProviderOptions(file.ProviderMetadata),
509 })
510 case ContentTypeSource:
511 // Sources are metadata about references used to generate the response.
512 // They don't need to be included in the conversation messages.
513 continue
514 case ContentTypeToolResult:
515 result, ok := AsContentType[ToolResultContent](c)
516 if !ok {
517 continue
518 }
519 toolParts = append(toolParts, ToolResultPart{
520 ToolCallID: result.ToolCallID,
521 Output: result.Result,
522 ProviderOptions: ProviderOptions(result.ProviderMetadata),
523 })
524 }
525 }
526
527 var messages []Message
528 if len(assistantParts) > 0 {
529 messages = append(messages, Message{
530 Role: MessageRoleAssistant,
531 Content: assistantParts,
532 })
533 }
534 if len(toolParts) > 0 {
535 messages = append(messages, Message{
536 Role: MessageRoleTool,
537 Content: toolParts,
538 })
539 }
540 return messages
541}
542
543func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
544 if len(toolCalls) == 0 {
545 return nil, nil
546 }
547
548 // Create a map for quick tool lookup
549 toolMap := make(map[string]AgentTool)
550 for _, tool := range allTools {
551 toolMap[tool.Info().Name] = tool
552 }
553
554 // Execute all tool calls in parallel
555 results := make([]ToolResultContent, len(toolCalls))
556 var toolExecutionError error
557 var wg sync.WaitGroup
558
559 for i, toolCall := range toolCalls {
560 wg.Add(1)
561 go func(index int, call ToolCallContent) {
562 defer wg.Done()
563
564 // Skip invalid tool calls - create error result
565 if call.Invalid {
566 results[index] = ToolResultContent{
567 ToolCallID: call.ToolCallID,
568 ToolName: call.ToolName,
569 Result: ToolResultOutputContentError{
570 Error: call.ValidationError,
571 },
572 ProviderExecuted: false,
573 }
574 if toolResultCallback != nil {
575 toolResultCallback(results[index])
576 }
577
578 return
579 }
580
581 tool, exists := toolMap[call.ToolName]
582 if !exists {
583 results[index] = ToolResultContent{
584 ToolCallID: call.ToolCallID,
585 ToolName: call.ToolName,
586 Result: ToolResultOutputContentError{
587 Error: errors.New("Error: Tool not found: " + call.ToolName),
588 },
589 ProviderExecuted: false,
590 }
591
592 if toolResultCallback != nil {
593 toolResultCallback(results[index])
594 }
595 return
596 }
597
598 // Execute the tool
599 result, err := tool.Run(ctx, ToolCall{
600 ID: call.ToolCallID,
601 Name: call.ToolName,
602 Input: call.Input,
603 })
604 if err != nil {
605 results[index] = ToolResultContent{
606 ToolCallID: call.ToolCallID,
607 ToolName: call.ToolName,
608 Result: ToolResultOutputContentError{
609 Error: err,
610 },
611 ClientMetadata: result.Metadata,
612 ProviderExecuted: false,
613 }
614 if toolResultCallback != nil {
615 toolResultCallback(results[index])
616 }
617 toolExecutionError = err
618 return
619 }
620
621 if result.IsError {
622 results[index] = ToolResultContent{
623 ToolCallID: call.ToolCallID,
624 ToolName: call.ToolName,
625 Result: ToolResultOutputContentError{
626 Error: errors.New(result.Content),
627 },
628 ClientMetadata: result.Metadata,
629 ProviderExecuted: false,
630 }
631
632 if toolResultCallback != nil {
633 toolResultCallback(results[index])
634 }
635 } else {
636 results[index] = ToolResultContent{
637 ToolCallID: call.ToolCallID,
638 ToolName: toolCall.ToolName,
639 Result: ToolResultOutputContentText{
640 Text: result.Content,
641 },
642 ClientMetadata: result.Metadata,
643 ProviderExecuted: false,
644 }
645 if toolResultCallback != nil {
646 toolResultCallback(results[index])
647 }
648 }
649 }(i, toolCall)
650 }
651
652 // Wait for all tool executions to complete
653 wg.Wait()
654
655 return results, toolExecutionError
656}
657
658// Stream implements Agent.
659func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
660 // Convert AgentStreamCall to AgentCall for preparation
661 call := AgentCall{
662 Prompt: opts.Prompt,
663 Files: opts.Files,
664 Messages: opts.Messages,
665 MaxOutputTokens: opts.MaxOutputTokens,
666 Temperature: opts.Temperature,
667 TopP: opts.TopP,
668 TopK: opts.TopK,
669 PresencePenalty: opts.PresencePenalty,
670 FrequencyPenalty: opts.FrequencyPenalty,
671 ActiveTools: opts.ActiveTools,
672 Headers: opts.Headers,
673 ProviderOptions: opts.ProviderOptions,
674 MaxRetries: opts.MaxRetries,
675 StopWhen: opts.StopWhen,
676 PrepareStep: opts.PrepareStep,
677 RepairToolCall: opts.RepairToolCall,
678 }
679
680 call = a.prepareCall(call)
681
682 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
683 if err != nil {
684 return nil, err
685 }
686
687 var responseMessages []Message
688 var steps []StepResult
689 var totalUsage Usage
690
691 // Start agent stream
692 if opts.OnAgentStart != nil {
693 opts.OnAgentStart()
694 }
695
696 for stepNumber := 0; ; stepNumber++ {
697 stepInputMessages := append(initialPrompt, responseMessages...)
698 stepModel := a.settings.model
699 stepSystemPrompt := a.settings.systemPrompt
700 stepActiveTools := call.ActiveTools
701 stepToolChoice := ToolChoiceAuto
702 disableAllTools := false
703
704 // Apply step preparation if provided
705 if call.PrepareStep != nil {
706 prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
707 Model: stepModel,
708 Steps: steps,
709 StepNumber: stepNumber,
710 Messages: stepInputMessages,
711 })
712 if err != nil {
713 return nil, err
714 }
715
716 if prepared.Messages != nil {
717 stepInputMessages = prepared.Messages
718 }
719 if prepared.Model != nil {
720 stepModel = prepared.Model
721 }
722 if prepared.System != nil {
723 stepSystemPrompt = *prepared.System
724 }
725 if prepared.ToolChoice != nil {
726 stepToolChoice = *prepared.ToolChoice
727 }
728 if len(prepared.ActiveTools) > 0 {
729 stepActiveTools = prepared.ActiveTools
730 }
731 disableAllTools = prepared.DisableAllTools
732 }
733
734 // Recreate prompt with potentially modified system prompt
735 if stepSystemPrompt != a.settings.systemPrompt {
736 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
737 if err != nil {
738 return nil, err
739 }
740 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
741 stepInputMessages[0] = stepPrompt[0]
742 }
743 }
744
745 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
746
747 // Start step stream
748 if opts.OnStepStart != nil {
749 opts.OnStepStart(stepNumber)
750 }
751
752 // Create streaming call
753 streamCall := Call{
754 Prompt: stepInputMessages,
755 MaxOutputTokens: call.MaxOutputTokens,
756 Temperature: call.Temperature,
757 TopP: call.TopP,
758 TopK: call.TopK,
759 PresencePenalty: call.PresencePenalty,
760 FrequencyPenalty: call.FrequencyPenalty,
761 Tools: preparedTools,
762 ToolChoice: &stepToolChoice,
763 Headers: call.Headers,
764 ProviderOptions: call.ProviderOptions,
765 }
766
767 // Get streaming response
768 stream, err := stepModel.Stream(ctx, streamCall)
769 if err != nil {
770 if opts.OnError != nil {
771 opts.OnError(err)
772 }
773 return nil, err
774 }
775
776 // Process stream with tool execution
777 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
778 if err != nil {
779 if opts.OnError != nil {
780 opts.OnError(err)
781 }
782 return nil, err
783 }
784
785 steps = append(steps, stepResult)
786 totalUsage = addUsage(totalUsage, stepResult.Usage)
787
788 // Call step finished callback
789 if opts.OnStepFinish != nil {
790 opts.OnStepFinish(stepResult)
791 }
792
793 // Add step messages to response messages
794 stepMessages := toResponseMessages(stepResult.Content)
795 responseMessages = append(responseMessages, stepMessages...)
796
797 // Check stop conditions
798 shouldStop := isStopConditionMet(call.StopWhen, steps)
799 if shouldStop || !shouldContinue {
800 break
801 }
802 }
803
804 // Finish agent stream
805 agentResult := &AgentResult{
806 Steps: steps,
807 Response: steps[len(steps)-1].Response,
808 TotalUsage: totalUsage,
809 }
810
811 if opts.OnFinish != nil {
812 opts.OnFinish(agentResult)
813 }
814
815 if opts.OnAgentFinish != nil {
816 opts.OnAgentFinish(agentResult)
817 }
818
819 return agentResult, nil
820}
821
822func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
823 var preparedTools []Tool
824
825 // If explicitly disabling all tools, return no tools
826 if disableAllTools {
827 return preparedTools
828 }
829
830 for _, tool := range tools {
831 // If activeTools has items, only include tools in the list
832 // If activeTools is empty, include all tools
833 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
834 continue
835 }
836 info := tool.Info()
837 preparedTools = append(preparedTools, FunctionTool{
838 Name: info.Name,
839 Description: info.Description,
840 InputSchema: map[string]any{
841 "type": "object",
842 "properties": info.Parameters,
843 "required": info.Required,
844 },
845 })
846 }
847 return preparedTools
848}
849
850// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
851func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
852 if err := a.validateToolCall(toolCall, availableTools); err == nil {
853 return toolCall
854 } else {
855 if repairFunc != nil {
856 repairOptions := ToolCallRepairOptions{
857 OriginalToolCall: toolCall,
858 ValidationError: err,
859 AvailableTools: availableTools,
860 SystemPrompt: systemPrompt,
861 Messages: messages,
862 }
863
864 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
865 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
866 return *repairedToolCall
867 }
868 }
869 }
870
871 invalidToolCall := toolCall
872 invalidToolCall.Invalid = true
873 invalidToolCall.ValidationError = err
874 return invalidToolCall
875 }
876}
877
878// validateToolCall validates a tool call against available tools and their schemas
879func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
880 var tool AgentTool
881 for _, t := range availableTools {
882 if t.Info().Name == toolCall.ToolName {
883 tool = t
884 break
885 }
886 }
887
888 if tool == nil {
889 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
890 }
891
892 // Validate JSON parsing
893 var input map[string]any
894 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
895 return fmt.Errorf("invalid JSON input: %w", err)
896 }
897
898 // Basic schema validation (check required fields)
899 // TODO: more robust schema validation using JSON Schema or similar
900 toolInfo := tool.Info()
901 for _, required := range toolInfo.Required {
902 if _, exists := input[required]; !exists {
903 return fmt.Errorf("missing required parameter: %s", required)
904 }
905 }
906 return nil
907}
908
909func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
910 if prompt == "" {
911 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
912 }
913
914 var preparedPrompt Prompt
915
916 if system != "" {
917 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
918 }
919
920 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
921 preparedPrompt = append(preparedPrompt, messages...)
922 return preparedPrompt, nil
923}
924
925func WithSystemPrompt(prompt string) AgentOption {
926 return func(s *agentSettings) {
927 s.systemPrompt = prompt
928 }
929}
930
931func WithMaxOutputTokens(tokens int64) AgentOption {
932 return func(s *agentSettings) {
933 s.maxOutputTokens = &tokens
934 }
935}
936
937func WithTemperature(temp float64) AgentOption {
938 return func(s *agentSettings) {
939 s.temperature = &temp
940 }
941}
942
943func WithTopP(topP float64) AgentOption {
944 return func(s *agentSettings) {
945 s.topP = &topP
946 }
947}
948
949func WithTopK(topK int64) AgentOption {
950 return func(s *agentSettings) {
951 s.topK = &topK
952 }
953}
954
955func WithPresencePenalty(penalty float64) AgentOption {
956 return func(s *agentSettings) {
957 s.presencePenalty = &penalty
958 }
959}
960
961func WithFrequencyPenalty(penalty float64) AgentOption {
962 return func(s *agentSettings) {
963 s.frequencyPenalty = &penalty
964 }
965}
966
967func WithTools(tools ...AgentTool) AgentOption {
968 return func(s *agentSettings) {
969 s.tools = append(s.tools, tools...)
970 }
971}
972
973func WithStopConditions(conditions ...StopCondition) AgentOption {
974 return func(s *agentSettings) {
975 s.stopWhen = append(s.stopWhen, conditions...)
976 }
977}
978
979func WithPrepareStep(fn PrepareStepFunction) AgentOption {
980 return func(s *agentSettings) {
981 s.prepareStep = fn
982 }
983}
984
985func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
986 return func(s *agentSettings) {
987 s.repairToolCall = fn
988 }
989}
990
991// processStepStream processes a single step's stream and returns the step result
992func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
993 var stepContent []Content
994 var stepToolCalls []ToolCallContent
995 var stepUsage Usage
996 stepFinishReason := FinishReasonUnknown
997 var stepWarnings []CallWarning
998 var stepProviderMetadata ProviderMetadata
999
1000 activeToolCalls := make(map[string]*ToolCallContent)
1001 activeTextContent := make(map[string]string)
1002
1003 // Process stream parts
1004 for part := range stream {
1005 // Forward all parts to chunk callback
1006 if opts.OnChunk != nil {
1007 opts.OnChunk(part)
1008 }
1009
1010 switch part.Type {
1011 case StreamPartTypeWarnings:
1012 stepWarnings = part.Warnings
1013 if opts.OnWarnings != nil {
1014 opts.OnWarnings(part.Warnings)
1015 }
1016
1017 case StreamPartTypeTextStart:
1018 activeTextContent[part.ID] = ""
1019 if opts.OnTextStart != nil {
1020 opts.OnTextStart(part.ID)
1021 }
1022
1023 case StreamPartTypeTextDelta:
1024 if _, exists := activeTextContent[part.ID]; exists {
1025 activeTextContent[part.ID] += part.Delta
1026 }
1027 if opts.OnTextDelta != nil {
1028 opts.OnTextDelta(part.ID, part.Delta)
1029 }
1030
1031 case StreamPartTypeTextEnd:
1032 if text, exists := activeTextContent[part.ID]; exists {
1033 stepContent = append(stepContent, TextContent{
1034 Text: text,
1035 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1036 })
1037 delete(activeTextContent, part.ID)
1038 }
1039 if opts.OnTextEnd != nil {
1040 opts.OnTextEnd(part.ID)
1041 }
1042
1043 case StreamPartTypeReasoningStart:
1044 activeTextContent[part.ID] = ""
1045 if opts.OnReasoningStart != nil {
1046 opts.OnReasoningStart(part.ID)
1047 }
1048
1049 case StreamPartTypeReasoningDelta:
1050 if _, exists := activeTextContent[part.ID]; exists {
1051 activeTextContent[part.ID] += part.Delta
1052 }
1053 if opts.OnReasoningDelta != nil {
1054 opts.OnReasoningDelta(part.ID, part.Delta)
1055 }
1056
1057 case StreamPartTypeReasoningEnd:
1058 if text, exists := activeTextContent[part.ID]; exists {
1059 stepContent = append(stepContent, ReasoningContent{
1060 Text: text,
1061 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1062 })
1063 if opts.OnReasoningEnd != nil {
1064 opts.OnReasoningEnd(part.ID, ReasoningContent{
1065 Text: text,
1066 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1067 })
1068 }
1069 delete(activeTextContent, part.ID)
1070 }
1071
1072 case StreamPartTypeToolInputStart:
1073 activeToolCalls[part.ID] = &ToolCallContent{
1074 ToolCallID: part.ID,
1075 ToolName: part.ToolCallName,
1076 Input: "",
1077 ProviderExecuted: part.ProviderExecuted,
1078 }
1079 if opts.OnToolInputStart != nil {
1080 opts.OnToolInputStart(part.ID, part.ToolCallName)
1081 }
1082
1083 case StreamPartTypeToolInputDelta:
1084 if toolCall, exists := activeToolCalls[part.ID]; exists {
1085 toolCall.Input += part.Delta
1086 }
1087 if opts.OnToolInputDelta != nil {
1088 opts.OnToolInputDelta(part.ID, part.Delta)
1089 }
1090
1091 case StreamPartTypeToolInputEnd:
1092 if opts.OnToolInputEnd != nil {
1093 opts.OnToolInputEnd(part.ID)
1094 }
1095
1096 case StreamPartTypeToolCall:
1097 toolCall := ToolCallContent{
1098 ToolCallID: part.ID,
1099 ToolName: part.ToolCallName,
1100 Input: part.ToolCallInput,
1101 ProviderExecuted: part.ProviderExecuted,
1102 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1103 }
1104
1105 // Validate and potentially repair the tool call
1106 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1107 stepToolCalls = append(stepToolCalls, validatedToolCall)
1108 stepContent = append(stepContent, validatedToolCall)
1109
1110 if opts.OnToolCall != nil {
1111 opts.OnToolCall(validatedToolCall)
1112 }
1113
1114 // Clean up active tool call
1115 delete(activeToolCalls, part.ID)
1116
1117 case StreamPartTypeSource:
1118 sourceContent := SourceContent{
1119 SourceType: part.SourceType,
1120 ID: part.ID,
1121 URL: part.URL,
1122 Title: part.Title,
1123 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1124 }
1125 stepContent = append(stepContent, sourceContent)
1126 if opts.OnSource != nil {
1127 opts.OnSource(sourceContent)
1128 }
1129
1130 case StreamPartTypeFinish:
1131 stepUsage = part.Usage
1132 stepFinishReason = part.FinishReason
1133 stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1134 if opts.OnStreamFinish != nil {
1135 opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1136 }
1137
1138 case StreamPartTypeError:
1139 if opts.OnStreamError != nil {
1140 opts.OnStreamError(part.Error)
1141 }
1142 if opts.OnError != nil {
1143 opts.OnError(part.Error)
1144 }
1145 return StepResult{}, false, part.Error
1146 }
1147 }
1148
1149 // Execute tools if any
1150 var toolResults []ToolResultContent
1151 if len(stepToolCalls) > 0 {
1152 var err error
1153 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1154 if err != nil {
1155 return StepResult{}, false, err
1156 }
1157 // Add tool results to content
1158 for _, result := range toolResults {
1159 stepContent = append(stepContent, result)
1160 }
1161 }
1162
1163 stepResult := StepResult{
1164 Response: Response{
1165 Content: stepContent,
1166 FinishReason: stepFinishReason,
1167 Usage: stepUsage,
1168 Warnings: stepWarnings,
1169 ProviderMetadata: stepProviderMetadata,
1170 },
1171 Messages: toResponseMessages(stepContent),
1172 }
1173
1174 // Determine if we should continue (has tool calls and not stopped)
1175 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1176
1177 return stepResult, shouldContinue, nil
1178}
1179
1180func addUsage(a, b Usage) Usage {
1181 return Usage{
1182 InputTokens: a.InputTokens + b.InputTokens,
1183 OutputTokens: a.OutputTokens + b.OutputTokens,
1184 TotalTokens: a.TotalTokens + b.TotalTokens,
1185 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1186 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1187 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1188 }
1189}
1190
1191func WithHeaders(headers map[string]string) AgentOption {
1192 return func(s *agentSettings) {
1193 s.headers = headers
1194 }
1195}
1196
1197func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1198 return func(s *agentSettings) {
1199 s.providerOptions = providerOptions
1200 }
1201}