1package ai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12)
13
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19type StopCondition = func(steps []StepResult) bool
20
21// StepCountIs returns a stop condition that stops after the specified number of steps.
22func StepCountIs(stepCount int) StopCondition {
23 return func(steps []StepResult) bool {
24 return len(steps) >= stepCount
25 }
26}
27
28// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
29func HasToolCall(toolName string) StopCondition {
30 return func(steps []StepResult) bool {
31 if len(steps) == 0 {
32 return false
33 }
34 lastStep := steps[len(steps)-1]
35 toolCalls := lastStep.Content.ToolCalls()
36 for _, toolCall := range toolCalls {
37 if toolCall.ToolName == toolName {
38 return true
39 }
40 }
41 return false
42 }
43}
44
45// HasContent returns a stop condition that stops when the specified content type appears in the last step.
46func HasContent(contentType ContentType) StopCondition {
47 return func(steps []StepResult) bool {
48 if len(steps) == 0 {
49 return false
50 }
51 lastStep := steps[len(steps)-1]
52 for _, content := range lastStep.Content {
53 if content.GetType() == contentType {
54 return true
55 }
56 }
57 return false
58 }
59}
60
61// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
62func FinishReasonIs(reason FinishReason) StopCondition {
63 return func(steps []StepResult) bool {
64 if len(steps) == 0 {
65 return false
66 }
67 lastStep := steps[len(steps)-1]
68 return lastStep.FinishReason == reason
69 }
70}
71
72// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
73func MaxTokensUsed(maxTokens int64) StopCondition {
74 return func(steps []StepResult) bool {
75 var totalTokens int64
76 for _, step := range steps {
77 totalTokens += step.Usage.TotalTokens
78 }
79 return totalTokens >= maxTokens
80 }
81}
82
83type PrepareStepFunctionOptions struct {
84 Steps []StepResult
85 StepNumber int
86 Model LanguageModel
87 Messages []Message
88}
89
90type PrepareStepResult struct {
91 Model LanguageModel
92 Messages []Message
93 System *string
94 ToolChoice *ToolChoice
95 ActiveTools []string
96 DisableAllTools bool
97}
98
99type ToolCallRepairOptions struct {
100 OriginalToolCall ToolCallContent
101 ValidationError error
102 AvailableTools []AgentTool
103 SystemPrompt string
104 Messages []Message
105}
106
107type (
108 PrepareStepFunction = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
109 OnStepFinishedFunction = func(step StepResult)
110 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
111)
112
113type agentSettings struct {
114 systemPrompt string
115 maxOutputTokens *int64
116 temperature *float64
117 topP *float64
118 topK *int64
119 presencePenalty *float64
120 frequencyPenalty *float64
121 headers map[string]string
122 providerOptions ProviderOptions
123
124 // TODO: add support for provider tools
125 tools []AgentTool
126 maxRetries *int
127
128 model LanguageModel
129
130 stopWhen []StopCondition
131 prepareStep PrepareStepFunction
132 repairToolCall RepairToolCallFunction
133 onRetry OnRetryCallback
134}
135
136type AgentCall struct {
137 Prompt string `json:"prompt"`
138 Files []FilePart `json:"files"`
139 Messages []Message `json:"messages"`
140 MaxOutputTokens *int64
141 Temperature *float64 `json:"temperature"`
142 TopP *float64 `json:"top_p"`
143 TopK *int64 `json:"top_k"`
144 PresencePenalty *float64 `json:"presence_penalty"`
145 FrequencyPenalty *float64 `json:"frequency_penalty"`
146 ActiveTools []string `json:"active_tools"`
147 ProviderOptions ProviderOptions
148 OnRetry OnRetryCallback
149 MaxRetries *int
150
151 StopWhen []StopCondition
152 PrepareStep PrepareStepFunction
153 RepairToolCall RepairToolCallFunction
154}
155
156// Agent-level callbacks.
157type (
158 // OnAgentStartFunc is called when agent starts.
159 OnAgentStartFunc func()
160
161 // OnAgentFinishFunc is called when agent finishes.
162 OnAgentFinishFunc func(result *AgentResult) error
163
164 // OnStepStartFunc is called when a step starts.
165 OnStepStartFunc func(stepNumber int) error
166
167 // OnStepFinishFunc is called when a step finishes.
168 OnStepFinishFunc func(stepResult StepResult) error
169
170 // OnFinishFunc is called when entire agent completes.
171 OnFinishFunc func(result *AgentResult)
172
173 // OnErrorFunc is called when an error occurs.
174 OnErrorFunc func(error)
175)
176
177// Stream part callbacks - called for each corresponding stream part type.
178type (
179 // OnChunkFunc is called for each stream part (catch-all).
180 OnChunkFunc func(StreamPart) error
181
182 // OnWarningsFunc is called for warnings.
183 OnWarningsFunc func(warnings []CallWarning) error
184
185 // OnTextStartFunc is called when text starts.
186 OnTextStartFunc func(id string) error
187
188 // OnTextDeltaFunc is called for text deltas.
189 OnTextDeltaFunc func(id, text string) error
190
191 // OnTextEndFunc is called when text ends.
192 OnTextEndFunc func(id string) error
193
194 // OnReasoningStartFunc is called when reasoning starts.
195 OnReasoningStartFunc func(id string) error
196
197 // OnReasoningDeltaFunc is called for reasoning deltas.
198 OnReasoningDeltaFunc func(id, text string) error
199
200 // OnReasoningEndFunc is called when reasoning ends.
201 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
202
203 // OnToolInputStartFunc is called when tool input starts.
204 OnToolInputStartFunc func(id, toolName string) error
205
206 // OnToolInputDeltaFunc is called for tool input deltas.
207 OnToolInputDeltaFunc func(id, delta string) error
208
209 // OnToolInputEndFunc is called when tool input ends.
210 OnToolInputEndFunc func(id string) error
211
212 // OnToolCallFunc is called when tool call is complete.
213 OnToolCallFunc func(toolCall ToolCallContent) error
214
215 // OnToolResultFunc is called when tool execution completes.
216 OnToolResultFunc func(result ToolResultContent) error
217
218 // OnSourceFunc is called for source references.
219 OnSourceFunc func(source SourceContent) error
220
221 // OnStreamFinishFunc is called when stream finishes.
222 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
223)
224
225type AgentStreamCall struct {
226 Prompt string `json:"prompt"`
227 Files []FilePart `json:"files"`
228 Messages []Message `json:"messages"`
229 MaxOutputTokens *int64
230 Temperature *float64 `json:"temperature"`
231 TopP *float64 `json:"top_p"`
232 TopK *int64 `json:"top_k"`
233 PresencePenalty *float64 `json:"presence_penalty"`
234 FrequencyPenalty *float64 `json:"frequency_penalty"`
235 ActiveTools []string `json:"active_tools"`
236 Headers map[string]string
237 ProviderOptions ProviderOptions
238 OnRetry OnRetryCallback
239 MaxRetries *int
240
241 StopWhen []StopCondition
242 PrepareStep PrepareStepFunction
243 RepairToolCall RepairToolCallFunction
244
245 // Agent-level callbacks
246 OnAgentStart OnAgentStartFunc // Called when agent starts
247 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
248 OnStepStart OnStepStartFunc // Called when a step starts
249 OnStepFinish OnStepFinishFunc // Called when a step finishes
250 OnFinish OnFinishFunc // Called when entire agent completes
251 OnError OnErrorFunc // Called when an error occurs
252
253 // Stream part callbacks - called for each corresponding stream part type
254 OnChunk OnChunkFunc // Called for each stream part (catch-all)
255 OnWarnings OnWarningsFunc // Called for warnings
256 OnTextStart OnTextStartFunc // Called when text starts
257 OnTextDelta OnTextDeltaFunc // Called for text deltas
258 OnTextEnd OnTextEndFunc // Called when text ends
259 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
260 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
261 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
262 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
263 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
264 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
265 OnToolCall OnToolCallFunc // Called when tool call is complete
266 OnToolResult OnToolResultFunc // Called when tool execution completes
267 OnSource OnSourceFunc // Called for source references
268 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
269}
270
271type AgentResult struct {
272 Steps []StepResult
273 // Final response
274 Response Response
275 TotalUsage Usage
276}
277
278type Agent interface {
279 Generate(context.Context, AgentCall) (*AgentResult, error)
280 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
281}
282
283type AgentOption = func(*agentSettings)
284
285type agent struct {
286 settings agentSettings
287}
288
289func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
290 settings := agentSettings{
291 model: model,
292 }
293 for _, o := range opts {
294 o(&settings)
295 }
296 return &agent{
297 settings: settings,
298 }
299}
300
301func (a *agent) prepareCall(call AgentCall) AgentCall {
302 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
303 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
304 call.TopP = cmp.Or(call.TopP, a.settings.topP)
305 call.TopK = cmp.Or(call.TopK, a.settings.topK)
306 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
307 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
308 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
309
310 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
311 call.StopWhen = a.settings.stopWhen
312 }
313 if call.PrepareStep == nil && a.settings.prepareStep != nil {
314 call.PrepareStep = a.settings.prepareStep
315 }
316 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
317 call.RepairToolCall = a.settings.repairToolCall
318 }
319 if call.OnRetry == nil && a.settings.onRetry != nil {
320 call.OnRetry = a.settings.onRetry
321 }
322
323 providerOptions := ProviderOptions{}
324 if a.settings.providerOptions != nil {
325 maps.Copy(providerOptions, a.settings.providerOptions)
326 }
327 if call.ProviderOptions != nil {
328 maps.Copy(providerOptions, call.ProviderOptions)
329 }
330 call.ProviderOptions = providerOptions
331
332 headers := map[string]string{}
333
334 if a.settings.headers != nil {
335 maps.Copy(headers, a.settings.headers)
336 }
337
338 return call
339}
340
341// Generate implements Agent.
342func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
343 opts = a.prepareCall(opts)
344 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
345 if err != nil {
346 return nil, err
347 }
348 var responseMessages []Message
349 var steps []StepResult
350
351 for {
352 stepInputMessages := append(initialPrompt, responseMessages...)
353 stepModel := a.settings.model
354 stepSystemPrompt := a.settings.systemPrompt
355 stepActiveTools := opts.ActiveTools
356 stepToolChoice := ToolChoiceAuto
357 disableAllTools := false
358
359 if opts.PrepareStep != nil {
360 prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
361 Model: stepModel,
362 Steps: steps,
363 StepNumber: len(steps),
364 Messages: stepInputMessages,
365 })
366 if err != nil {
367 return nil, err
368 }
369
370 // Apply prepared step modifications
371 if prepared.Messages != nil {
372 stepInputMessages = prepared.Messages
373 }
374 if prepared.Model != nil {
375 stepModel = prepared.Model
376 }
377 if prepared.System != nil {
378 stepSystemPrompt = *prepared.System
379 }
380 if prepared.ToolChoice != nil {
381 stepToolChoice = *prepared.ToolChoice
382 }
383 if len(prepared.ActiveTools) > 0 {
384 stepActiveTools = prepared.ActiveTools
385 }
386 disableAllTools = prepared.DisableAllTools
387 }
388
389 // Recreate prompt with potentially modified system prompt
390 if stepSystemPrompt != a.settings.systemPrompt {
391 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
392 if err != nil {
393 return nil, err
394 }
395 // Replace system message part, keep the rest
396 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
397 stepInputMessages[0] = stepPrompt[0] // Replace system message
398 }
399 }
400
401 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
402
403 retryOptions := DefaultRetryOptions()
404 retryOptions.OnRetry = opts.OnRetry
405 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
406
407 result, err := retry(ctx, func() (*Response, error) {
408 return stepModel.Generate(ctx, Call{
409 Prompt: stepInputMessages,
410 MaxOutputTokens: opts.MaxOutputTokens,
411 Temperature: opts.Temperature,
412 TopP: opts.TopP,
413 TopK: opts.TopK,
414 PresencePenalty: opts.PresencePenalty,
415 FrequencyPenalty: opts.FrequencyPenalty,
416 Tools: preparedTools,
417 ToolChoice: &stepToolChoice,
418 ProviderOptions: opts.ProviderOptions,
419 })
420 })
421 if err != nil {
422 return nil, err
423 }
424
425 var stepToolCalls []ToolCallContent
426 for _, content := range result.Content {
427 if content.GetType() == ContentTypeToolCall {
428 toolCall, ok := AsContentType[ToolCallContent](content)
429 if !ok {
430 continue
431 }
432
433 // Validate and potentially repair the tool call
434 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
435 stepToolCalls = append(stepToolCalls, validatedToolCall)
436 }
437 }
438
439 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
440
441 // Build step content with validated tool calls and tool results
442 stepContent := []Content{}
443 toolCallIndex := 0
444 for _, content := range result.Content {
445 if content.GetType() == ContentTypeToolCall {
446 // Replace with validated tool call
447 if toolCallIndex < len(stepToolCalls) {
448 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
449 toolCallIndex++
450 }
451 } else {
452 // Keep other content as-is
453 stepContent = append(stepContent, content)
454 }
455 }
456 // Add tool results
457 for _, result := range toolResults {
458 stepContent = append(stepContent, result)
459 }
460 currentStepMessages := toResponseMessages(stepContent)
461 responseMessages = append(responseMessages, currentStepMessages...)
462
463 stepResult := StepResult{
464 Response: Response{
465 Content: stepContent,
466 FinishReason: result.FinishReason,
467 Usage: result.Usage,
468 Warnings: result.Warnings,
469 ProviderMetadata: result.ProviderMetadata,
470 },
471 Messages: currentStepMessages,
472 }
473 steps = append(steps, stepResult)
474 shouldStop := isStopConditionMet(opts.StopWhen, steps)
475
476 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
477 break
478 }
479 }
480
481 totalUsage := Usage{}
482
483 for _, step := range steps {
484 usage := step.Usage
485 totalUsage.InputTokens += usage.InputTokens
486 totalUsage.OutputTokens += usage.OutputTokens
487 totalUsage.ReasoningTokens += usage.ReasoningTokens
488 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
489 totalUsage.CacheReadTokens += usage.CacheReadTokens
490 totalUsage.TotalTokens += usage.TotalTokens
491 }
492
493 agentResult := &AgentResult{
494 Steps: steps,
495 Response: steps[len(steps)-1].Response,
496 TotalUsage: totalUsage,
497 }
498 return agentResult, nil
499}
500
501func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
502 if len(conditions) == 0 {
503 return false
504 }
505
506 for _, condition := range conditions {
507 if condition(steps) {
508 return true
509 }
510 }
511 return false
512}
513
514func toResponseMessages(content []Content) []Message {
515 var assistantParts []MessagePart
516 var toolParts []MessagePart
517
518 for _, c := range content {
519 switch c.GetType() {
520 case ContentTypeText:
521 text, ok := AsContentType[TextContent](c)
522 if !ok {
523 continue
524 }
525 assistantParts = append(assistantParts, TextPart{
526 Text: text.Text,
527 ProviderOptions: ProviderOptions(text.ProviderMetadata),
528 })
529 case ContentTypeReasoning:
530 reasoning, ok := AsContentType[ReasoningContent](c)
531 if !ok {
532 continue
533 }
534 assistantParts = append(assistantParts, ReasoningPart{
535 Text: reasoning.Text,
536 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
537 })
538 case ContentTypeToolCall:
539 toolCall, ok := AsContentType[ToolCallContent](c)
540 if !ok {
541 continue
542 }
543 assistantParts = append(assistantParts, ToolCallPart{
544 ToolCallID: toolCall.ToolCallID,
545 ToolName: toolCall.ToolName,
546 Input: toolCall.Input,
547 ProviderExecuted: toolCall.ProviderExecuted,
548 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
549 })
550 case ContentTypeFile:
551 file, ok := AsContentType[FileContent](c)
552 if !ok {
553 continue
554 }
555 assistantParts = append(assistantParts, FilePart{
556 Data: file.Data,
557 MediaType: file.MediaType,
558 ProviderOptions: ProviderOptions(file.ProviderMetadata),
559 })
560 case ContentTypeSource:
561 // Sources are metadata about references used to generate the response.
562 // They don't need to be included in the conversation messages.
563 continue
564 case ContentTypeToolResult:
565 result, ok := AsContentType[ToolResultContent](c)
566 if !ok {
567 continue
568 }
569 toolParts = append(toolParts, ToolResultPart{
570 ToolCallID: result.ToolCallID,
571 Output: result.Result,
572 ProviderOptions: ProviderOptions(result.ProviderMetadata),
573 })
574 }
575 }
576
577 var messages []Message
578 if len(assistantParts) > 0 {
579 messages = append(messages, Message{
580 Role: MessageRoleAssistant,
581 Content: assistantParts,
582 })
583 }
584 if len(toolParts) > 0 {
585 messages = append(messages, Message{
586 Role: MessageRoleTool,
587 Content: toolParts,
588 })
589 }
590 return messages
591}
592
593func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
594 if len(toolCalls) == 0 {
595 return nil, nil
596 }
597
598 // Create a map for quick tool lookup
599 toolMap := make(map[string]AgentTool)
600 for _, tool := range allTools {
601 toolMap[tool.Info().Name] = tool
602 }
603
604 // Execute all tool calls in parallel
605 results := make([]ToolResultContent, len(toolCalls))
606 executeErrors := make([]error, len(toolCalls))
607
608 var wg sync.WaitGroup
609
610 for i, toolCall := range toolCalls {
611 wg.Add(1)
612 go func(index int, call ToolCallContent) {
613 defer wg.Done()
614
615 // Skip invalid tool calls - create error result
616 if call.Invalid {
617 results[index] = ToolResultContent{
618 ToolCallID: call.ToolCallID,
619 ToolName: call.ToolName,
620 Result: ToolResultOutputContentError{
621 Error: call.ValidationError,
622 },
623 ProviderExecuted: false,
624 }
625 if toolResultCallback != nil {
626 err := toolResultCallback(results[index])
627 if err != nil {
628 executeErrors[index] = err
629 }
630 }
631
632 return
633 }
634
635 tool, exists := toolMap[call.ToolName]
636 if !exists {
637 results[index] = ToolResultContent{
638 ToolCallID: call.ToolCallID,
639 ToolName: call.ToolName,
640 Result: ToolResultOutputContentError{
641 Error: errors.New("Error: Tool not found: " + call.ToolName),
642 },
643 ProviderExecuted: false,
644 }
645
646 if toolResultCallback != nil {
647 err := toolResultCallback(results[index])
648 if err != nil {
649 executeErrors[index] = err
650 }
651 }
652 return
653 }
654
655 // Execute the tool
656 result, err := tool.Run(ctx, ToolCall{
657 ID: call.ToolCallID,
658 Name: call.ToolName,
659 Input: call.Input,
660 })
661 if err != nil {
662 results[index] = ToolResultContent{
663 ToolCallID: call.ToolCallID,
664 ToolName: call.ToolName,
665 Result: ToolResultOutputContentError{
666 Error: err,
667 },
668 ClientMetadata: result.Metadata,
669 ProviderExecuted: false,
670 }
671 if toolResultCallback != nil {
672 cbErr := toolResultCallback(results[index])
673 if cbErr != nil {
674 executeErrors[index] = cbErr
675 }
676 }
677 executeErrors[index] = err
678 return
679 }
680
681 if result.IsError {
682 results[index] = ToolResultContent{
683 ToolCallID: call.ToolCallID,
684 ToolName: call.ToolName,
685 Result: ToolResultOutputContentError{
686 Error: errors.New(result.Content),
687 },
688 ClientMetadata: result.Metadata,
689 ProviderExecuted: false,
690 }
691
692 if toolResultCallback != nil {
693 err := toolResultCallback(results[index])
694 if err != nil {
695 executeErrors[index] = err
696 }
697 }
698 } else {
699 results[index] = ToolResultContent{
700 ToolCallID: call.ToolCallID,
701 ToolName: toolCall.ToolName,
702 Result: ToolResultOutputContentText{
703 Text: result.Content,
704 },
705 ClientMetadata: result.Metadata,
706 ProviderExecuted: false,
707 }
708 if toolResultCallback != nil {
709 err := toolResultCallback(results[index])
710 if err != nil {
711 executeErrors[index] = err
712 }
713 }
714 }
715 }(i, toolCall)
716 }
717
718 // Wait for all tool executions to complete
719 wg.Wait()
720
721 for _, err := range executeErrors {
722 if err != nil {
723 return nil, err
724 }
725 }
726
727 return results, nil
728}
729
730// Stream implements Agent.
731func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
732 // Convert AgentStreamCall to AgentCall for preparation
733 call := AgentCall{
734 Prompt: opts.Prompt,
735 Files: opts.Files,
736 Messages: opts.Messages,
737 MaxOutputTokens: opts.MaxOutputTokens,
738 Temperature: opts.Temperature,
739 TopP: opts.TopP,
740 TopK: opts.TopK,
741 PresencePenalty: opts.PresencePenalty,
742 FrequencyPenalty: opts.FrequencyPenalty,
743 ActiveTools: opts.ActiveTools,
744 ProviderOptions: opts.ProviderOptions,
745 MaxRetries: opts.MaxRetries,
746 StopWhen: opts.StopWhen,
747 PrepareStep: opts.PrepareStep,
748 RepairToolCall: opts.RepairToolCall,
749 }
750
751 call = a.prepareCall(call)
752
753 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
754 if err != nil {
755 return nil, err
756 }
757
758 var responseMessages []Message
759 var steps []StepResult
760 var totalUsage Usage
761
762 // Start agent stream
763 if opts.OnAgentStart != nil {
764 opts.OnAgentStart()
765 }
766
767 for stepNumber := 0; ; stepNumber++ {
768 stepInputMessages := append(initialPrompt, responseMessages...)
769 stepModel := a.settings.model
770 stepSystemPrompt := a.settings.systemPrompt
771 stepActiveTools := call.ActiveTools
772 stepToolChoice := ToolChoiceAuto
773 disableAllTools := false
774
775 // Apply step preparation if provided
776 if call.PrepareStep != nil {
777 prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
778 Model: stepModel,
779 Steps: steps,
780 StepNumber: stepNumber,
781 Messages: stepInputMessages,
782 })
783 if err != nil {
784 return nil, err
785 }
786
787 if prepared.Messages != nil {
788 stepInputMessages = prepared.Messages
789 }
790 if prepared.Model != nil {
791 stepModel = prepared.Model
792 }
793 if prepared.System != nil {
794 stepSystemPrompt = *prepared.System
795 }
796 if prepared.ToolChoice != nil {
797 stepToolChoice = *prepared.ToolChoice
798 }
799 if len(prepared.ActiveTools) > 0 {
800 stepActiveTools = prepared.ActiveTools
801 }
802 disableAllTools = prepared.DisableAllTools
803 }
804
805 // Recreate prompt with potentially modified system prompt
806 if stepSystemPrompt != a.settings.systemPrompt {
807 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
808 if err != nil {
809 return nil, err
810 }
811 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
812 stepInputMessages[0] = stepPrompt[0]
813 }
814 }
815
816 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
817
818 // Start step stream
819 if opts.OnStepStart != nil {
820 _ = opts.OnStepStart(stepNumber)
821 }
822
823 // Create streaming call
824 streamCall := Call{
825 Prompt: stepInputMessages,
826 MaxOutputTokens: call.MaxOutputTokens,
827 Temperature: call.Temperature,
828 TopP: call.TopP,
829 TopK: call.TopK,
830 PresencePenalty: call.PresencePenalty,
831 FrequencyPenalty: call.FrequencyPenalty,
832 Tools: preparedTools,
833 ToolChoice: &stepToolChoice,
834 ProviderOptions: call.ProviderOptions,
835 }
836
837 // Get streaming response
838 stream, err := stepModel.Stream(ctx, streamCall)
839 if err != nil {
840 if opts.OnError != nil {
841 opts.OnError(err)
842 }
843 return nil, err
844 }
845
846 // Process stream with tool execution
847 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
848 if err != nil {
849 if opts.OnError != nil {
850 opts.OnError(err)
851 }
852 return nil, err
853 }
854
855 steps = append(steps, stepResult)
856 totalUsage = addUsage(totalUsage, stepResult.Usage)
857
858 // Call step finished callback
859 if opts.OnStepFinish != nil {
860 _ = opts.OnStepFinish(stepResult)
861 }
862
863 // Add step messages to response messages
864 stepMessages := toResponseMessages(stepResult.Content)
865 responseMessages = append(responseMessages, stepMessages...)
866
867 // Check stop conditions
868 shouldStop := isStopConditionMet(call.StopWhen, steps)
869 if shouldStop || !shouldContinue {
870 break
871 }
872 }
873
874 // Finish agent stream
875 agentResult := &AgentResult{
876 Steps: steps,
877 Response: steps[len(steps)-1].Response,
878 TotalUsage: totalUsage,
879 }
880
881 if opts.OnFinish != nil {
882 opts.OnFinish(agentResult)
883 }
884
885 if opts.OnAgentFinish != nil {
886 _ = opts.OnAgentFinish(agentResult)
887 }
888
889 return agentResult, nil
890}
891
892func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
893 preparedTools := make([]Tool, 0, len(tools))
894
895 // If explicitly disabling all tools, return no tools
896 if disableAllTools {
897 return preparedTools
898 }
899
900 for _, tool := range tools {
901 // If activeTools has items, only include tools in the list
902 // If activeTools is empty, include all tools
903 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
904 continue
905 }
906 info := tool.Info()
907 preparedTools = append(preparedTools, FunctionTool{
908 Name: info.Name,
909 Description: info.Description,
910 InputSchema: map[string]any{
911 "type": "object",
912 "properties": info.Parameters,
913 "required": info.Required,
914 },
915 })
916 }
917 return preparedTools
918}
919
920// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
921func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
922 if err := a.validateToolCall(toolCall, availableTools); err == nil {
923 return toolCall
924 } else {
925 if repairFunc != nil {
926 repairOptions := ToolCallRepairOptions{
927 OriginalToolCall: toolCall,
928 ValidationError: err,
929 AvailableTools: availableTools,
930 SystemPrompt: systemPrompt,
931 Messages: messages,
932 }
933
934 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
935 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
936 return *repairedToolCall
937 }
938 }
939 }
940
941 invalidToolCall := toolCall
942 invalidToolCall.Invalid = true
943 invalidToolCall.ValidationError = err
944 return invalidToolCall
945 }
946}
947
948// validateToolCall validates a tool call against available tools and their schemas.
949func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
950 var tool AgentTool
951 for _, t := range availableTools {
952 if t.Info().Name == toolCall.ToolName {
953 tool = t
954 break
955 }
956 }
957
958 if tool == nil {
959 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
960 }
961
962 // Validate JSON parsing
963 var input map[string]any
964 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
965 return fmt.Errorf("invalid JSON input: %w", err)
966 }
967
968 // Basic schema validation (check required fields)
969 // TODO: more robust schema validation using JSON Schema or similar
970 toolInfo := tool.Info()
971 for _, required := range toolInfo.Required {
972 if _, exists := input[required]; !exists {
973 return fmt.Errorf("missing required parameter: %s", required)
974 }
975 }
976 return nil
977}
978
979func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
980 if prompt == "" {
981 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
982 }
983
984 var preparedPrompt Prompt
985
986 if system != "" {
987 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
988 }
989 preparedPrompt = append(preparedPrompt, messages...)
990 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
991 return preparedPrompt, nil
992}
993
994func WithSystemPrompt(prompt string) AgentOption {
995 return func(s *agentSettings) {
996 s.systemPrompt = prompt
997 }
998}
999
1000func WithMaxOutputTokens(tokens int64) AgentOption {
1001 return func(s *agentSettings) {
1002 s.maxOutputTokens = &tokens
1003 }
1004}
1005
1006func WithTemperature(temp float64) AgentOption {
1007 return func(s *agentSettings) {
1008 s.temperature = &temp
1009 }
1010}
1011
1012func WithTopP(topP float64) AgentOption {
1013 return func(s *agentSettings) {
1014 s.topP = &topP
1015 }
1016}
1017
1018func WithTopK(topK int64) AgentOption {
1019 return func(s *agentSettings) {
1020 s.topK = &topK
1021 }
1022}
1023
1024func WithPresencePenalty(penalty float64) AgentOption {
1025 return func(s *agentSettings) {
1026 s.presencePenalty = &penalty
1027 }
1028}
1029
1030func WithFrequencyPenalty(penalty float64) AgentOption {
1031 return func(s *agentSettings) {
1032 s.frequencyPenalty = &penalty
1033 }
1034}
1035
1036func WithTools(tools ...AgentTool) AgentOption {
1037 return func(s *agentSettings) {
1038 s.tools = append(s.tools, tools...)
1039 }
1040}
1041
1042func WithStopConditions(conditions ...StopCondition) AgentOption {
1043 return func(s *agentSettings) {
1044 s.stopWhen = append(s.stopWhen, conditions...)
1045 }
1046}
1047
1048func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1049 return func(s *agentSettings) {
1050 s.prepareStep = fn
1051 }
1052}
1053
1054func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1055 return func(s *agentSettings) {
1056 s.repairToolCall = fn
1057 }
1058}
1059
1060// processStepStream processes a single step's stream and returns the step result.
1061func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1062 var stepContent []Content
1063 var stepToolCalls []ToolCallContent
1064 var stepUsage Usage
1065 stepFinishReason := FinishReasonUnknown
1066 var stepWarnings []CallWarning
1067 var stepProviderMetadata ProviderMetadata
1068
1069 activeToolCalls := make(map[string]*ToolCallContent)
1070 activeTextContent := make(map[string]string)
1071 type reasoningContent struct {
1072 content string
1073 options ProviderMetadata
1074 }
1075 activeReasoningContent := make(map[string]reasoningContent)
1076
1077 // Process stream parts
1078 for part := range stream {
1079 // Forward all parts to chunk callback
1080 if opts.OnChunk != nil {
1081 err := opts.OnChunk(part)
1082 if err != nil {
1083 return StepResult{}, false, err
1084 }
1085 }
1086
1087 switch part.Type {
1088 case StreamPartTypeWarnings:
1089 stepWarnings = part.Warnings
1090 if opts.OnWarnings != nil {
1091 err := opts.OnWarnings(part.Warnings)
1092 if err != nil {
1093 return StepResult{}, false, err
1094 }
1095 }
1096
1097 case StreamPartTypeTextStart:
1098 activeTextContent[part.ID] = ""
1099 if opts.OnTextStart != nil {
1100 err := opts.OnTextStart(part.ID)
1101 if err != nil {
1102 return StepResult{}, false, err
1103 }
1104 }
1105
1106 case StreamPartTypeTextDelta:
1107 if _, exists := activeTextContent[part.ID]; exists {
1108 activeTextContent[part.ID] += part.Delta
1109 }
1110 if opts.OnTextDelta != nil {
1111 err := opts.OnTextDelta(part.ID, part.Delta)
1112 if err != nil {
1113 return StepResult{}, false, err
1114 }
1115 }
1116
1117 case StreamPartTypeTextEnd:
1118 if text, exists := activeTextContent[part.ID]; exists {
1119 stepContent = append(stepContent, TextContent{
1120 Text: text,
1121 ProviderMetadata: part.ProviderMetadata,
1122 })
1123 delete(activeTextContent, part.ID)
1124 }
1125 if opts.OnTextEnd != nil {
1126 err := opts.OnTextEnd(part.ID)
1127 if err != nil {
1128 return StepResult{}, false, err
1129 }
1130 }
1131
1132 case StreamPartTypeReasoningStart:
1133 activeReasoningContent[part.ID] = reasoningContent{content: ""}
1134 if opts.OnReasoningStart != nil {
1135 err := opts.OnReasoningStart(part.ID)
1136 if err != nil {
1137 return StepResult{}, false, err
1138 }
1139 }
1140
1141 case StreamPartTypeReasoningDelta:
1142 if active, exists := activeReasoningContent[part.ID]; exists {
1143 active.content += part.Delta
1144 active.options = part.ProviderMetadata
1145 activeReasoningContent[part.ID] = active
1146 }
1147 if opts.OnReasoningDelta != nil {
1148 err := opts.OnReasoningDelta(part.ID, part.Delta)
1149 if err != nil {
1150 return StepResult{}, false, err
1151 }
1152 }
1153
1154 case StreamPartTypeReasoningEnd:
1155 if active, exists := activeReasoningContent[part.ID]; exists {
1156 content := ReasoningContent{
1157 Text: active.content,
1158 ProviderMetadata: active.options,
1159 }
1160 stepContent = append(stepContent, content)
1161 if opts.OnReasoningEnd != nil {
1162 err := opts.OnReasoningEnd(part.ID, content)
1163 if err != nil {
1164 return StepResult{}, false, err
1165 }
1166 }
1167 delete(activeReasoningContent, part.ID)
1168 }
1169
1170 case StreamPartTypeToolInputStart:
1171 activeToolCalls[part.ID] = &ToolCallContent{
1172 ToolCallID: part.ID,
1173 ToolName: part.ToolCallName,
1174 Input: "",
1175 ProviderExecuted: part.ProviderExecuted,
1176 }
1177 if opts.OnToolInputStart != nil {
1178 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1179 if err != nil {
1180 return StepResult{}, false, err
1181 }
1182 }
1183
1184 case StreamPartTypeToolInputDelta:
1185 if toolCall, exists := activeToolCalls[part.ID]; exists {
1186 toolCall.Input += part.Delta
1187 }
1188 if opts.OnToolInputDelta != nil {
1189 err := opts.OnToolInputDelta(part.ID, part.Delta)
1190 if err != nil {
1191 return StepResult{}, false, err
1192 }
1193 }
1194
1195 case StreamPartTypeToolInputEnd:
1196 if opts.OnToolInputEnd != nil {
1197 err := opts.OnToolInputEnd(part.ID)
1198 if err != nil {
1199 return StepResult{}, false, err
1200 }
1201 }
1202
1203 case StreamPartTypeToolCall:
1204 toolCall := ToolCallContent{
1205 ToolCallID: part.ID,
1206 ToolName: part.ToolCallName,
1207 Input: part.ToolCallInput,
1208 ProviderExecuted: part.ProviderExecuted,
1209 ProviderMetadata: part.ProviderMetadata,
1210 }
1211
1212 // Validate and potentially repair the tool call
1213 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1214 stepToolCalls = append(stepToolCalls, validatedToolCall)
1215 stepContent = append(stepContent, validatedToolCall)
1216
1217 if opts.OnToolCall != nil {
1218 err := opts.OnToolCall(validatedToolCall)
1219 if err != nil {
1220 return StepResult{}, false, err
1221 }
1222 }
1223
1224 // Clean up active tool call
1225 delete(activeToolCalls, part.ID)
1226
1227 case StreamPartTypeSource:
1228 sourceContent := SourceContent{
1229 SourceType: part.SourceType,
1230 ID: part.ID,
1231 URL: part.URL,
1232 Title: part.Title,
1233 ProviderMetadata: part.ProviderMetadata,
1234 }
1235 stepContent = append(stepContent, sourceContent)
1236 if opts.OnSource != nil {
1237 err := opts.OnSource(sourceContent)
1238 if err != nil {
1239 return StepResult{}, false, err
1240 }
1241 }
1242
1243 case StreamPartTypeFinish:
1244 stepUsage = part.Usage
1245 stepFinishReason = part.FinishReason
1246 stepProviderMetadata = part.ProviderMetadata
1247 if opts.OnStreamFinish != nil {
1248 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1249 if err != nil {
1250 return StepResult{}, false, err
1251 }
1252 }
1253
1254 case StreamPartTypeError:
1255 if opts.OnError != nil {
1256 opts.OnError(part.Error)
1257 }
1258 return StepResult{}, false, part.Error
1259 }
1260 }
1261
1262 // Execute tools if any
1263 var toolResults []ToolResultContent
1264 if len(stepToolCalls) > 0 {
1265 var err error
1266 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1267 if err != nil {
1268 return StepResult{}, false, err
1269 }
1270 // Add tool results to content
1271 for _, result := range toolResults {
1272 stepContent = append(stepContent, result)
1273 }
1274 }
1275
1276 stepResult := StepResult{
1277 Response: Response{
1278 Content: stepContent,
1279 FinishReason: stepFinishReason,
1280 Usage: stepUsage,
1281 Warnings: stepWarnings,
1282 ProviderMetadata: stepProviderMetadata,
1283 },
1284 Messages: toResponseMessages(stepContent),
1285 }
1286
1287 // Determine if we should continue (has tool calls and not stopped)
1288 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1289
1290 return stepResult, shouldContinue, nil
1291}
1292
1293func addUsage(a, b Usage) Usage {
1294 return Usage{
1295 InputTokens: a.InputTokens + b.InputTokens,
1296 OutputTokens: a.OutputTokens + b.OutputTokens,
1297 TotalTokens: a.TotalTokens + b.TotalTokens,
1298 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1299 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1300 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1301 }
1302}
1303
1304func WithHeaders(headers map[string]string) AgentOption {
1305 return func(s *agentSettings) {
1306 s.headers = headers
1307 }
1308}
1309
1310func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1311 return func(s *agentSettings) {
1312 s.providerOptions = providerOptions
1313 }
1314}