1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sync"
11)
12
13type StepResult struct {
14 Response
15 Messages []Message
16}
17
18type StopCondition = func(steps []StepResult) bool
19
20// StepCountIs returns a stop condition that stops after the specified number of steps.
21func StepCountIs(stepCount int) StopCondition {
22 return func(steps []StepResult) bool {
23 return len(steps) >= stepCount
24 }
25}
26
27// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
28func HasToolCall(toolName string) StopCondition {
29 return func(steps []StepResult) bool {
30 if len(steps) == 0 {
31 return false
32 }
33 lastStep := steps[len(steps)-1]
34 toolCalls := lastStep.Content.ToolCalls()
35 for _, toolCall := range toolCalls {
36 if toolCall.ToolName == toolName {
37 return true
38 }
39 }
40 return false
41 }
42}
43
44// HasContent returns a stop condition that stops when the specified content type appears in the last step.
45func HasContent(contentType ContentType) StopCondition {
46 return func(steps []StepResult) bool {
47 if len(steps) == 0 {
48 return false
49 }
50 lastStep := steps[len(steps)-1]
51 for _, content := range lastStep.Content {
52 if content.GetType() == contentType {
53 return true
54 }
55 }
56 return false
57 }
58}
59
60// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
61func FinishReasonIs(reason FinishReason) StopCondition {
62 return func(steps []StepResult) bool {
63 if len(steps) == 0 {
64 return false
65 }
66 lastStep := steps[len(steps)-1]
67 return lastStep.FinishReason == reason
68 }
69}
70
71// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
72func MaxTokensUsed(maxTokens int64) StopCondition {
73 return func(steps []StepResult) bool {
74 var totalTokens int64
75 for _, step := range steps {
76 totalTokens += step.Usage.TotalTokens
77 }
78 return totalTokens >= maxTokens
79 }
80}
81
82type PrepareStepFunctionOptions struct {
83 Steps []StepResult
84 StepNumber int
85 Model LanguageModel
86 Messages []Message
87}
88
89type PrepareStepResult struct {
90 Model LanguageModel
91 Messages []Message
92 System *string
93 ToolChoice *ToolChoice
94 ActiveTools []string
95 DisableAllTools bool
96}
97
98type ToolCallRepairOptions struct {
99 OriginalToolCall ToolCallContent
100 ValidationError error
101 AvailableTools []AgentTool
102 SystemPrompt string
103 Messages []Message
104}
105
106type (
107 PrepareStepFunction = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
108 OnStepFinishedFunction = func(step StepResult)
109 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
110)
111
112type agentSettings struct {
113 systemPrompt string
114 maxOutputTokens *int64
115 temperature *float64
116 topP *float64
117 topK *int64
118 presencePenalty *float64
119 frequencyPenalty *float64
120 headers map[string]string
121 providerOptions ProviderOptions
122
123 // TODO: add support for provider tools
124 tools []AgentTool
125 maxRetries *int
126
127 model LanguageModel
128
129 stopWhen []StopCondition
130 prepareStep PrepareStepFunction
131 repairToolCall RepairToolCallFunction
132 onRetry OnRetryCallback
133}
134
135type AgentCall struct {
136 Prompt string `json:"prompt"`
137 Files []FilePart `json:"files"`
138 Messages []Message `json:"messages"`
139 MaxOutputTokens *int64
140 Temperature *float64 `json:"temperature"`
141 TopP *float64 `json:"top_p"`
142 TopK *int64 `json:"top_k"`
143 PresencePenalty *float64 `json:"presence_penalty"`
144 FrequencyPenalty *float64 `json:"frequency_penalty"`
145 ActiveTools []string `json:"active_tools"`
146 Headers map[string]string
147 ProviderOptions ProviderOptions
148 OnRetry OnRetryCallback
149 MaxRetries *int
150
151 StopWhen []StopCondition
152 PrepareStep PrepareStepFunction
153 RepairToolCall RepairToolCallFunction
154}
155
156type AgentStreamCall struct {
157 Prompt string `json:"prompt"`
158 Files []FilePart `json:"files"`
159 Messages []Message `json:"messages"`
160 MaxOutputTokens *int64
161 Temperature *float64 `json:"temperature"`
162 TopP *float64 `json:"top_p"`
163 TopK *int64 `json:"top_k"`
164 PresencePenalty *float64 `json:"presence_penalty"`
165 FrequencyPenalty *float64 `json:"frequency_penalty"`
166 ActiveTools []string `json:"active_tools"`
167 Headers map[string]string
168 ProviderOptions ProviderOptions
169 OnRetry OnRetryCallback
170 MaxRetries *int
171
172 StopWhen []StopCondition
173 PrepareStep PrepareStepFunction
174 RepairToolCall RepairToolCallFunction
175
176 // Agent-level callbacks
177 OnAgentStart func() // Called when agent starts
178 OnAgentFinish func(result *AgentResult) // Called when agent finishes
179 OnStepStart func(stepNumber int) // Called when a step starts
180 OnStepFinish func(stepResult StepResult) // Called when a step finishes
181 OnFinish func(result *AgentResult) // Called when entire agent completes
182 OnError func(error) // Called when an error occurs
183
184 // Stream part callbacks - called for each corresponding stream part type
185 OnChunk func(StreamPart) error // Called for each stream part (catch-all)
186 OnWarnings func(warnings []CallWarning) error // Called for warnings
187 OnTextStart func(id string) error // Called when text starts
188 OnTextDelta func(id, text string) error // Called for text deltas
189 OnTextEnd func(id string) error // Called when text ends
190 OnReasoningStart func(id string) error // Called when reasoning starts
191 OnReasoningDelta func(id, text string) error // Called for reasoning deltas
192 OnReasoningEnd func(id string, reasoning ReasoningContent) error // Called when reasoning ends
193 OnToolInputStart func(id, toolName string) error // Called when tool input starts
194 OnToolInputDelta func(id, delta string) error // Called for tool input deltas
195 OnToolInputEnd func(id string) error // Called when tool input ends
196 OnToolCall func(toolCall ToolCallContent) error // Called when tool call is complete
197 OnToolResult func(result ToolResultContent) error // Called when tool execution completes
198 OnSource func(source SourceContent) error // Called for source references
199 OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error // Called when stream finishes
200}
201
202type AgentResult struct {
203 Steps []StepResult
204 // Final response
205 Response Response
206 TotalUsage Usage
207}
208
209type Agent interface {
210 Generate(context.Context, AgentCall) (*AgentResult, error)
211 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
212}
213
214type AgentOption = func(*agentSettings)
215
216type agent struct {
217 settings agentSettings
218}
219
220func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
221 settings := agentSettings{
222 model: model,
223 }
224 for _, o := range opts {
225 o(&settings)
226 }
227 return &agent{
228 settings: settings,
229 }
230}
231
232func (a *agent) prepareCall(call AgentCall) AgentCall {
233 if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
234 call.MaxOutputTokens = a.settings.maxOutputTokens
235 }
236 if call.Temperature == nil && a.settings.temperature != nil {
237 call.Temperature = a.settings.temperature
238 }
239 if call.TopP == nil && a.settings.topP != nil {
240 call.TopP = a.settings.topP
241 }
242 if call.TopK == nil && a.settings.topK != nil {
243 call.TopK = a.settings.topK
244 }
245 if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
246 call.PresencePenalty = a.settings.presencePenalty
247 }
248 if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
249 call.FrequencyPenalty = a.settings.frequencyPenalty
250 }
251 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
252 call.StopWhen = a.settings.stopWhen
253 }
254 if call.PrepareStep == nil && a.settings.prepareStep != nil {
255 call.PrepareStep = a.settings.prepareStep
256 }
257 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
258 call.RepairToolCall = a.settings.repairToolCall
259 }
260 if call.OnRetry == nil && a.settings.onRetry != nil {
261 call.OnRetry = a.settings.onRetry
262 }
263 if call.MaxRetries == nil && a.settings.maxRetries != nil {
264 call.MaxRetries = a.settings.maxRetries
265 }
266
267 providerOptions := ProviderOptions{}
268 if a.settings.providerOptions != nil {
269 maps.Copy(providerOptions, a.settings.providerOptions)
270 }
271 if call.ProviderOptions != nil {
272 maps.Copy(providerOptions, call.ProviderOptions)
273 }
274 call.ProviderOptions = providerOptions
275
276 headers := map[string]string{}
277
278 if a.settings.headers != nil {
279 maps.Copy(headers, a.settings.headers)
280 }
281
282 if call.Headers != nil {
283 maps.Copy(headers, call.Headers)
284 }
285 call.Headers = headers
286 return call
287}
288
289// Generate implements Agent.
290func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
291 opts = a.prepareCall(opts)
292 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
293 if err != nil {
294 return nil, err
295 }
296 var responseMessages []Message
297 var steps []StepResult
298
299 for {
300 stepInputMessages := append(initialPrompt, responseMessages...)
301 stepModel := a.settings.model
302 stepSystemPrompt := a.settings.systemPrompt
303 stepActiveTools := opts.ActiveTools
304 stepToolChoice := ToolChoiceAuto
305 disableAllTools := false
306
307 if opts.PrepareStep != nil {
308 prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
309 Model: stepModel,
310 Steps: steps,
311 StepNumber: len(steps),
312 Messages: stepInputMessages,
313 })
314 if err != nil {
315 return nil, err
316 }
317
318 // Apply prepared step modifications
319 if prepared.Messages != nil {
320 stepInputMessages = prepared.Messages
321 }
322 if prepared.Model != nil {
323 stepModel = prepared.Model
324 }
325 if prepared.System != nil {
326 stepSystemPrompt = *prepared.System
327 }
328 if prepared.ToolChoice != nil {
329 stepToolChoice = *prepared.ToolChoice
330 }
331 if len(prepared.ActiveTools) > 0 {
332 stepActiveTools = prepared.ActiveTools
333 }
334 disableAllTools = prepared.DisableAllTools
335 }
336
337 // Recreate prompt with potentially modified system prompt
338 if stepSystemPrompt != a.settings.systemPrompt {
339 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
340 if err != nil {
341 return nil, err
342 }
343 // Replace system message part, keep the rest
344 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
345 stepInputMessages[0] = stepPrompt[0] // Replace system message
346 }
347 }
348
349 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
350
351 retryOptions := DefaultRetryOptions()
352 retryOptions.OnRetry = opts.OnRetry
353 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
354
355 result, err := retry(ctx, func() (*Response, error) {
356 return stepModel.Generate(ctx, Call{
357 Prompt: stepInputMessages,
358 MaxOutputTokens: opts.MaxOutputTokens,
359 Temperature: opts.Temperature,
360 TopP: opts.TopP,
361 TopK: opts.TopK,
362 PresencePenalty: opts.PresencePenalty,
363 FrequencyPenalty: opts.FrequencyPenalty,
364 Tools: preparedTools,
365 ToolChoice: &stepToolChoice,
366 Headers: opts.Headers,
367 ProviderOptions: opts.ProviderOptions,
368 })
369 })
370 if err != nil {
371 return nil, err
372 }
373
374 var stepToolCalls []ToolCallContent
375 for _, content := range result.Content {
376 if content.GetType() == ContentTypeToolCall {
377 toolCall, ok := AsContentType[ToolCallContent](content)
378 if !ok {
379 continue
380 }
381
382 // Validate and potentially repair the tool call
383 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
384 stepToolCalls = append(stepToolCalls, validatedToolCall)
385 }
386 }
387
388 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
389
390 // Build step content with validated tool calls and tool results
391 stepContent := []Content{}
392 toolCallIndex := 0
393 for _, content := range result.Content {
394 if content.GetType() == ContentTypeToolCall {
395 // Replace with validated tool call
396 if toolCallIndex < len(stepToolCalls) {
397 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
398 toolCallIndex++
399 }
400 } else {
401 // Keep other content as-is
402 stepContent = append(stepContent, content)
403 }
404 }
405 // Add tool results
406 for _, result := range toolResults {
407 stepContent = append(stepContent, result)
408 }
409 currentStepMessages := toResponseMessages(stepContent)
410 responseMessages = append(responseMessages, currentStepMessages...)
411
412 stepResult := StepResult{
413 Response: Response{
414 Content: stepContent,
415 FinishReason: result.FinishReason,
416 Usage: result.Usage,
417 Warnings: result.Warnings,
418 ProviderMetadata: result.ProviderMetadata,
419 },
420 Messages: currentStepMessages,
421 }
422 steps = append(steps, stepResult)
423 shouldStop := isStopConditionMet(opts.StopWhen, steps)
424
425 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
426 break
427 }
428 }
429
430 totalUsage := Usage{}
431
432 for _, step := range steps {
433 usage := step.Usage
434 totalUsage.InputTokens += usage.InputTokens
435 totalUsage.OutputTokens += usage.OutputTokens
436 totalUsage.ReasoningTokens += usage.ReasoningTokens
437 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
438 totalUsage.CacheReadTokens += usage.CacheReadTokens
439 totalUsage.TotalTokens += usage.TotalTokens
440 }
441
442 agentResult := &AgentResult{
443 Steps: steps,
444 Response: steps[len(steps)-1].Response,
445 TotalUsage: totalUsage,
446 }
447 return agentResult, nil
448}
449
450func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
451 if len(conditions) == 0 {
452 return false
453 }
454
455 for _, condition := range conditions {
456 if condition(steps) {
457 return true
458 }
459 }
460 return false
461}
462
463func toResponseMessages(content []Content) []Message {
464 var assistantParts []MessagePart
465 var toolParts []MessagePart
466
467 for _, c := range content {
468 switch c.GetType() {
469 case ContentTypeText:
470 text, ok := AsContentType[TextContent](c)
471 if !ok {
472 continue
473 }
474 assistantParts = append(assistantParts, TextPart{
475 Text: text.Text,
476 ProviderOptions: ProviderOptions(text.ProviderMetadata),
477 })
478 case ContentTypeReasoning:
479 reasoning, ok := AsContentType[ReasoningContent](c)
480 if !ok {
481 continue
482 }
483 assistantParts = append(assistantParts, ReasoningPart{
484 Text: reasoning.Text,
485 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
486 })
487 case ContentTypeToolCall:
488 toolCall, ok := AsContentType[ToolCallContent](c)
489 if !ok {
490 continue
491 }
492 assistantParts = append(assistantParts, ToolCallPart{
493 ToolCallID: toolCall.ToolCallID,
494 ToolName: toolCall.ToolName,
495 Input: toolCall.Input,
496 ProviderExecuted: toolCall.ProviderExecuted,
497 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
498 })
499 case ContentTypeFile:
500 file, ok := AsContentType[FileContent](c)
501 if !ok {
502 continue
503 }
504 assistantParts = append(assistantParts, FilePart{
505 Data: file.Data,
506 MediaType: file.MediaType,
507 ProviderOptions: ProviderOptions(file.ProviderMetadata),
508 })
509 case ContentTypeSource:
510 // Sources are metadata about references used to generate the response.
511 // They don't need to be included in the conversation messages.
512 continue
513 case ContentTypeToolResult:
514 result, ok := AsContentType[ToolResultContent](c)
515 if !ok {
516 continue
517 }
518 toolParts = append(toolParts, ToolResultPart{
519 ToolCallID: result.ToolCallID,
520 Output: result.Result,
521 ProviderOptions: ProviderOptions(result.ProviderMetadata),
522 })
523 }
524 }
525
526 var messages []Message
527 if len(assistantParts) > 0 {
528 messages = append(messages, Message{
529 Role: MessageRoleAssistant,
530 Content: assistantParts,
531 })
532 }
533 if len(toolParts) > 0 {
534 messages = append(messages, Message{
535 Role: MessageRoleTool,
536 Content: toolParts,
537 })
538 }
539 return messages
540}
541
542func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
543 if len(toolCalls) == 0 {
544 return nil, nil
545 }
546
547 // Create a map for quick tool lookup
548 toolMap := make(map[string]AgentTool)
549 for _, tool := range allTools {
550 toolMap[tool.Info().Name] = tool
551 }
552
553 // Execute all tool calls in parallel
554 results := make([]ToolResultContent, len(toolCalls))
555 executeErrors := make([]error, len(toolCalls))
556
557 var wg sync.WaitGroup
558
559 for i, toolCall := range toolCalls {
560 wg.Add(1)
561 go func(index int, call ToolCallContent) {
562 defer wg.Done()
563
564 // Skip invalid tool calls - create error result
565 if call.Invalid {
566 results[index] = ToolResultContent{
567 ToolCallID: call.ToolCallID,
568 ToolName: call.ToolName,
569 Result: ToolResultOutputContentError{
570 Error: call.ValidationError,
571 },
572 ProviderExecuted: false,
573 }
574 if toolResultCallback != nil {
575 err := toolResultCallback(results[index])
576 if err != nil {
577 executeErrors[index] = err
578 }
579 }
580
581 return
582 }
583
584 tool, exists := toolMap[call.ToolName]
585 if !exists {
586 results[index] = ToolResultContent{
587 ToolCallID: call.ToolCallID,
588 ToolName: call.ToolName,
589 Result: ToolResultOutputContentError{
590 Error: errors.New("Error: Tool not found: " + call.ToolName),
591 },
592 ProviderExecuted: false,
593 }
594
595 if toolResultCallback != nil {
596 err := toolResultCallback(results[index])
597 if err != nil {
598 executeErrors[index] = err
599 }
600 }
601 return
602 }
603
604 // Execute the tool
605 result, err := tool.Run(ctx, ToolCall{
606 ID: call.ToolCallID,
607 Name: call.ToolName,
608 Input: call.Input,
609 })
610 if err != nil {
611 results[index] = ToolResultContent{
612 ToolCallID: call.ToolCallID,
613 ToolName: call.ToolName,
614 Result: ToolResultOutputContentError{
615 Error: err,
616 },
617 ClientMetadata: result.Metadata,
618 ProviderExecuted: false,
619 }
620 if toolResultCallback != nil {
621 cbErr := toolResultCallback(results[index])
622 if cbErr != nil {
623 executeErrors[index] = cbErr
624 }
625 }
626 executeErrors[index] = err
627 return
628 }
629
630 if result.IsError {
631 results[index] = ToolResultContent{
632 ToolCallID: call.ToolCallID,
633 ToolName: call.ToolName,
634 Result: ToolResultOutputContentError{
635 Error: errors.New(result.Content),
636 },
637 ClientMetadata: result.Metadata,
638 ProviderExecuted: false,
639 }
640
641 if toolResultCallback != nil {
642 err := toolResultCallback(results[index])
643 if err != nil {
644 executeErrors[index] = err
645 }
646 }
647 } else {
648 results[index] = ToolResultContent{
649 ToolCallID: call.ToolCallID,
650 ToolName: toolCall.ToolName,
651 Result: ToolResultOutputContentText{
652 Text: result.Content,
653 },
654 ClientMetadata: result.Metadata,
655 ProviderExecuted: false,
656 }
657 if toolResultCallback != nil {
658 err := toolResultCallback(results[index])
659 if err != nil {
660 executeErrors[index] = err
661 }
662 }
663 }
664 }(i, toolCall)
665 }
666
667 // Wait for all tool executions to complete
668 wg.Wait()
669
670 for _, err := range executeErrors {
671 if err != nil {
672 return nil, err
673 }
674 }
675
676 return results, nil
677}
678
679// Stream implements Agent.
680func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
681 // Convert AgentStreamCall to AgentCall for preparation
682 call := AgentCall{
683 Prompt: opts.Prompt,
684 Files: opts.Files,
685 Messages: opts.Messages,
686 MaxOutputTokens: opts.MaxOutputTokens,
687 Temperature: opts.Temperature,
688 TopP: opts.TopP,
689 TopK: opts.TopK,
690 PresencePenalty: opts.PresencePenalty,
691 FrequencyPenalty: opts.FrequencyPenalty,
692 ActiveTools: opts.ActiveTools,
693 Headers: opts.Headers,
694 ProviderOptions: opts.ProviderOptions,
695 MaxRetries: opts.MaxRetries,
696 StopWhen: opts.StopWhen,
697 PrepareStep: opts.PrepareStep,
698 RepairToolCall: opts.RepairToolCall,
699 }
700
701 call = a.prepareCall(call)
702
703 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
704 if err != nil {
705 return nil, err
706 }
707
708 var responseMessages []Message
709 var steps []StepResult
710 var totalUsage Usage
711
712 // Start agent stream
713 if opts.OnAgentStart != nil {
714 opts.OnAgentStart()
715 }
716
717 for stepNumber := 0; ; stepNumber++ {
718 stepInputMessages := append(initialPrompt, responseMessages...)
719 stepModel := a.settings.model
720 stepSystemPrompt := a.settings.systemPrompt
721 stepActiveTools := call.ActiveTools
722 stepToolChoice := ToolChoiceAuto
723 disableAllTools := false
724
725 // Apply step preparation if provided
726 if call.PrepareStep != nil {
727 prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
728 Model: stepModel,
729 Steps: steps,
730 StepNumber: stepNumber,
731 Messages: stepInputMessages,
732 })
733 if err != nil {
734 return nil, err
735 }
736
737 if prepared.Messages != nil {
738 stepInputMessages = prepared.Messages
739 }
740 if prepared.Model != nil {
741 stepModel = prepared.Model
742 }
743 if prepared.System != nil {
744 stepSystemPrompt = *prepared.System
745 }
746 if prepared.ToolChoice != nil {
747 stepToolChoice = *prepared.ToolChoice
748 }
749 if len(prepared.ActiveTools) > 0 {
750 stepActiveTools = prepared.ActiveTools
751 }
752 disableAllTools = prepared.DisableAllTools
753 }
754
755 // Recreate prompt with potentially modified system prompt
756 if stepSystemPrompt != a.settings.systemPrompt {
757 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
758 if err != nil {
759 return nil, err
760 }
761 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
762 stepInputMessages[0] = stepPrompt[0]
763 }
764 }
765
766 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
767
768 // Start step stream
769 if opts.OnStepStart != nil {
770 opts.OnStepStart(stepNumber)
771 }
772
773 // Create streaming call
774 streamCall := Call{
775 Prompt: stepInputMessages,
776 MaxOutputTokens: call.MaxOutputTokens,
777 Temperature: call.Temperature,
778 TopP: call.TopP,
779 TopK: call.TopK,
780 PresencePenalty: call.PresencePenalty,
781 FrequencyPenalty: call.FrequencyPenalty,
782 Tools: preparedTools,
783 ToolChoice: &stepToolChoice,
784 Headers: call.Headers,
785 ProviderOptions: call.ProviderOptions,
786 }
787
788 // Get streaming response
789 stream, err := stepModel.Stream(ctx, streamCall)
790 if err != nil {
791 if opts.OnError != nil {
792 opts.OnError(err)
793 }
794 return nil, err
795 }
796
797 // Process stream with tool execution
798 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
799 if err != nil {
800 if opts.OnError != nil {
801 opts.OnError(err)
802 }
803 return nil, err
804 }
805
806 steps = append(steps, stepResult)
807 totalUsage = addUsage(totalUsage, stepResult.Usage)
808
809 // Call step finished callback
810 if opts.OnStepFinish != nil {
811 opts.OnStepFinish(stepResult)
812 }
813
814 // Add step messages to response messages
815 stepMessages := toResponseMessages(stepResult.Content)
816 responseMessages = append(responseMessages, stepMessages...)
817
818 // Check stop conditions
819 shouldStop := isStopConditionMet(call.StopWhen, steps)
820 if shouldStop || !shouldContinue {
821 break
822 }
823 }
824
825 // Finish agent stream
826 agentResult := &AgentResult{
827 Steps: steps,
828 Response: steps[len(steps)-1].Response,
829 TotalUsage: totalUsage,
830 }
831
832 if opts.OnFinish != nil {
833 opts.OnFinish(agentResult)
834 }
835
836 if opts.OnAgentFinish != nil {
837 opts.OnAgentFinish(agentResult)
838 }
839
840 return agentResult, nil
841}
842
843func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
844 var preparedTools []Tool
845
846 // If explicitly disabling all tools, return no tools
847 if disableAllTools {
848 return preparedTools
849 }
850
851 for _, tool := range tools {
852 // If activeTools has items, only include tools in the list
853 // If activeTools is empty, include all tools
854 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
855 continue
856 }
857 info := tool.Info()
858 preparedTools = append(preparedTools, FunctionTool{
859 Name: info.Name,
860 Description: info.Description,
861 InputSchema: map[string]any{
862 "type": "object",
863 "properties": info.Parameters,
864 "required": info.Required,
865 },
866 })
867 }
868 return preparedTools
869}
870
871// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
872func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
873 if err := a.validateToolCall(toolCall, availableTools); err == nil {
874 return toolCall
875 } else {
876 if repairFunc != nil {
877 repairOptions := ToolCallRepairOptions{
878 OriginalToolCall: toolCall,
879 ValidationError: err,
880 AvailableTools: availableTools,
881 SystemPrompt: systemPrompt,
882 Messages: messages,
883 }
884
885 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
886 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
887 return *repairedToolCall
888 }
889 }
890 }
891
892 invalidToolCall := toolCall
893 invalidToolCall.Invalid = true
894 invalidToolCall.ValidationError = err
895 return invalidToolCall
896 }
897}
898
899// validateToolCall validates a tool call against available tools and their schemas
900func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
901 var tool AgentTool
902 for _, t := range availableTools {
903 if t.Info().Name == toolCall.ToolName {
904 tool = t
905 break
906 }
907 }
908
909 if tool == nil {
910 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
911 }
912
913 // Validate JSON parsing
914 var input map[string]any
915 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
916 return fmt.Errorf("invalid JSON input: %w", err)
917 }
918
919 // Basic schema validation (check required fields)
920 // TODO: more robust schema validation using JSON Schema or similar
921 toolInfo := tool.Info()
922 for _, required := range toolInfo.Required {
923 if _, exists := input[required]; !exists {
924 return fmt.Errorf("missing required parameter: %s", required)
925 }
926 }
927 return nil
928}
929
930func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
931 if prompt == "" {
932 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
933 }
934
935 var preparedPrompt Prompt
936
937 if system != "" {
938 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
939 }
940
941 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
942 preparedPrompt = append(preparedPrompt, messages...)
943 return preparedPrompt, nil
944}
945
946func WithSystemPrompt(prompt string) AgentOption {
947 return func(s *agentSettings) {
948 s.systemPrompt = prompt
949 }
950}
951
952func WithMaxOutputTokens(tokens int64) AgentOption {
953 return func(s *agentSettings) {
954 s.maxOutputTokens = &tokens
955 }
956}
957
958func WithTemperature(temp float64) AgentOption {
959 return func(s *agentSettings) {
960 s.temperature = &temp
961 }
962}
963
964func WithTopP(topP float64) AgentOption {
965 return func(s *agentSettings) {
966 s.topP = &topP
967 }
968}
969
970func WithTopK(topK int64) AgentOption {
971 return func(s *agentSettings) {
972 s.topK = &topK
973 }
974}
975
976func WithPresencePenalty(penalty float64) AgentOption {
977 return func(s *agentSettings) {
978 s.presencePenalty = &penalty
979 }
980}
981
982func WithFrequencyPenalty(penalty float64) AgentOption {
983 return func(s *agentSettings) {
984 s.frequencyPenalty = &penalty
985 }
986}
987
988func WithTools(tools ...AgentTool) AgentOption {
989 return func(s *agentSettings) {
990 s.tools = append(s.tools, tools...)
991 }
992}
993
994func WithStopConditions(conditions ...StopCondition) AgentOption {
995 return func(s *agentSettings) {
996 s.stopWhen = append(s.stopWhen, conditions...)
997 }
998}
999
1000func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1001 return func(s *agentSettings) {
1002 s.prepareStep = fn
1003 }
1004}
1005
1006func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1007 return func(s *agentSettings) {
1008 s.repairToolCall = fn
1009 }
1010}
1011
1012// processStepStream processes a single step's stream and returns the step result
1013func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1014 var stepContent []Content
1015 var stepToolCalls []ToolCallContent
1016 var stepUsage Usage
1017 stepFinishReason := FinishReasonUnknown
1018 var stepWarnings []CallWarning
1019 var stepProviderMetadata ProviderMetadata
1020
1021 activeToolCalls := make(map[string]*ToolCallContent)
1022 activeTextContent := make(map[string]string)
1023
1024 // Process stream parts
1025 for part := range stream {
1026 // Forward all parts to chunk callback
1027 if opts.OnChunk != nil {
1028 err := opts.OnChunk(part)
1029 if err != nil {
1030 return StepResult{}, false, err
1031 }
1032 }
1033
1034 switch part.Type {
1035 case StreamPartTypeWarnings:
1036 stepWarnings = part.Warnings
1037 if opts.OnWarnings != nil {
1038 err := opts.OnWarnings(part.Warnings)
1039 if err != nil {
1040 return StepResult{}, false, err
1041 }
1042 }
1043
1044 case StreamPartTypeTextStart:
1045 activeTextContent[part.ID] = ""
1046 if opts.OnTextStart != nil {
1047 err := opts.OnTextStart(part.ID)
1048 if err != nil {
1049 return StepResult{}, false, err
1050 }
1051 }
1052
1053 case StreamPartTypeTextDelta:
1054 if _, exists := activeTextContent[part.ID]; exists {
1055 activeTextContent[part.ID] += part.Delta
1056 }
1057 if opts.OnTextDelta != nil {
1058 err := opts.OnTextDelta(part.ID, part.Delta)
1059 if err != nil {
1060 return StepResult{}, false, err
1061 }
1062 }
1063
1064 case StreamPartTypeTextEnd:
1065 if text, exists := activeTextContent[part.ID]; exists {
1066 stepContent = append(stepContent, TextContent{
1067 Text: text,
1068 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1069 })
1070 delete(activeTextContent, part.ID)
1071 }
1072 if opts.OnTextEnd != nil {
1073 err := opts.OnTextEnd(part.ID)
1074 if err != nil {
1075 return StepResult{}, false, err
1076 }
1077 }
1078
1079 case StreamPartTypeReasoningStart:
1080 activeTextContent[part.ID] = ""
1081 if opts.OnReasoningStart != nil {
1082 err := opts.OnReasoningStart(part.ID)
1083 if err != nil {
1084 return StepResult{}, false, err
1085 }
1086 }
1087
1088 case StreamPartTypeReasoningDelta:
1089 if _, exists := activeTextContent[part.ID]; exists {
1090 activeTextContent[part.ID] += part.Delta
1091 }
1092 if opts.OnReasoningDelta != nil {
1093 err := opts.OnReasoningDelta(part.ID, part.Delta)
1094 if err != nil {
1095 return StepResult{}, false, err
1096 }
1097 }
1098
1099 case StreamPartTypeReasoningEnd:
1100 if text, exists := activeTextContent[part.ID]; exists {
1101 stepContent = append(stepContent, ReasoningContent{
1102 Text: text,
1103 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1104 })
1105 if opts.OnReasoningEnd != nil {
1106 err := opts.OnReasoningEnd(part.ID, ReasoningContent{
1107 Text: text,
1108 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1109 })
1110 if err != nil {
1111 return StepResult{}, false, err
1112 }
1113 }
1114 delete(activeTextContent, part.ID)
1115 }
1116
1117 case StreamPartTypeToolInputStart:
1118 activeToolCalls[part.ID] = &ToolCallContent{
1119 ToolCallID: part.ID,
1120 ToolName: part.ToolCallName,
1121 Input: "",
1122 ProviderExecuted: part.ProviderExecuted,
1123 }
1124 if opts.OnToolInputStart != nil {
1125 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1126 if err != nil {
1127 return StepResult{}, false, err
1128 }
1129 }
1130
1131 case StreamPartTypeToolInputDelta:
1132 if toolCall, exists := activeToolCalls[part.ID]; exists {
1133 toolCall.Input += part.Delta
1134 }
1135 if opts.OnToolInputDelta != nil {
1136 err := opts.OnToolInputDelta(part.ID, part.Delta)
1137 if err != nil {
1138 return StepResult{}, false, err
1139 }
1140 }
1141
1142 case StreamPartTypeToolInputEnd:
1143 if opts.OnToolInputEnd != nil {
1144 err := opts.OnToolInputEnd(part.ID)
1145 if err != nil {
1146 return StepResult{}, false, err
1147 }
1148 }
1149
1150 case StreamPartTypeToolCall:
1151 toolCall := ToolCallContent{
1152 ToolCallID: part.ID,
1153 ToolName: part.ToolCallName,
1154 Input: part.ToolCallInput,
1155 ProviderExecuted: part.ProviderExecuted,
1156 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1157 }
1158
1159 // Validate and potentially repair the tool call
1160 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1161 stepToolCalls = append(stepToolCalls, validatedToolCall)
1162 stepContent = append(stepContent, validatedToolCall)
1163
1164 if opts.OnToolCall != nil {
1165 err := opts.OnToolCall(validatedToolCall)
1166 if err != nil {
1167 return StepResult{}, false, err
1168 }
1169 }
1170
1171 // Clean up active tool call
1172 delete(activeToolCalls, part.ID)
1173
1174 case StreamPartTypeSource:
1175 sourceContent := SourceContent{
1176 SourceType: part.SourceType,
1177 ID: part.ID,
1178 URL: part.URL,
1179 Title: part.Title,
1180 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1181 }
1182 stepContent = append(stepContent, sourceContent)
1183 if opts.OnSource != nil {
1184 err := opts.OnSource(sourceContent)
1185 if err != nil {
1186 return StepResult{}, false, err
1187 }
1188 }
1189
1190 case StreamPartTypeFinish:
1191 stepUsage = part.Usage
1192 stepFinishReason = part.FinishReason
1193 stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1194 if opts.OnStreamFinish != nil {
1195 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1196 if err != nil {
1197 return StepResult{}, false, err
1198 }
1199 }
1200
1201 case StreamPartTypeError:
1202 if opts.OnError != nil {
1203 opts.OnError(part.Error)
1204 }
1205 return StepResult{}, false, part.Error
1206 }
1207 }
1208
1209 // Execute tools if any
1210 var toolResults []ToolResultContent
1211 if len(stepToolCalls) > 0 {
1212 var err error
1213 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1214 if err != nil {
1215 return StepResult{}, false, err
1216 }
1217 // Add tool results to content
1218 for _, result := range toolResults {
1219 stepContent = append(stepContent, result)
1220 }
1221 }
1222
1223 stepResult := StepResult{
1224 Response: Response{
1225 Content: stepContent,
1226 FinishReason: stepFinishReason,
1227 Usage: stepUsage,
1228 Warnings: stepWarnings,
1229 ProviderMetadata: stepProviderMetadata,
1230 },
1231 Messages: toResponseMessages(stepContent),
1232 }
1233
1234 // Determine if we should continue (has tool calls and not stopped)
1235 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1236
1237 return stepResult, shouldContinue, nil
1238}
1239
1240func addUsage(a, b Usage) Usage {
1241 return Usage{
1242 InputTokens: a.InputTokens + b.InputTokens,
1243 OutputTokens: a.OutputTokens + b.OutputTokens,
1244 TotalTokens: a.TotalTokens + b.TotalTokens,
1245 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1246 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1247 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1248 }
1249}
1250
1251func WithHeaders(headers map[string]string) AgentOption {
1252 return func(s *agentSettings) {
1253 s.headers = headers
1254 }
1255}
1256
1257func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1258 return func(s *agentSettings) {
1259 s.providerOptions = providerOptions
1260 }
1261}