1package ai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12)
13
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19type StopCondition = func(steps []StepResult) bool
20
21// StepCountIs returns a stop condition that stops after the specified number of steps.
22func StepCountIs(stepCount int) StopCondition {
23 return func(steps []StepResult) bool {
24 return len(steps) >= stepCount
25 }
26}
27
28// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
29func HasToolCall(toolName string) StopCondition {
30 return func(steps []StepResult) bool {
31 if len(steps) == 0 {
32 return false
33 }
34 lastStep := steps[len(steps)-1]
35 toolCalls := lastStep.Content.ToolCalls()
36 for _, toolCall := range toolCalls {
37 if toolCall.ToolName == toolName {
38 return true
39 }
40 }
41 return false
42 }
43}
44
45// HasContent returns a stop condition that stops when the specified content type appears in the last step.
46func HasContent(contentType ContentType) StopCondition {
47 return func(steps []StepResult) bool {
48 if len(steps) == 0 {
49 return false
50 }
51 lastStep := steps[len(steps)-1]
52 for _, content := range lastStep.Content {
53 if content.GetType() == contentType {
54 return true
55 }
56 }
57 return false
58 }
59}
60
61// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
62func FinishReasonIs(reason FinishReason) StopCondition {
63 return func(steps []StepResult) bool {
64 if len(steps) == 0 {
65 return false
66 }
67 lastStep := steps[len(steps)-1]
68 return lastStep.FinishReason == reason
69 }
70}
71
72// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
73func MaxTokensUsed(maxTokens int64) StopCondition {
74 return func(steps []StepResult) bool {
75 var totalTokens int64
76 for _, step := range steps {
77 totalTokens += step.Usage.TotalTokens
78 }
79 return totalTokens >= maxTokens
80 }
81}
82
83type PrepareStepFunctionOptions struct {
84 Steps []StepResult
85 StepNumber int
86 Model LanguageModel
87 Messages []Message
88}
89
90type PrepareStepResult struct {
91 Model LanguageModel
92 Messages []Message
93 System *string
94 ToolChoice *ToolChoice
95 ActiveTools []string
96 DisableAllTools bool
97}
98
99type ToolCallRepairOptions struct {
100 OriginalToolCall ToolCallContent
101 ValidationError error
102 AvailableTools []AgentTool
103 SystemPrompt string
104 Messages []Message
105}
106
107type (
108 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
109 OnStepFinishedFunction = func(step StepResult)
110 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
111)
112
113type agentSettings struct {
114 systemPrompt string
115 maxOutputTokens *int64
116 temperature *float64
117 topP *float64
118 topK *int64
119 presencePenalty *float64
120 frequencyPenalty *float64
121 headers map[string]string
122 providerOptions ProviderOptions
123
124 // TODO: add support for provider tools
125 tools []AgentTool
126 maxRetries *int
127
128 model LanguageModel
129
130 stopWhen []StopCondition
131 prepareStep PrepareStepFunction
132 repairToolCall RepairToolCallFunction
133 onRetry OnRetryCallback
134}
135
136type AgentCall struct {
137 Prompt string `json:"prompt"`
138 Files []FilePart `json:"files"`
139 Messages []Message `json:"messages"`
140 MaxOutputTokens *int64
141 Temperature *float64 `json:"temperature"`
142 TopP *float64 `json:"top_p"`
143 TopK *int64 `json:"top_k"`
144 PresencePenalty *float64 `json:"presence_penalty"`
145 FrequencyPenalty *float64 `json:"frequency_penalty"`
146 ActiveTools []string `json:"active_tools"`
147 ProviderOptions ProviderOptions
148 OnRetry OnRetryCallback
149 MaxRetries *int
150
151 StopWhen []StopCondition
152 PrepareStep PrepareStepFunction
153 RepairToolCall RepairToolCallFunction
154}
155
156// Agent-level callbacks.
157type (
158 // OnAgentStartFunc is called when agent starts.
159 OnAgentStartFunc func()
160
161 // OnAgentFinishFunc is called when agent finishes.
162 OnAgentFinishFunc func(result *AgentResult) error
163
164 // OnStepStartFunc is called when a step starts.
165 OnStepStartFunc func(stepNumber int) error
166
167 // OnStepFinishFunc is called when a step finishes.
168 OnStepFinishFunc func(stepResult StepResult) error
169
170 // OnFinishFunc is called when entire agent completes.
171 OnFinishFunc func(result *AgentResult)
172
173 // OnErrorFunc is called when an error occurs.
174 OnErrorFunc func(error)
175)
176
177// Stream part callbacks - called for each corresponding stream part type.
178type (
179 // OnChunkFunc is called for each stream part (catch-all).
180 OnChunkFunc func(StreamPart) error
181
182 // OnWarningsFunc is called for warnings.
183 OnWarningsFunc func(warnings []CallWarning) error
184
185 // OnTextStartFunc is called when text starts.
186 OnTextStartFunc func(id string) error
187
188 // OnTextDeltaFunc is called for text deltas.
189 OnTextDeltaFunc func(id, text string) error
190
191 // OnTextEndFunc is called when text ends.
192 OnTextEndFunc func(id string) error
193
194 // OnReasoningStartFunc is called when reasoning starts.
195 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
196
197 // OnReasoningDeltaFunc is called for reasoning deltas.
198 OnReasoningDeltaFunc func(id, text string) error
199
200 // OnReasoningEndFunc is called when reasoning ends.
201 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
202
203 // OnToolInputStartFunc is called when tool input starts.
204 OnToolInputStartFunc func(id, toolName string) error
205
206 // OnToolInputDeltaFunc is called for tool input deltas.
207 OnToolInputDeltaFunc func(id, delta string) error
208
209 // OnToolInputEndFunc is called when tool input ends.
210 OnToolInputEndFunc func(id string) error
211
212 // OnToolCallFunc is called when tool call is complete.
213 OnToolCallFunc func(toolCall ToolCallContent) error
214
215 // OnToolResultFunc is called when tool execution completes.
216 OnToolResultFunc func(result ToolResultContent) error
217
218 // OnSourceFunc is called for source references.
219 OnSourceFunc func(source SourceContent) error
220
221 // OnStreamFinishFunc is called when stream finishes.
222 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
223)
224
225type AgentStreamCall struct {
226 Prompt string `json:"prompt"`
227 Files []FilePart `json:"files"`
228 Messages []Message `json:"messages"`
229 MaxOutputTokens *int64
230 Temperature *float64 `json:"temperature"`
231 TopP *float64 `json:"top_p"`
232 TopK *int64 `json:"top_k"`
233 PresencePenalty *float64 `json:"presence_penalty"`
234 FrequencyPenalty *float64 `json:"frequency_penalty"`
235 ActiveTools []string `json:"active_tools"`
236 Headers map[string]string
237 ProviderOptions ProviderOptions
238 OnRetry OnRetryCallback
239 MaxRetries *int
240
241 StopWhen []StopCondition
242 PrepareStep PrepareStepFunction
243 RepairToolCall RepairToolCallFunction
244
245 // Agent-level callbacks
246 OnAgentStart OnAgentStartFunc // Called when agent starts
247 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
248 OnStepStart OnStepStartFunc // Called when a step starts
249 OnStepFinish OnStepFinishFunc // Called when a step finishes
250 OnFinish OnFinishFunc // Called when entire agent completes
251 OnError OnErrorFunc // Called when an error occurs
252
253 // Stream part callbacks - called for each corresponding stream part type
254 OnChunk OnChunkFunc // Called for each stream part (catch-all)
255 OnWarnings OnWarningsFunc // Called for warnings
256 OnTextStart OnTextStartFunc // Called when text starts
257 OnTextDelta OnTextDeltaFunc // Called for text deltas
258 OnTextEnd OnTextEndFunc // Called when text ends
259 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
260 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
261 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
262 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
263 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
264 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
265 OnToolCall OnToolCallFunc // Called when tool call is complete
266 OnToolResult OnToolResultFunc // Called when tool execution completes
267 OnSource OnSourceFunc // Called for source references
268 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
269}
270
271type AgentResult struct {
272 Steps []StepResult
273 // Final response
274 Response Response
275 TotalUsage Usage
276}
277
278type Agent interface {
279 Generate(context.Context, AgentCall) (*AgentResult, error)
280 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
281}
282
283type AgentOption = func(*agentSettings)
284
285type agent struct {
286 settings agentSettings
287}
288
289func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
290 settings := agentSettings{
291 model: model,
292 }
293 for _, o := range opts {
294 o(&settings)
295 }
296 return &agent{
297 settings: settings,
298 }
299}
300
301func (a *agent) prepareCall(call AgentCall) AgentCall {
302 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
303 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
304 call.TopP = cmp.Or(call.TopP, a.settings.topP)
305 call.TopK = cmp.Or(call.TopK, a.settings.topK)
306 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
307 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
308 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
309
310 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
311 call.StopWhen = a.settings.stopWhen
312 }
313 if call.PrepareStep == nil && a.settings.prepareStep != nil {
314 call.PrepareStep = a.settings.prepareStep
315 }
316 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
317 call.RepairToolCall = a.settings.repairToolCall
318 }
319 if call.OnRetry == nil && a.settings.onRetry != nil {
320 call.OnRetry = a.settings.onRetry
321 }
322
323 providerOptions := ProviderOptions{}
324 if a.settings.providerOptions != nil {
325 maps.Copy(providerOptions, a.settings.providerOptions)
326 }
327 if call.ProviderOptions != nil {
328 maps.Copy(providerOptions, call.ProviderOptions)
329 }
330 call.ProviderOptions = providerOptions
331
332 headers := map[string]string{}
333
334 if a.settings.headers != nil {
335 maps.Copy(headers, a.settings.headers)
336 }
337
338 return call
339}
340
341// Generate implements Agent.
342func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
343 opts = a.prepareCall(opts)
344 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
345 if err != nil {
346 return nil, err
347 }
348 var responseMessages []Message
349 var steps []StepResult
350
351 for {
352 stepInputMessages := append(initialPrompt, responseMessages...)
353 stepModel := a.settings.model
354 stepSystemPrompt := a.settings.systemPrompt
355 stepActiveTools := opts.ActiveTools
356 stepToolChoice := ToolChoiceAuto
357 disableAllTools := false
358
359 if opts.PrepareStep != nil {
360 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
361 Model: stepModel,
362 Steps: steps,
363 StepNumber: len(steps),
364 Messages: stepInputMessages,
365 })
366 if err != nil {
367 return nil, err
368 }
369
370 ctx = updatedCtx
371
372 // Apply prepared step modifications
373 if prepared.Messages != nil {
374 stepInputMessages = prepared.Messages
375 }
376 if prepared.Model != nil {
377 stepModel = prepared.Model
378 }
379 if prepared.System != nil {
380 stepSystemPrompt = *prepared.System
381 }
382 if prepared.ToolChoice != nil {
383 stepToolChoice = *prepared.ToolChoice
384 }
385 if len(prepared.ActiveTools) > 0 {
386 stepActiveTools = prepared.ActiveTools
387 }
388 disableAllTools = prepared.DisableAllTools
389 }
390
391 // Recreate prompt with potentially modified system prompt
392 if stepSystemPrompt != a.settings.systemPrompt {
393 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
394 if err != nil {
395 return nil, err
396 }
397 // Replace system message part, keep the rest
398 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
399 stepInputMessages[0] = stepPrompt[0] // Replace system message
400 }
401 }
402
403 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
404
405 retryOptions := DefaultRetryOptions()
406 retryOptions.OnRetry = opts.OnRetry
407 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
408
409 result, err := retry(ctx, func() (*Response, error) {
410 return stepModel.Generate(ctx, Call{
411 Prompt: stepInputMessages,
412 MaxOutputTokens: opts.MaxOutputTokens,
413 Temperature: opts.Temperature,
414 TopP: opts.TopP,
415 TopK: opts.TopK,
416 PresencePenalty: opts.PresencePenalty,
417 FrequencyPenalty: opts.FrequencyPenalty,
418 Tools: preparedTools,
419 ToolChoice: &stepToolChoice,
420 ProviderOptions: opts.ProviderOptions,
421 })
422 })
423 if err != nil {
424 return nil, err
425 }
426
427 var stepToolCalls []ToolCallContent
428 for _, content := range result.Content {
429 if content.GetType() == ContentTypeToolCall {
430 toolCall, ok := AsContentType[ToolCallContent](content)
431 if !ok {
432 continue
433 }
434
435 // Validate and potentially repair the tool call
436 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
437 stepToolCalls = append(stepToolCalls, validatedToolCall)
438 }
439 }
440
441 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
442
443 // Build step content with validated tool calls and tool results
444 stepContent := []Content{}
445 toolCallIndex := 0
446 for _, content := range result.Content {
447 if content.GetType() == ContentTypeToolCall {
448 // Replace with validated tool call
449 if toolCallIndex < len(stepToolCalls) {
450 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
451 toolCallIndex++
452 }
453 } else {
454 // Keep other content as-is
455 stepContent = append(stepContent, content)
456 }
457 }
458 // Add tool results
459 for _, result := range toolResults {
460 stepContent = append(stepContent, result)
461 }
462 currentStepMessages := toResponseMessages(stepContent)
463 responseMessages = append(responseMessages, currentStepMessages...)
464
465 stepResult := StepResult{
466 Response: Response{
467 Content: stepContent,
468 FinishReason: result.FinishReason,
469 Usage: result.Usage,
470 Warnings: result.Warnings,
471 ProviderMetadata: result.ProviderMetadata,
472 },
473 Messages: currentStepMessages,
474 }
475 steps = append(steps, stepResult)
476 shouldStop := isStopConditionMet(opts.StopWhen, steps)
477
478 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
479 break
480 }
481 }
482
483 totalUsage := Usage{}
484
485 for _, step := range steps {
486 usage := step.Usage
487 totalUsage.InputTokens += usage.InputTokens
488 totalUsage.OutputTokens += usage.OutputTokens
489 totalUsage.ReasoningTokens += usage.ReasoningTokens
490 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
491 totalUsage.CacheReadTokens += usage.CacheReadTokens
492 totalUsage.TotalTokens += usage.TotalTokens
493 }
494
495 agentResult := &AgentResult{
496 Steps: steps,
497 Response: steps[len(steps)-1].Response,
498 TotalUsage: totalUsage,
499 }
500 return agentResult, nil
501}
502
503func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
504 if len(conditions) == 0 {
505 return false
506 }
507
508 for _, condition := range conditions {
509 if condition(steps) {
510 return true
511 }
512 }
513 return false
514}
515
516func toResponseMessages(content []Content) []Message {
517 var assistantParts []MessagePart
518 var toolParts []MessagePart
519
520 for _, c := range content {
521 switch c.GetType() {
522 case ContentTypeText:
523 text, ok := AsContentType[TextContent](c)
524 if !ok {
525 continue
526 }
527 assistantParts = append(assistantParts, TextPart{
528 Text: text.Text,
529 ProviderOptions: ProviderOptions(text.ProviderMetadata),
530 })
531 case ContentTypeReasoning:
532 reasoning, ok := AsContentType[ReasoningContent](c)
533 if !ok {
534 continue
535 }
536 assistantParts = append(assistantParts, ReasoningPart{
537 Text: reasoning.Text,
538 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
539 })
540 case ContentTypeToolCall:
541 toolCall, ok := AsContentType[ToolCallContent](c)
542 if !ok {
543 continue
544 }
545 assistantParts = append(assistantParts, ToolCallPart{
546 ToolCallID: toolCall.ToolCallID,
547 ToolName: toolCall.ToolName,
548 Input: toolCall.Input,
549 ProviderExecuted: toolCall.ProviderExecuted,
550 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
551 })
552 case ContentTypeFile:
553 file, ok := AsContentType[FileContent](c)
554 if !ok {
555 continue
556 }
557 assistantParts = append(assistantParts, FilePart{
558 Data: file.Data,
559 MediaType: file.MediaType,
560 ProviderOptions: ProviderOptions(file.ProviderMetadata),
561 })
562 case ContentTypeSource:
563 // Sources are metadata about references used to generate the response.
564 // They don't need to be included in the conversation messages.
565 continue
566 case ContentTypeToolResult:
567 result, ok := AsContentType[ToolResultContent](c)
568 if !ok {
569 continue
570 }
571 toolParts = append(toolParts, ToolResultPart{
572 ToolCallID: result.ToolCallID,
573 Output: result.Result,
574 ProviderOptions: ProviderOptions(result.ProviderMetadata),
575 })
576 }
577 }
578
579 var messages []Message
580 if len(assistantParts) > 0 {
581 messages = append(messages, Message{
582 Role: MessageRoleAssistant,
583 Content: assistantParts,
584 })
585 }
586 if len(toolParts) > 0 {
587 messages = append(messages, Message{
588 Role: MessageRoleTool,
589 Content: toolParts,
590 })
591 }
592 return messages
593}
594
595func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
596 if len(toolCalls) == 0 {
597 return nil, nil
598 }
599
600 // Create a map for quick tool lookup
601 toolMap := make(map[string]AgentTool)
602 for _, tool := range allTools {
603 toolMap[tool.Info().Name] = tool
604 }
605
606 // Execute all tool calls in parallel
607 results := make([]ToolResultContent, len(toolCalls))
608 executeErrors := make([]error, len(toolCalls))
609
610 var wg sync.WaitGroup
611
612 for i, toolCall := range toolCalls {
613 wg.Add(1)
614 go func(index int, call ToolCallContent) {
615 defer wg.Done()
616
617 // Skip invalid tool calls - create error result
618 if call.Invalid {
619 results[index] = ToolResultContent{
620 ToolCallID: call.ToolCallID,
621 ToolName: call.ToolName,
622 Result: ToolResultOutputContentError{
623 Error: call.ValidationError,
624 },
625 ProviderExecuted: false,
626 }
627 if toolResultCallback != nil {
628 err := toolResultCallback(results[index])
629 if err != nil {
630 executeErrors[index] = err
631 }
632 }
633
634 return
635 }
636
637 tool, exists := toolMap[call.ToolName]
638 if !exists {
639 results[index] = ToolResultContent{
640 ToolCallID: call.ToolCallID,
641 ToolName: call.ToolName,
642 Result: ToolResultOutputContentError{
643 Error: errors.New("Error: Tool not found: " + call.ToolName),
644 },
645 ProviderExecuted: false,
646 }
647
648 if toolResultCallback != nil {
649 err := toolResultCallback(results[index])
650 if err != nil {
651 executeErrors[index] = err
652 }
653 }
654 return
655 }
656
657 // Execute the tool
658 result, err := tool.Run(ctx, ToolCall{
659 ID: call.ToolCallID,
660 Name: call.ToolName,
661 Input: call.Input,
662 })
663 if err != nil {
664 results[index] = ToolResultContent{
665 ToolCallID: call.ToolCallID,
666 ToolName: call.ToolName,
667 Result: ToolResultOutputContentError{
668 Error: err,
669 },
670 ClientMetadata: result.Metadata,
671 ProviderExecuted: false,
672 }
673 if toolResultCallback != nil {
674 cbErr := toolResultCallback(results[index])
675 if cbErr != nil {
676 executeErrors[index] = cbErr
677 }
678 }
679 executeErrors[index] = err
680 return
681 }
682
683 if result.IsError {
684 results[index] = ToolResultContent{
685 ToolCallID: call.ToolCallID,
686 ToolName: call.ToolName,
687 Result: ToolResultOutputContentError{
688 Error: errors.New(result.Content),
689 },
690 ClientMetadata: result.Metadata,
691 ProviderExecuted: false,
692 }
693
694 if toolResultCallback != nil {
695 err := toolResultCallback(results[index])
696 if err != nil {
697 executeErrors[index] = err
698 }
699 }
700 } else {
701 results[index] = ToolResultContent{
702 ToolCallID: call.ToolCallID,
703 ToolName: toolCall.ToolName,
704 Result: ToolResultOutputContentText{
705 Text: result.Content,
706 },
707 ClientMetadata: result.Metadata,
708 ProviderExecuted: false,
709 }
710 if toolResultCallback != nil {
711 err := toolResultCallback(results[index])
712 if err != nil {
713 executeErrors[index] = err
714 }
715 }
716 }
717 }(i, toolCall)
718 }
719
720 // Wait for all tool executions to complete
721 wg.Wait()
722
723 for _, err := range executeErrors {
724 if err != nil {
725 return nil, err
726 }
727 }
728
729 return results, nil
730}
731
732// Stream implements Agent.
733func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
734 // Convert AgentStreamCall to AgentCall for preparation
735 call := AgentCall{
736 Prompt: opts.Prompt,
737 Files: opts.Files,
738 Messages: opts.Messages,
739 MaxOutputTokens: opts.MaxOutputTokens,
740 Temperature: opts.Temperature,
741 TopP: opts.TopP,
742 TopK: opts.TopK,
743 PresencePenalty: opts.PresencePenalty,
744 FrequencyPenalty: opts.FrequencyPenalty,
745 ActiveTools: opts.ActiveTools,
746 ProviderOptions: opts.ProviderOptions,
747 MaxRetries: opts.MaxRetries,
748 StopWhen: opts.StopWhen,
749 PrepareStep: opts.PrepareStep,
750 RepairToolCall: opts.RepairToolCall,
751 }
752
753 call = a.prepareCall(call)
754
755 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
756 if err != nil {
757 return nil, err
758 }
759
760 var responseMessages []Message
761 var steps []StepResult
762 var totalUsage Usage
763
764 // Start agent stream
765 if opts.OnAgentStart != nil {
766 opts.OnAgentStart()
767 }
768
769 for stepNumber := 0; ; stepNumber++ {
770 stepInputMessages := append(initialPrompt, responseMessages...)
771 stepModel := a.settings.model
772 stepSystemPrompt := a.settings.systemPrompt
773 stepActiveTools := call.ActiveTools
774 stepToolChoice := ToolChoiceAuto
775 disableAllTools := false
776
777 // Apply step preparation if provided
778 if call.PrepareStep != nil {
779 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
780 Model: stepModel,
781 Steps: steps,
782 StepNumber: stepNumber,
783 Messages: stepInputMessages,
784 })
785 if err != nil {
786 return nil, err
787 }
788
789 ctx = updatedCtx
790
791 if prepared.Messages != nil {
792 stepInputMessages = prepared.Messages
793 }
794 if prepared.Model != nil {
795 stepModel = prepared.Model
796 }
797 if prepared.System != nil {
798 stepSystemPrompt = *prepared.System
799 }
800 if prepared.ToolChoice != nil {
801 stepToolChoice = *prepared.ToolChoice
802 }
803 if len(prepared.ActiveTools) > 0 {
804 stepActiveTools = prepared.ActiveTools
805 }
806 disableAllTools = prepared.DisableAllTools
807 }
808
809 // Recreate prompt with potentially modified system prompt
810 if stepSystemPrompt != a.settings.systemPrompt {
811 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
812 if err != nil {
813 return nil, err
814 }
815 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
816 stepInputMessages[0] = stepPrompt[0]
817 }
818 }
819
820 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
821
822 // Start step stream
823 if opts.OnStepStart != nil {
824 _ = opts.OnStepStart(stepNumber)
825 }
826
827 // Create streaming call
828 streamCall := Call{
829 Prompt: stepInputMessages,
830 MaxOutputTokens: call.MaxOutputTokens,
831 Temperature: call.Temperature,
832 TopP: call.TopP,
833 TopK: call.TopK,
834 PresencePenalty: call.PresencePenalty,
835 FrequencyPenalty: call.FrequencyPenalty,
836 Tools: preparedTools,
837 ToolChoice: &stepToolChoice,
838 ProviderOptions: call.ProviderOptions,
839 }
840
841 // Get streaming response
842 stream, err := stepModel.Stream(ctx, streamCall)
843 if err != nil {
844 if opts.OnError != nil {
845 opts.OnError(err)
846 }
847 return nil, err
848 }
849
850 // Process stream with tool execution
851 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
852 if err != nil {
853 if opts.OnError != nil {
854 opts.OnError(err)
855 }
856 return nil, err
857 }
858
859 steps = append(steps, stepResult)
860 totalUsage = addUsage(totalUsage, stepResult.Usage)
861
862 // Call step finished callback
863 if opts.OnStepFinish != nil {
864 _ = opts.OnStepFinish(stepResult)
865 }
866
867 // Add step messages to response messages
868 stepMessages := toResponseMessages(stepResult.Content)
869 responseMessages = append(responseMessages, stepMessages...)
870
871 // Check stop conditions
872 shouldStop := isStopConditionMet(call.StopWhen, steps)
873 if shouldStop || !shouldContinue {
874 break
875 }
876 }
877
878 // Finish agent stream
879 agentResult := &AgentResult{
880 Steps: steps,
881 Response: steps[len(steps)-1].Response,
882 TotalUsage: totalUsage,
883 }
884
885 if opts.OnFinish != nil {
886 opts.OnFinish(agentResult)
887 }
888
889 if opts.OnAgentFinish != nil {
890 _ = opts.OnAgentFinish(agentResult)
891 }
892
893 return agentResult, nil
894}
895
896func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
897 preparedTools := make([]Tool, 0, len(tools))
898
899 // If explicitly disabling all tools, return no tools
900 if disableAllTools {
901 return preparedTools
902 }
903
904 for _, tool := range tools {
905 // If activeTools has items, only include tools in the list
906 // If activeTools is empty, include all tools
907 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
908 continue
909 }
910 info := tool.Info()
911 preparedTools = append(preparedTools, FunctionTool{
912 Name: info.Name,
913 Description: info.Description,
914 InputSchema: map[string]any{
915 "type": "object",
916 "properties": info.Parameters,
917 "required": info.Required,
918 },
919 ProviderOptions: tool.ProviderOptions(),
920 })
921 }
922 return preparedTools
923}
924
925// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
926func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
927 if err := a.validateToolCall(toolCall, availableTools); err == nil {
928 return toolCall
929 } else {
930 if repairFunc != nil {
931 repairOptions := ToolCallRepairOptions{
932 OriginalToolCall: toolCall,
933 ValidationError: err,
934 AvailableTools: availableTools,
935 SystemPrompt: systemPrompt,
936 Messages: messages,
937 }
938
939 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
940 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
941 return *repairedToolCall
942 }
943 }
944 }
945
946 invalidToolCall := toolCall
947 invalidToolCall.Invalid = true
948 invalidToolCall.ValidationError = err
949 return invalidToolCall
950 }
951}
952
953// validateToolCall validates a tool call against available tools and their schemas.
954func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
955 var tool AgentTool
956 for _, t := range availableTools {
957 if t.Info().Name == toolCall.ToolName {
958 tool = t
959 break
960 }
961 }
962
963 if tool == nil {
964 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
965 }
966
967 // Validate JSON parsing
968 var input map[string]any
969 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
970 return fmt.Errorf("invalid JSON input: %w", err)
971 }
972
973 // Basic schema validation (check required fields)
974 // TODO: more robust schema validation using JSON Schema or similar
975 toolInfo := tool.Info()
976 for _, required := range toolInfo.Required {
977 if _, exists := input[required]; !exists {
978 return fmt.Errorf("missing required parameter: %s", required)
979 }
980 }
981 return nil
982}
983
984func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
985 if prompt == "" {
986 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
987 }
988
989 var preparedPrompt Prompt
990
991 if system != "" {
992 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
993 }
994 preparedPrompt = append(preparedPrompt, messages...)
995 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
996 return preparedPrompt, nil
997}
998
999func WithSystemPrompt(prompt string) AgentOption {
1000 return func(s *agentSettings) {
1001 s.systemPrompt = prompt
1002 }
1003}
1004
1005func WithMaxOutputTokens(tokens int64) AgentOption {
1006 return func(s *agentSettings) {
1007 s.maxOutputTokens = &tokens
1008 }
1009}
1010
1011func WithTemperature(temp float64) AgentOption {
1012 return func(s *agentSettings) {
1013 s.temperature = &temp
1014 }
1015}
1016
1017func WithTopP(topP float64) AgentOption {
1018 return func(s *agentSettings) {
1019 s.topP = &topP
1020 }
1021}
1022
1023func WithTopK(topK int64) AgentOption {
1024 return func(s *agentSettings) {
1025 s.topK = &topK
1026 }
1027}
1028
1029func WithPresencePenalty(penalty float64) AgentOption {
1030 return func(s *agentSettings) {
1031 s.presencePenalty = &penalty
1032 }
1033}
1034
1035func WithFrequencyPenalty(penalty float64) AgentOption {
1036 return func(s *agentSettings) {
1037 s.frequencyPenalty = &penalty
1038 }
1039}
1040
1041func WithTools(tools ...AgentTool) AgentOption {
1042 return func(s *agentSettings) {
1043 s.tools = append(s.tools, tools...)
1044 }
1045}
1046
1047func WithStopConditions(conditions ...StopCondition) AgentOption {
1048 return func(s *agentSettings) {
1049 s.stopWhen = append(s.stopWhen, conditions...)
1050 }
1051}
1052
1053func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1054 return func(s *agentSettings) {
1055 s.prepareStep = fn
1056 }
1057}
1058
1059func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1060 return func(s *agentSettings) {
1061 s.repairToolCall = fn
1062 }
1063}
1064
1065// processStepStream processes a single step's stream and returns the step result.
1066func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1067 var stepContent []Content
1068 var stepToolCalls []ToolCallContent
1069 var stepUsage Usage
1070 stepFinishReason := FinishReasonUnknown
1071 var stepWarnings []CallWarning
1072 var stepProviderMetadata ProviderMetadata
1073
1074 activeToolCalls := make(map[string]*ToolCallContent)
1075 activeTextContent := make(map[string]string)
1076 type reasoningContent struct {
1077 content string
1078 options ProviderMetadata
1079 }
1080 activeReasoningContent := make(map[string]reasoningContent)
1081
1082 // Process stream parts
1083 for part := range stream {
1084 // Forward all parts to chunk callback
1085 if opts.OnChunk != nil {
1086 err := opts.OnChunk(part)
1087 if err != nil {
1088 return StepResult{}, false, err
1089 }
1090 }
1091
1092 switch part.Type {
1093 case StreamPartTypeWarnings:
1094 stepWarnings = part.Warnings
1095 if opts.OnWarnings != nil {
1096 err := opts.OnWarnings(part.Warnings)
1097 if err != nil {
1098 return StepResult{}, false, err
1099 }
1100 }
1101
1102 case StreamPartTypeTextStart:
1103 activeTextContent[part.ID] = ""
1104 if opts.OnTextStart != nil {
1105 err := opts.OnTextStart(part.ID)
1106 if err != nil {
1107 return StepResult{}, false, err
1108 }
1109 }
1110
1111 case StreamPartTypeTextDelta:
1112 if _, exists := activeTextContent[part.ID]; exists {
1113 activeTextContent[part.ID] += part.Delta
1114 }
1115 if opts.OnTextDelta != nil {
1116 err := opts.OnTextDelta(part.ID, part.Delta)
1117 if err != nil {
1118 return StepResult{}, false, err
1119 }
1120 }
1121
1122 case StreamPartTypeTextEnd:
1123 if text, exists := activeTextContent[part.ID]; exists {
1124 stepContent = append(stepContent, TextContent{
1125 Text: text,
1126 ProviderMetadata: part.ProviderMetadata,
1127 })
1128 delete(activeTextContent, part.ID)
1129 }
1130 if opts.OnTextEnd != nil {
1131 err := opts.OnTextEnd(part.ID)
1132 if err != nil {
1133 return StepResult{}, false, err
1134 }
1135 }
1136
1137 case StreamPartTypeReasoningStart:
1138 activeReasoningContent[part.ID] = reasoningContent{content: "", options: part.ProviderMetadata}
1139 if opts.OnReasoningStart != nil {
1140 content := ReasoningContent{
1141 Text: "",
1142 ProviderMetadata: part.ProviderMetadata,
1143 }
1144 err := opts.OnReasoningStart(part.ID, content)
1145 if err != nil {
1146 return StepResult{}, false, err
1147 }
1148 }
1149
1150 case StreamPartTypeReasoningDelta:
1151 if active, exists := activeReasoningContent[part.ID]; exists {
1152 active.content += part.Delta
1153 active.options = part.ProviderMetadata
1154 activeReasoningContent[part.ID] = active
1155 }
1156 if opts.OnReasoningDelta != nil {
1157 err := opts.OnReasoningDelta(part.ID, part.Delta)
1158 if err != nil {
1159 return StepResult{}, false, err
1160 }
1161 }
1162
1163 case StreamPartTypeReasoningEnd:
1164 if active, exists := activeReasoningContent[part.ID]; exists {
1165 if part.ProviderMetadata != nil {
1166 active.options = part.ProviderMetadata
1167 }
1168 content := ReasoningContent{
1169 Text: active.content,
1170 ProviderMetadata: active.options,
1171 }
1172 stepContent = append(stepContent, content)
1173 if opts.OnReasoningEnd != nil {
1174 err := opts.OnReasoningEnd(part.ID, content)
1175 if err != nil {
1176 return StepResult{}, false, err
1177 }
1178 }
1179 delete(activeReasoningContent, part.ID)
1180 }
1181
1182 case StreamPartTypeToolInputStart:
1183 activeToolCalls[part.ID] = &ToolCallContent{
1184 ToolCallID: part.ID,
1185 ToolName: part.ToolCallName,
1186 Input: "",
1187 ProviderExecuted: part.ProviderExecuted,
1188 }
1189 if opts.OnToolInputStart != nil {
1190 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1191 if err != nil {
1192 return StepResult{}, false, err
1193 }
1194 }
1195
1196 case StreamPartTypeToolInputDelta:
1197 if toolCall, exists := activeToolCalls[part.ID]; exists {
1198 toolCall.Input += part.Delta
1199 }
1200 if opts.OnToolInputDelta != nil {
1201 err := opts.OnToolInputDelta(part.ID, part.Delta)
1202 if err != nil {
1203 return StepResult{}, false, err
1204 }
1205 }
1206
1207 case StreamPartTypeToolInputEnd:
1208 if opts.OnToolInputEnd != nil {
1209 err := opts.OnToolInputEnd(part.ID)
1210 if err != nil {
1211 return StepResult{}, false, err
1212 }
1213 }
1214
1215 case StreamPartTypeToolCall:
1216 toolCall := ToolCallContent{
1217 ToolCallID: part.ID,
1218 ToolName: part.ToolCallName,
1219 Input: part.ToolCallInput,
1220 ProviderExecuted: part.ProviderExecuted,
1221 ProviderMetadata: part.ProviderMetadata,
1222 }
1223
1224 // Validate and potentially repair the tool call
1225 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1226 stepToolCalls = append(stepToolCalls, validatedToolCall)
1227 stepContent = append(stepContent, validatedToolCall)
1228
1229 if opts.OnToolCall != nil {
1230 err := opts.OnToolCall(validatedToolCall)
1231 if err != nil {
1232 return StepResult{}, false, err
1233 }
1234 }
1235
1236 // Clean up active tool call
1237 delete(activeToolCalls, part.ID)
1238
1239 case StreamPartTypeSource:
1240 sourceContent := SourceContent{
1241 SourceType: part.SourceType,
1242 ID: part.ID,
1243 URL: part.URL,
1244 Title: part.Title,
1245 ProviderMetadata: part.ProviderMetadata,
1246 }
1247 stepContent = append(stepContent, sourceContent)
1248 if opts.OnSource != nil {
1249 err := opts.OnSource(sourceContent)
1250 if err != nil {
1251 return StepResult{}, false, err
1252 }
1253 }
1254
1255 case StreamPartTypeFinish:
1256 stepUsage = part.Usage
1257 stepFinishReason = part.FinishReason
1258 stepProviderMetadata = part.ProviderMetadata
1259 if opts.OnStreamFinish != nil {
1260 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1261 if err != nil {
1262 return StepResult{}, false, err
1263 }
1264 }
1265
1266 case StreamPartTypeError:
1267 if opts.OnError != nil {
1268 opts.OnError(part.Error)
1269 }
1270 return StepResult{}, false, part.Error
1271 }
1272 }
1273
1274 // Execute tools if any
1275 var toolResults []ToolResultContent
1276 if len(stepToolCalls) > 0 {
1277 var err error
1278 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1279 if err != nil {
1280 return StepResult{}, false, err
1281 }
1282 // Add tool results to content
1283 for _, result := range toolResults {
1284 stepContent = append(stepContent, result)
1285 }
1286 }
1287
1288 stepResult := StepResult{
1289 Response: Response{
1290 Content: stepContent,
1291 FinishReason: stepFinishReason,
1292 Usage: stepUsage,
1293 Warnings: stepWarnings,
1294 ProviderMetadata: stepProviderMetadata,
1295 },
1296 Messages: toResponseMessages(stepContent),
1297 }
1298
1299 // Determine if we should continue (has tool calls and not stopped)
1300 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1301
1302 return stepResult, shouldContinue, nil
1303}
1304
1305func addUsage(a, b Usage) Usage {
1306 return Usage{
1307 InputTokens: a.InputTokens + b.InputTokens,
1308 OutputTokens: a.OutputTokens + b.OutputTokens,
1309 TotalTokens: a.TotalTokens + b.TotalTokens,
1310 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1311 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1312 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1313 }
1314}
1315
1316func WithHeaders(headers map[string]string) AgentOption {
1317 return func(s *agentSettings) {
1318 s.headers = headers
1319 }
1320}
1321
1322func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1323 return func(s *agentSettings) {
1324 s.providerOptions = providerOptions
1325 }
1326}