1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sync"
11)
12
13type StepResult struct {
14 Response
15 Messages []Message
16}
17
18type StopCondition = func(steps []StepResult) bool
19
20// StepCountIs returns a stop condition that stops after the specified number of steps.
21func StepCountIs(stepCount int) StopCondition {
22 return func(steps []StepResult) bool {
23 return len(steps) >= stepCount
24 }
25}
26
27// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
28func HasToolCall(toolName string) StopCondition {
29 return func(steps []StepResult) bool {
30 if len(steps) == 0 {
31 return false
32 }
33 lastStep := steps[len(steps)-1]
34 toolCalls := lastStep.Content.ToolCalls()
35 for _, toolCall := range toolCalls {
36 if toolCall.ToolName == toolName {
37 return true
38 }
39 }
40 return false
41 }
42}
43
44// HasContent returns a stop condition that stops when the specified content type appears in the last step.
45func HasContent(contentType ContentType) StopCondition {
46 return func(steps []StepResult) bool {
47 if len(steps) == 0 {
48 return false
49 }
50 lastStep := steps[len(steps)-1]
51 for _, content := range lastStep.Content {
52 if content.GetType() == contentType {
53 return true
54 }
55 }
56 return false
57 }
58}
59
60// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
61func FinishReasonIs(reason FinishReason) StopCondition {
62 return func(steps []StepResult) bool {
63 if len(steps) == 0 {
64 return false
65 }
66 lastStep := steps[len(steps)-1]
67 return lastStep.FinishReason == reason
68 }
69}
70
71// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
72func MaxTokensUsed(maxTokens int64) StopCondition {
73 return func(steps []StepResult) bool {
74 var totalTokens int64
75 for _, step := range steps {
76 totalTokens += step.Usage.TotalTokens
77 }
78 return totalTokens >= maxTokens
79 }
80}
81
82type PrepareStepFunctionOptions struct {
83 Steps []StepResult
84 StepNumber int
85 Model LanguageModel
86 Messages []Message
87}
88
89type PrepareStepResult struct {
90 Model LanguageModel
91 Messages []Message
92 System *string
93 ToolChoice *ToolChoice
94 ActiveTools []string
95 DisableAllTools bool
96}
97
98type ToolCallRepairOptions struct {
99 OriginalToolCall ToolCallContent
100 ValidationError error
101 AvailableTools []AgentTool
102 SystemPrompt string
103 Messages []Message
104}
105
106type (
107 PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
108 OnStepFinishedFunction = func(step StepResult)
109 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
110)
111
112type AgentSettings struct {
113 systemPrompt string
114 maxOutputTokens *int64
115 temperature *float64
116 topP *float64
117 topK *int64
118 presencePenalty *float64
119 frequencyPenalty *float64
120 headers map[string]string
121 providerOptions ProviderOptions
122
123 // TODO: add support for provider tools
124 tools []AgentTool
125 maxRetries *int
126
127 model LanguageModel
128
129 stopWhen []StopCondition
130 prepareStep PrepareStepFunction
131 repairToolCall RepairToolCallFunction
132 onStepFinished OnStepFinishedFunction
133 onRetry OnRetryCallback
134}
135
136type AgentCall struct {
137 Prompt string `json:"prompt"`
138 Files []FilePart `json:"files"`
139 Messages []Message `json:"messages"`
140 MaxOutputTokens *int64
141 Temperature *float64 `json:"temperature"`
142 TopP *float64 `json:"top_p"`
143 TopK *int64 `json:"top_k"`
144 PresencePenalty *float64 `json:"presence_penalty"`
145 FrequencyPenalty *float64 `json:"frequency_penalty"`
146 ActiveTools []string `json:"active_tools"`
147 Headers map[string]string
148 ProviderOptions ProviderOptions
149 OnRetry OnRetryCallback
150 MaxRetries *int
151
152 StopWhen []StopCondition
153 PrepareStep PrepareStepFunction
154 RepairToolCall RepairToolCallFunction
155 OnStepFinished OnStepFinishedFunction
156}
157
158type AgentStreamCall struct {
159 Prompt string `json:"prompt"`
160 Files []FilePart `json:"files"`
161 Messages []Message `json:"messages"`
162 MaxOutputTokens *int64
163 Temperature *float64 `json:"temperature"`
164 TopP *float64 `json:"top_p"`
165 TopK *int64 `json:"top_k"`
166 PresencePenalty *float64 `json:"presence_penalty"`
167 FrequencyPenalty *float64 `json:"frequency_penalty"`
168 ActiveTools []string `json:"active_tools"`
169 Headers map[string]string
170 ProviderOptions ProviderOptions
171 OnRetry OnRetryCallback
172 MaxRetries *int
173
174 StopWhen []StopCondition
175 PrepareStep PrepareStepFunction
176 RepairToolCall RepairToolCallFunction
177
178 // Agent-level callbacks
179 OnAgentStart func() // Called when agent starts
180 OnAgentFinish func(result *AgentResult) // Called when agent finishes
181 OnStepStart func(stepNumber int) // Called when a step starts
182 OnStepFinish func(stepResult StepResult) // Called when a step finishes
183 OnFinish func(result *AgentResult) // Called when entire agent completes
184 OnError func(error) // Called when an error occurs
185
186 // Stream part callbacks - called for each corresponding stream part type
187 OnChunk func(StreamPart) // Called for each stream part (catch-all)
188 OnWarnings func(warnings []CallWarning) // Called for warnings
189 OnTextStart func(id string) // Called when text starts
190 OnTextDelta func(id, text string) // Called for text deltas
191 OnTextEnd func(id string) // Called when text ends
192 OnReasoningStart func(id string) // Called when reasoning starts
193 OnReasoningDelta func(id, text string) // Called for reasoning deltas
194 OnReasoningEnd func(id string) // Called when reasoning ends
195 OnToolInputStart func(id, toolName string) // Called when tool input starts
196 OnToolInputDelta func(id, delta string) // Called for tool input deltas
197 OnToolInputEnd func(id string) // Called when tool input ends
198 OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete
199 OnToolResult func(result ToolResultContent) // Called when tool execution completes
200 OnSource func(source SourceContent) // Called for source references
201 OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes
202 OnStreamError func(error) // Called when stream error occurs
203}
204
205type AgentResult struct {
206 Steps []StepResult
207 // Final response
208 Response Response
209 TotalUsage Usage
210}
211
212type Agent interface {
213 Generate(context.Context, AgentCall) (*AgentResult, error)
214 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
215}
216
217type agentOption = func(*AgentSettings)
218
219type agent struct {
220 settings AgentSettings
221}
222
223func NewAgent(model LanguageModel, opts ...agentOption) Agent {
224 settings := AgentSettings{
225 model: model,
226 }
227 for _, o := range opts {
228 o(&settings)
229 }
230 return &agent{
231 settings: settings,
232 }
233}
234
235func (a *agent) prepareCall(call AgentCall) AgentCall {
236 if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
237 call.MaxOutputTokens = a.settings.maxOutputTokens
238 }
239 if call.Temperature == nil && a.settings.temperature != nil {
240 call.Temperature = a.settings.temperature
241 }
242 if call.TopP == nil && a.settings.topP != nil {
243 call.TopP = a.settings.topP
244 }
245 if call.TopK == nil && a.settings.topK != nil {
246 call.TopK = a.settings.topK
247 }
248 if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
249 call.PresencePenalty = a.settings.presencePenalty
250 }
251 if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
252 call.FrequencyPenalty = a.settings.frequencyPenalty
253 }
254 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
255 call.StopWhen = a.settings.stopWhen
256 }
257 if call.PrepareStep == nil && a.settings.prepareStep != nil {
258 call.PrepareStep = a.settings.prepareStep
259 }
260 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
261 call.RepairToolCall = a.settings.repairToolCall
262 }
263 if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
264 call.OnStepFinished = a.settings.onStepFinished
265 }
266 if call.OnRetry == nil && a.settings.onRetry != nil {
267 call.OnRetry = a.settings.onRetry
268 }
269 if call.MaxRetries == nil && a.settings.maxRetries != nil {
270 call.MaxRetries = a.settings.maxRetries
271 }
272
273 providerOptions := ProviderOptions{}
274 if a.settings.providerOptions != nil {
275 maps.Copy(providerOptions, a.settings.providerOptions)
276 }
277 if call.ProviderOptions != nil {
278 maps.Copy(providerOptions, call.ProviderOptions)
279 }
280 call.ProviderOptions = providerOptions
281
282 headers := map[string]string{}
283
284 if a.settings.headers != nil {
285 maps.Copy(headers, a.settings.headers)
286 }
287
288 if call.Headers != nil {
289 maps.Copy(headers, call.Headers)
290 }
291 call.Headers = headers
292 return call
293}
294
295// Generate implements Agent.
296func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
297 opts = a.prepareCall(opts)
298 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
299 if err != nil {
300 return nil, err
301 }
302 var responseMessages []Message
303 var steps []StepResult
304
305 for {
306 stepInputMessages := append(initialPrompt, responseMessages...)
307 stepModel := a.settings.model
308 stepSystemPrompt := a.settings.systemPrompt
309 stepActiveTools := opts.ActiveTools
310 stepToolChoice := ToolChoiceAuto
311 disableAllTools := false
312
313 if opts.PrepareStep != nil {
314 prepared := opts.PrepareStep(PrepareStepFunctionOptions{
315 Model: stepModel,
316 Steps: steps,
317 StepNumber: len(steps),
318 Messages: stepInputMessages,
319 })
320
321 // Apply prepared step modifications
322 if prepared.Messages != nil {
323 stepInputMessages = prepared.Messages
324 }
325 if prepared.Model != nil {
326 stepModel = prepared.Model
327 }
328 if prepared.System != nil {
329 stepSystemPrompt = *prepared.System
330 }
331 if prepared.ToolChoice != nil {
332 stepToolChoice = *prepared.ToolChoice
333 }
334 if len(prepared.ActiveTools) > 0 {
335 stepActiveTools = prepared.ActiveTools
336 }
337 disableAllTools = prepared.DisableAllTools
338 }
339
340 // Recreate prompt with potentially modified system prompt
341 if stepSystemPrompt != a.settings.systemPrompt {
342 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
343 if err != nil {
344 return nil, err
345 }
346 // Replace system message part, keep the rest
347 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
348 stepInputMessages[0] = stepPrompt[0] // Replace system message
349 }
350 }
351
352 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
353
354 retryOptions := DefaultRetryOptions()
355 retryOptions.OnRetry = opts.OnRetry
356 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
357
358 result, err := retry(ctx, func() (*Response, error) {
359 return stepModel.Generate(ctx, Call{
360 Prompt: stepInputMessages,
361 MaxOutputTokens: opts.MaxOutputTokens,
362 Temperature: opts.Temperature,
363 TopP: opts.TopP,
364 TopK: opts.TopK,
365 PresencePenalty: opts.PresencePenalty,
366 FrequencyPenalty: opts.FrequencyPenalty,
367 Tools: preparedTools,
368 ToolChoice: &stepToolChoice,
369 Headers: opts.Headers,
370 ProviderOptions: opts.ProviderOptions,
371 })
372 })
373 if err != nil {
374 return nil, err
375 }
376
377 var stepToolCalls []ToolCallContent
378 for _, content := range result.Content {
379 if content.GetType() == ContentTypeToolCall {
380 toolCall, ok := AsContentType[ToolCallContent](content)
381 if !ok {
382 continue
383 }
384
385 // Validate and potentially repair the tool call
386 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
387 stepToolCalls = append(stepToolCalls, validatedToolCall)
388 }
389 }
390
391 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
392
393 // Build step content with validated tool calls and tool results
394 stepContent := []Content{}
395 toolCallIndex := 0
396 for _, content := range result.Content {
397 if content.GetType() == ContentTypeToolCall {
398 // Replace with validated tool call
399 if toolCallIndex < len(stepToolCalls) {
400 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
401 toolCallIndex++
402 }
403 } else {
404 // Keep other content as-is
405 stepContent = append(stepContent, content)
406 }
407 }
408 // Add tool results
409 for _, result := range toolResults {
410 stepContent = append(stepContent, result)
411 }
412 currentStepMessages := toResponseMessages(stepContent)
413 responseMessages = append(responseMessages, currentStepMessages...)
414
415 stepResult := StepResult{
416 Response: Response{
417 Content: stepContent,
418 FinishReason: result.FinishReason,
419 Usage: result.Usage,
420 Warnings: result.Warnings,
421 ProviderMetadata: result.ProviderMetadata,
422 },
423 Messages: currentStepMessages,
424 }
425 steps = append(steps, stepResult)
426 if opts.OnStepFinished != nil {
427 opts.OnStepFinished(stepResult)
428 }
429
430 shouldStop := isStopConditionMet(opts.StopWhen, steps)
431
432 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
433 break
434 }
435 }
436
437 totalUsage := Usage{}
438
439 for _, step := range steps {
440 usage := step.Usage
441 totalUsage.InputTokens += usage.InputTokens
442 totalUsage.OutputTokens += usage.OutputTokens
443 totalUsage.ReasoningTokens += usage.ReasoningTokens
444 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
445 totalUsage.CacheReadTokens += usage.CacheReadTokens
446 totalUsage.TotalTokens += usage.TotalTokens
447 }
448
449 agentResult := &AgentResult{
450 Steps: steps,
451 Response: steps[len(steps)-1].Response,
452 TotalUsage: totalUsage,
453 }
454 return agentResult, nil
455}
456
457func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
458 if len(conditions) == 0 {
459 return false
460 }
461
462 for _, condition := range conditions {
463 if condition(steps) {
464 return true
465 }
466 }
467 return false
468}
469
470func toResponseMessages(content []Content) []Message {
471 var assistantParts []MessagePart
472 var toolParts []MessagePart
473
474 for _, c := range content {
475 switch c.GetType() {
476 case ContentTypeText:
477 text, ok := AsContentType[TextContent](c)
478 if !ok {
479 continue
480 }
481 assistantParts = append(assistantParts, TextPart{
482 Text: text.Text,
483 ProviderOptions: ProviderOptions(text.ProviderMetadata),
484 })
485 case ContentTypeReasoning:
486 reasoning, ok := AsContentType[ReasoningContent](c)
487 if !ok {
488 continue
489 }
490 assistantParts = append(assistantParts, ReasoningPart{
491 Text: reasoning.Text,
492 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
493 })
494 case ContentTypeToolCall:
495 toolCall, ok := AsContentType[ToolCallContent](c)
496 if !ok {
497 continue
498 }
499 assistantParts = append(assistantParts, ToolCallPart{
500 ToolCallID: toolCall.ToolCallID,
501 ToolName: toolCall.ToolName,
502 Input: toolCall.Input,
503 ProviderExecuted: toolCall.ProviderExecuted,
504 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
505 })
506 case ContentTypeFile:
507 file, ok := AsContentType[FileContent](c)
508 if !ok {
509 continue
510 }
511 assistantParts = append(assistantParts, FilePart{
512 Data: file.Data,
513 MediaType: file.MediaType,
514 ProviderOptions: ProviderOptions(file.ProviderMetadata),
515 })
516 case ContentTypeSource:
517 // Sources are metadata about references used to generate the response.
518 // They don't need to be included in the conversation messages.
519 continue
520 case ContentTypeToolResult:
521 result, ok := AsContentType[ToolResultContent](c)
522 if !ok {
523 continue
524 }
525 toolParts = append(toolParts, ToolResultPart{
526 ToolCallID: result.ToolCallID,
527 Output: result.Result,
528 ProviderOptions: ProviderOptions(result.ProviderMetadata),
529 })
530 }
531 }
532
533 var messages []Message
534 if len(assistantParts) > 0 {
535 messages = append(messages, Message{
536 Role: MessageRoleAssistant,
537 Content: assistantParts,
538 })
539 }
540 if len(toolParts) > 0 {
541 messages = append(messages, Message{
542 Role: MessageRoleTool,
543 Content: toolParts,
544 })
545 }
546 return messages
547}
548
549func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
550 if len(toolCalls) == 0 {
551 return nil, nil
552 }
553
554 // Create a map for quick tool lookup
555 toolMap := make(map[string]AgentTool)
556 for _, tool := range allTools {
557 toolMap[tool.Info().Name] = tool
558 }
559
560 // Execute all tool calls in parallel
561 results := make([]ToolResultContent, len(toolCalls))
562 var toolExecutionError error
563 var wg sync.WaitGroup
564
565 for i, toolCall := range toolCalls {
566 wg.Add(1)
567 go func(index int, call ToolCallContent) {
568 defer wg.Done()
569
570 // Skip invalid tool calls - create error result
571 if call.Invalid {
572 results[index] = ToolResultContent{
573 ToolCallID: call.ToolCallID,
574 ToolName: call.ToolName,
575 Result: ToolResultOutputContentError{
576 Error: call.ValidationError,
577 },
578 ProviderExecuted: false,
579 }
580 if toolResultCallback != nil {
581 toolResultCallback(results[index])
582 }
583
584 return
585 }
586
587 tool, exists := toolMap[call.ToolName]
588 if !exists {
589 results[index] = ToolResultContent{
590 ToolCallID: call.ToolCallID,
591 ToolName: call.ToolName,
592 Result: ToolResultOutputContentError{
593 Error: errors.New("Error: Tool not found: " + call.ToolName),
594 },
595 ProviderExecuted: false,
596 }
597
598 if toolResultCallback != nil {
599 toolResultCallback(results[index])
600 }
601 return
602 }
603
604 // Execute the tool
605 result, err := tool.Run(ctx, ToolCall{
606 ID: call.ToolCallID,
607 Name: call.ToolName,
608 Input: call.Input,
609 })
610 if err != nil {
611 results[index] = ToolResultContent{
612 ToolCallID: call.ToolCallID,
613 ToolName: call.ToolName,
614 Result: ToolResultOutputContentError{
615 Error: err,
616 },
617 ClientMetadata: result.Metadata,
618 ProviderExecuted: false,
619 }
620 if toolResultCallback != nil {
621 toolResultCallback(results[index])
622 }
623 toolExecutionError = err
624 return
625 }
626
627 if result.IsError {
628 results[index] = ToolResultContent{
629 ToolCallID: call.ToolCallID,
630 ToolName: call.ToolName,
631 Result: ToolResultOutputContentError{
632 Error: errors.New(result.Content),
633 },
634 ClientMetadata: result.Metadata,
635 ProviderExecuted: false,
636 }
637
638 if toolResultCallback != nil {
639 toolResultCallback(results[index])
640 }
641 } else {
642 results[index] = ToolResultContent{
643 ToolCallID: call.ToolCallID,
644 ToolName: toolCall.ToolName,
645 Result: ToolResultOutputContentText{
646 Text: result.Content,
647 },
648 ClientMetadata: result.Metadata,
649 ProviderExecuted: false,
650 }
651 if toolResultCallback != nil {
652 toolResultCallback(results[index])
653 }
654 }
655 }(i, toolCall)
656 }
657
658 // Wait for all tool executions to complete
659 wg.Wait()
660
661 return results, toolExecutionError
662}
663
664// Stream implements Agent.
665func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
666 // Convert AgentStreamCall to AgentCall for preparation
667 call := AgentCall{
668 Prompt: opts.Prompt,
669 Files: opts.Files,
670 Messages: opts.Messages,
671 MaxOutputTokens: opts.MaxOutputTokens,
672 Temperature: opts.Temperature,
673 TopP: opts.TopP,
674 TopK: opts.TopK,
675 PresencePenalty: opts.PresencePenalty,
676 FrequencyPenalty: opts.FrequencyPenalty,
677 ActiveTools: opts.ActiveTools,
678 Headers: opts.Headers,
679 ProviderOptions: opts.ProviderOptions,
680 MaxRetries: opts.MaxRetries,
681 StopWhen: opts.StopWhen,
682 PrepareStep: opts.PrepareStep,
683 RepairToolCall: opts.RepairToolCall,
684 }
685
686 call = a.prepareCall(call)
687
688 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
689 if err != nil {
690 return nil, err
691 }
692
693 var responseMessages []Message
694 var steps []StepResult
695 var totalUsage Usage
696
697 // Start agent stream
698 if opts.OnAgentStart != nil {
699 opts.OnAgentStart()
700 }
701
702 for stepNumber := 0; ; stepNumber++ {
703 stepInputMessages := append(initialPrompt, responseMessages...)
704 stepModel := a.settings.model
705 stepSystemPrompt := a.settings.systemPrompt
706 stepActiveTools := call.ActiveTools
707 stepToolChoice := ToolChoiceAuto
708 disableAllTools := false
709
710 // Apply step preparation if provided
711 if call.PrepareStep != nil {
712 prepared := call.PrepareStep(PrepareStepFunctionOptions{
713 Model: stepModel,
714 Steps: steps,
715 StepNumber: stepNumber,
716 Messages: stepInputMessages,
717 })
718
719 if prepared.Messages != nil {
720 stepInputMessages = prepared.Messages
721 }
722 if prepared.Model != nil {
723 stepModel = prepared.Model
724 }
725 if prepared.System != nil {
726 stepSystemPrompt = *prepared.System
727 }
728 if prepared.ToolChoice != nil {
729 stepToolChoice = *prepared.ToolChoice
730 }
731 if len(prepared.ActiveTools) > 0 {
732 stepActiveTools = prepared.ActiveTools
733 }
734 disableAllTools = prepared.DisableAllTools
735 }
736
737 // Recreate prompt with potentially modified system prompt
738 if stepSystemPrompt != a.settings.systemPrompt {
739 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
740 if err != nil {
741 return nil, err
742 }
743 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
744 stepInputMessages[0] = stepPrompt[0]
745 }
746 }
747
748 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
749
750 // Start step stream
751 if opts.OnStepStart != nil {
752 opts.OnStepStart(stepNumber)
753 }
754
755 // Create streaming call
756 streamCall := Call{
757 Prompt: stepInputMessages,
758 MaxOutputTokens: call.MaxOutputTokens,
759 Temperature: call.Temperature,
760 TopP: call.TopP,
761 TopK: call.TopK,
762 PresencePenalty: call.PresencePenalty,
763 FrequencyPenalty: call.FrequencyPenalty,
764 Tools: preparedTools,
765 ToolChoice: &stepToolChoice,
766 Headers: call.Headers,
767 ProviderOptions: call.ProviderOptions,
768 }
769
770 // Get streaming response
771 stream, err := stepModel.Stream(ctx, streamCall)
772 if err != nil {
773 if opts.OnError != nil {
774 opts.OnError(err)
775 }
776 return nil, err
777 }
778
779 // Process stream with tool execution
780 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
781 if err != nil {
782 if opts.OnError != nil {
783 opts.OnError(err)
784 }
785 return nil, err
786 }
787
788 steps = append(steps, stepResult)
789 totalUsage = addUsage(totalUsage, stepResult.Usage)
790
791 // Call step finished callback
792 if opts.OnStepFinish != nil {
793 opts.OnStepFinish(stepResult)
794 }
795
796 // Add step messages to response messages
797 stepMessages := toResponseMessages(stepResult.Content)
798 responseMessages = append(responseMessages, stepMessages...)
799
800 // Check stop conditions
801 shouldStop := isStopConditionMet(call.StopWhen, steps)
802 if shouldStop || !shouldContinue {
803 break
804 }
805 }
806
807 // Finish agent stream
808 agentResult := &AgentResult{
809 Steps: steps,
810 Response: steps[len(steps)-1].Response,
811 TotalUsage: totalUsage,
812 }
813
814 if opts.OnFinish != nil {
815 opts.OnFinish(agentResult)
816 }
817
818 if opts.OnAgentFinish != nil {
819 opts.OnAgentFinish(agentResult)
820 }
821
822 return agentResult, nil
823}
824
825func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
826 var preparedTools []Tool
827
828 // If explicitly disabling all tools, return no tools
829 if disableAllTools {
830 return preparedTools
831 }
832
833 for _, tool := range tools {
834 // If activeTools has items, only include tools in the list
835 // If activeTools is empty, include all tools
836 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
837 continue
838 }
839 info := tool.Info()
840 preparedTools = append(preparedTools, FunctionTool{
841 Name: info.Name,
842 Description: info.Description,
843 InputSchema: map[string]any{
844 "type": "object",
845 "properties": info.Parameters,
846 "required": info.Required,
847 },
848 })
849 }
850 return preparedTools
851}
852
853// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
854func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
855 if err := a.validateToolCall(toolCall, availableTools); err == nil {
856 return toolCall
857 } else {
858 if repairFunc != nil {
859 repairOptions := ToolCallRepairOptions{
860 OriginalToolCall: toolCall,
861 ValidationError: err,
862 AvailableTools: availableTools,
863 SystemPrompt: systemPrompt,
864 Messages: messages,
865 }
866
867 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
868 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
869 return *repairedToolCall
870 }
871 }
872 }
873
874 invalidToolCall := toolCall
875 invalidToolCall.Invalid = true
876 invalidToolCall.ValidationError = err
877 return invalidToolCall
878 }
879}
880
881// validateToolCall validates a tool call against available tools and their schemas
882func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
883 var tool AgentTool
884 for _, t := range availableTools {
885 if t.Info().Name == toolCall.ToolName {
886 tool = t
887 break
888 }
889 }
890
891 if tool == nil {
892 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
893 }
894
895 // Validate JSON parsing
896 var input map[string]any
897 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
898 return fmt.Errorf("invalid JSON input: %w", err)
899 }
900
901 // Basic schema validation (check required fields)
902 // TODO: more robust schema validation using JSON Schema or similar
903 toolInfo := tool.Info()
904 for _, required := range toolInfo.Required {
905 if _, exists := input[required]; !exists {
906 return fmt.Errorf("missing required parameter: %s", required)
907 }
908 }
909 return nil
910}
911
912func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
913 if prompt == "" {
914 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
915 }
916
917 var preparedPrompt Prompt
918
919 if system != "" {
920 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
921 }
922
923 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
924 preparedPrompt = append(preparedPrompt, messages...)
925 return preparedPrompt, nil
926}
927
928func WithSystemPrompt(prompt string) agentOption {
929 return func(s *AgentSettings) {
930 s.systemPrompt = prompt
931 }
932}
933
934func WithMaxOutputTokens(tokens int64) agentOption {
935 return func(s *AgentSettings) {
936 s.maxOutputTokens = &tokens
937 }
938}
939
940func WithTemperature(temp float64) agentOption {
941 return func(s *AgentSettings) {
942 s.temperature = &temp
943 }
944}
945
946func WithTopP(topP float64) agentOption {
947 return func(s *AgentSettings) {
948 s.topP = &topP
949 }
950}
951
952func WithTopK(topK int64) agentOption {
953 return func(s *AgentSettings) {
954 s.topK = &topK
955 }
956}
957
958func WithPresencePenalty(penalty float64) agentOption {
959 return func(s *AgentSettings) {
960 s.presencePenalty = &penalty
961 }
962}
963
964func WithFrequencyPenalty(penalty float64) agentOption {
965 return func(s *AgentSettings) {
966 s.frequencyPenalty = &penalty
967 }
968}
969
970func WithTools(tools ...AgentTool) agentOption {
971 return func(s *AgentSettings) {
972 s.tools = append(s.tools, tools...)
973 }
974}
975
976func WithStopConditions(conditions ...StopCondition) agentOption {
977 return func(s *AgentSettings) {
978 s.stopWhen = append(s.stopWhen, conditions...)
979 }
980}
981
982func WithPrepareStep(fn PrepareStepFunction) agentOption {
983 return func(s *AgentSettings) {
984 s.prepareStep = fn
985 }
986}
987
988func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
989 return func(s *AgentSettings) {
990 s.repairToolCall = fn
991 }
992}
993
994func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
995 return func(s *AgentSettings) {
996 s.onStepFinished = fn
997 }
998}
999
1000// processStepStream processes a single step's stream and returns the step result
1001func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1002 var stepContent []Content
1003 var stepToolCalls []ToolCallContent
1004 var stepUsage Usage
1005 var stepFinishReason FinishReason = FinishReasonUnknown
1006 var stepWarnings []CallWarning
1007 var stepProviderMetadata ProviderMetadata
1008
1009 activeToolCalls := make(map[string]*ToolCallContent)
1010 activeTextContent := make(map[string]string)
1011
1012 // Process stream parts
1013 for part := range stream {
1014 // Forward all parts to chunk callback
1015 if opts.OnChunk != nil {
1016 opts.OnChunk(part)
1017 }
1018
1019 switch part.Type {
1020 case StreamPartTypeWarnings:
1021 stepWarnings = part.Warnings
1022 if opts.OnWarnings != nil {
1023 opts.OnWarnings(part.Warnings)
1024 }
1025
1026 case StreamPartTypeTextStart:
1027 activeTextContent[part.ID] = ""
1028 if opts.OnTextStart != nil {
1029 opts.OnTextStart(part.ID)
1030 }
1031
1032 case StreamPartTypeTextDelta:
1033 if _, exists := activeTextContent[part.ID]; exists {
1034 activeTextContent[part.ID] += part.Delta
1035 }
1036 if opts.OnTextDelta != nil {
1037 opts.OnTextDelta(part.ID, part.Delta)
1038 }
1039
1040 case StreamPartTypeTextEnd:
1041 if text, exists := activeTextContent[part.ID]; exists {
1042 stepContent = append(stepContent, TextContent{
1043 Text: text,
1044 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1045 })
1046 delete(activeTextContent, part.ID)
1047 }
1048 if opts.OnTextEnd != nil {
1049 opts.OnTextEnd(part.ID)
1050 }
1051
1052 case StreamPartTypeReasoningStart:
1053 activeTextContent[part.ID] = ""
1054 if opts.OnReasoningStart != nil {
1055 opts.OnReasoningStart(part.ID)
1056 }
1057
1058 case StreamPartTypeReasoningDelta:
1059 if _, exists := activeTextContent[part.ID]; exists {
1060 activeTextContent[part.ID] += part.Delta
1061 }
1062 if opts.OnReasoningDelta != nil {
1063 opts.OnReasoningDelta(part.ID, part.Delta)
1064 }
1065
1066 case StreamPartTypeReasoningEnd:
1067 if text, exists := activeTextContent[part.ID]; exists {
1068 stepContent = append(stepContent, ReasoningContent{
1069 Text: text,
1070 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1071 })
1072 delete(activeTextContent, part.ID)
1073 }
1074 if opts.OnReasoningEnd != nil {
1075 opts.OnReasoningEnd(part.ID)
1076 }
1077
1078 case StreamPartTypeToolInputStart:
1079 activeToolCalls[part.ID] = &ToolCallContent{
1080 ToolCallID: part.ID,
1081 ToolName: part.ToolCallName,
1082 Input: "",
1083 ProviderExecuted: part.ProviderExecuted,
1084 }
1085 if opts.OnToolInputStart != nil {
1086 opts.OnToolInputStart(part.ID, part.ToolCallName)
1087 }
1088
1089 case StreamPartTypeToolInputDelta:
1090 if toolCall, exists := activeToolCalls[part.ID]; exists {
1091 toolCall.Input += part.Delta
1092 }
1093 if opts.OnToolInputDelta != nil {
1094 opts.OnToolInputDelta(part.ID, part.Delta)
1095 }
1096
1097 case StreamPartTypeToolInputEnd:
1098 if opts.OnToolInputEnd != nil {
1099 opts.OnToolInputEnd(part.ID)
1100 }
1101
1102 case StreamPartTypeToolCall:
1103 toolCall := ToolCallContent{
1104 ToolCallID: part.ID,
1105 ToolName: part.ToolCallName,
1106 Input: part.ToolCallInput,
1107 ProviderExecuted: part.ProviderExecuted,
1108 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1109 }
1110
1111 // Validate and potentially repair the tool call
1112 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1113 stepToolCalls = append(stepToolCalls, validatedToolCall)
1114 stepContent = append(stepContent, validatedToolCall)
1115
1116 if opts.OnToolCall != nil {
1117 opts.OnToolCall(validatedToolCall)
1118 }
1119
1120 // Clean up active tool call
1121 delete(activeToolCalls, part.ID)
1122
1123 case StreamPartTypeSource:
1124 sourceContent := SourceContent{
1125 SourceType: part.SourceType,
1126 ID: part.ID,
1127 URL: part.URL,
1128 Title: part.Title,
1129 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1130 }
1131 stepContent = append(stepContent, sourceContent)
1132 if opts.OnSource != nil {
1133 opts.OnSource(sourceContent)
1134 }
1135
1136 case StreamPartTypeFinish:
1137 stepUsage = part.Usage
1138 stepFinishReason = part.FinishReason
1139 stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1140 if opts.OnStreamFinish != nil {
1141 opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1142 }
1143
1144 case StreamPartTypeError:
1145 if opts.OnStreamError != nil {
1146 opts.OnStreamError(part.Error)
1147 }
1148 if opts.OnError != nil {
1149 opts.OnError(part.Error)
1150 }
1151 return StepResult{}, false, part.Error
1152 }
1153 }
1154
1155 // Execute tools if any
1156 var toolResults []ToolResultContent
1157 if len(stepToolCalls) > 0 {
1158 var err error
1159 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1160 if err != nil {
1161 return StepResult{}, false, err
1162 }
1163 // Add tool results to content
1164 for _, result := range toolResults {
1165 stepContent = append(stepContent, result)
1166 }
1167 }
1168
1169 stepResult := StepResult{
1170 Response: Response{
1171 Content: stepContent,
1172 FinishReason: stepFinishReason,
1173 Usage: stepUsage,
1174 Warnings: stepWarnings,
1175 ProviderMetadata: stepProviderMetadata,
1176 },
1177 Messages: toResponseMessages(stepContent),
1178 }
1179
1180 // Determine if we should continue (has tool calls and not stopped)
1181 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1182
1183 return stepResult, shouldContinue, nil
1184}
1185
1186func addUsage(a, b Usage) Usage {
1187 return Usage{
1188 InputTokens: a.InputTokens + b.InputTokens,
1189 OutputTokens: a.OutputTokens + b.OutputTokens,
1190 TotalTokens: a.TotalTokens + b.TotalTokens,
1191 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1192 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1193 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1194 }
1195}
1196
1197func WithHeaders(headers map[string]string) agentOption {
1198 return func(s *AgentSettings) {
1199 s.headers = headers
1200 }
1201}
1202
1203func WithProviderOptions(providerOptions ProviderOptions) agentOption {
1204 return func(s *AgentSettings) {
1205 s.providerOptions = providerOptions
1206 }
1207}