1package ai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12)
13
14type StepResult struct {
15 Response
16 Messages []Message
17}
18
19type StopCondition = func(steps []StepResult) bool
20
21// StepCountIs returns a stop condition that stops after the specified number of steps.
22func StepCountIs(stepCount int) StopCondition {
23 return func(steps []StepResult) bool {
24 return len(steps) >= stepCount
25 }
26}
27
28// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
29func HasToolCall(toolName string) StopCondition {
30 return func(steps []StepResult) bool {
31 if len(steps) == 0 {
32 return false
33 }
34 lastStep := steps[len(steps)-1]
35 toolCalls := lastStep.Content.ToolCalls()
36 for _, toolCall := range toolCalls {
37 if toolCall.ToolName == toolName {
38 return true
39 }
40 }
41 return false
42 }
43}
44
45// HasContent returns a stop condition that stops when the specified content type appears in the last step.
46func HasContent(contentType ContentType) StopCondition {
47 return func(steps []StepResult) bool {
48 if len(steps) == 0 {
49 return false
50 }
51 lastStep := steps[len(steps)-1]
52 for _, content := range lastStep.Content {
53 if content.GetType() == contentType {
54 return true
55 }
56 }
57 return false
58 }
59}
60
61// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
62func FinishReasonIs(reason FinishReason) StopCondition {
63 return func(steps []StepResult) bool {
64 if len(steps) == 0 {
65 return false
66 }
67 lastStep := steps[len(steps)-1]
68 return lastStep.FinishReason == reason
69 }
70}
71
72// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
73func MaxTokensUsed(maxTokens int64) StopCondition {
74 return func(steps []StepResult) bool {
75 var totalTokens int64
76 for _, step := range steps {
77 totalTokens += step.Usage.TotalTokens
78 }
79 return totalTokens >= maxTokens
80 }
81}
82
83type PrepareStepFunctionOptions struct {
84 Steps []StepResult
85 StepNumber int
86 Model LanguageModel
87 Messages []Message
88}
89
90type PrepareStepResult struct {
91 Model LanguageModel
92 Messages []Message
93 System *string
94 ToolChoice *ToolChoice
95 ActiveTools []string
96 DisableAllTools bool
97}
98
99type ToolCallRepairOptions struct {
100 OriginalToolCall ToolCallContent
101 ValidationError error
102 AvailableTools []AgentTool
103 SystemPrompt string
104 Messages []Message
105}
106
107type (
108 PrepareStepFunction = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
109 OnStepFinishedFunction = func(step StepResult)
110 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
111)
112
113type agentSettings struct {
114 systemPrompt string
115 maxOutputTokens *int64
116 temperature *float64
117 topP *float64
118 topK *int64
119 presencePenalty *float64
120 frequencyPenalty *float64
121 headers map[string]string
122 providerOptions ProviderOptions
123
124 // TODO: add support for provider tools
125 tools []AgentTool
126 maxRetries *int
127
128 model LanguageModel
129
130 stopWhen []StopCondition
131 prepareStep PrepareStepFunction
132 repairToolCall RepairToolCallFunction
133 onRetry OnRetryCallback
134}
135
136type AgentCall struct {
137 Prompt string `json:"prompt"`
138 Files []FilePart `json:"files"`
139 Messages []Message `json:"messages"`
140 MaxOutputTokens *int64
141 Temperature *float64 `json:"temperature"`
142 TopP *float64 `json:"top_p"`
143 TopK *int64 `json:"top_k"`
144 PresencePenalty *float64 `json:"presence_penalty"`
145 FrequencyPenalty *float64 `json:"frequency_penalty"`
146 ActiveTools []string `json:"active_tools"`
147 Headers map[string]string
148 ProviderOptions ProviderOptions
149 OnRetry OnRetryCallback
150 MaxRetries *int
151
152 StopWhen []StopCondition
153 PrepareStep PrepareStepFunction
154 RepairToolCall RepairToolCallFunction
155}
156
157// Agent-level callbacks.
158type (
159 // OnAgentStartFunc is called when agent starts.
160 OnAgentStartFunc func()
161
162 // OnAgentFinishFunc is called when agent finishes.
163 OnAgentFinishFunc func(result *AgentResult) error
164
165 // OnStepStartFunc is called when a step starts.
166 OnStepStartFunc func(stepNumber int) error
167
168 // OnStepFinishFunc is called when a step finishes.
169 OnStepFinishFunc func(stepResult StepResult) error
170
171 // OnFinishFunc is called when entire agent completes.
172 OnFinishFunc func(result *AgentResult)
173
174 // OnErrorFunc is called when an error occurs.
175 OnErrorFunc func(error)
176)
177
178// Stream part callbacks - called for each corresponding stream part type.
179type (
180 // OnChunkFunc is called for each stream part (catch-all).
181 OnChunkFunc func(StreamPart) error
182
183 // OnWarningsFunc is called for warnings.
184 OnWarningsFunc func(warnings []CallWarning) error
185
186 // OnTextStartFunc is called when text starts.
187 OnTextStartFunc func(id string) error
188
189 // OnTextDeltaFunc is called for text deltas.
190 OnTextDeltaFunc func(id, text string) error
191
192 // OnTextEndFunc is called when text ends.
193 OnTextEndFunc func(id string) error
194
195 // OnReasoningStartFunc is called when reasoning starts.
196 OnReasoningStartFunc func(id string) error
197
198 // OnReasoningDeltaFunc is called for reasoning deltas.
199 OnReasoningDeltaFunc func(id, text string) error
200
201 // OnReasoningEndFunc is called when reasoning ends.
202 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
203
204 // OnToolInputStartFunc is called when tool input starts.
205 OnToolInputStartFunc func(id, toolName string) error
206
207 // OnToolInputDeltaFunc is called for tool input deltas.
208 OnToolInputDeltaFunc func(id, delta string) error
209
210 // OnToolInputEndFunc is called when tool input ends.
211 OnToolInputEndFunc func(id string) error
212
213 // OnToolCallFunc is called when tool call is complete.
214 OnToolCallFunc func(toolCall ToolCallContent) error
215
216 // OnToolResultFunc is called when tool execution completes.
217 OnToolResultFunc func(result ToolResultContent) error
218
219 // OnSourceFunc is called for source references.
220 OnSourceFunc func(source SourceContent) error
221
222 // OnStreamFinishFunc is called when stream finishes.
223 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
224)
225
226type AgentStreamCall struct {
227 Prompt string `json:"prompt"`
228 Files []FilePart `json:"files"`
229 Messages []Message `json:"messages"`
230 MaxOutputTokens *int64
231 Temperature *float64 `json:"temperature"`
232 TopP *float64 `json:"top_p"`
233 TopK *int64 `json:"top_k"`
234 PresencePenalty *float64 `json:"presence_penalty"`
235 FrequencyPenalty *float64 `json:"frequency_penalty"`
236 ActiveTools []string `json:"active_tools"`
237 Headers map[string]string
238 ProviderOptions ProviderOptions
239 OnRetry OnRetryCallback
240 MaxRetries *int
241
242 StopWhen []StopCondition
243 PrepareStep PrepareStepFunction
244 RepairToolCall RepairToolCallFunction
245
246 // Agent-level callbacks
247 OnAgentStart OnAgentStartFunc // Called when agent starts
248 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
249 OnStepStart OnStepStartFunc // Called when a step starts
250 OnStepFinish OnStepFinishFunc // Called when a step finishes
251 OnFinish OnFinishFunc // Called when entire agent completes
252 OnError OnErrorFunc // Called when an error occurs
253
254 // Stream part callbacks - called for each corresponding stream part type
255 OnChunk OnChunkFunc // Called for each stream part (catch-all)
256 OnWarnings OnWarningsFunc // Called for warnings
257 OnTextStart OnTextStartFunc // Called when text starts
258 OnTextDelta OnTextDeltaFunc // Called for text deltas
259 OnTextEnd OnTextEndFunc // Called when text ends
260 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
261 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
262 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
263 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
264 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
265 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
266 OnToolCall OnToolCallFunc // Called when tool call is complete
267 OnToolResult OnToolResultFunc // Called when tool execution completes
268 OnSource OnSourceFunc // Called for source references
269 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
270}
271
272type AgentResult struct {
273 Steps []StepResult
274 // Final response
275 Response Response
276 TotalUsage Usage
277}
278
279type Agent interface {
280 Generate(context.Context, AgentCall) (*AgentResult, error)
281 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
282}
283
284type AgentOption = func(*agentSettings)
285
286type agent struct {
287 settings agentSettings
288}
289
290func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
291 settings := agentSettings{
292 model: model,
293 }
294 for _, o := range opts {
295 o(&settings)
296 }
297 return &agent{
298 settings: settings,
299 }
300}
301
302func (a *agent) prepareCall(call AgentCall) AgentCall {
303 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
304 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
305 call.TopP = cmp.Or(call.TopP, a.settings.topP)
306 call.TopK = cmp.Or(call.TopK, a.settings.topK)
307 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
308 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
309 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
310
311 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
312 call.StopWhen = a.settings.stopWhen
313 }
314 if call.PrepareStep == nil && a.settings.prepareStep != nil {
315 call.PrepareStep = a.settings.prepareStep
316 }
317 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
318 call.RepairToolCall = a.settings.repairToolCall
319 }
320 if call.OnRetry == nil && a.settings.onRetry != nil {
321 call.OnRetry = a.settings.onRetry
322 }
323
324 providerOptions := ProviderOptions{}
325 if a.settings.providerOptions != nil {
326 maps.Copy(providerOptions, a.settings.providerOptions)
327 }
328 if call.ProviderOptions != nil {
329 maps.Copy(providerOptions, call.ProviderOptions)
330 }
331 call.ProviderOptions = providerOptions
332
333 headers := map[string]string{}
334
335 if a.settings.headers != nil {
336 maps.Copy(headers, a.settings.headers)
337 }
338
339 if call.Headers != nil {
340 maps.Copy(headers, call.Headers)
341 }
342 call.Headers = headers
343 return call
344}
345
346// Generate implements Agent.
347func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
348 opts = a.prepareCall(opts)
349 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
350 if err != nil {
351 return nil, err
352 }
353 var responseMessages []Message
354 var steps []StepResult
355
356 for {
357 stepInputMessages := append(initialPrompt, responseMessages...)
358 stepModel := a.settings.model
359 stepSystemPrompt := a.settings.systemPrompt
360 stepActiveTools := opts.ActiveTools
361 stepToolChoice := ToolChoiceAuto
362 disableAllTools := false
363
364 if opts.PrepareStep != nil {
365 prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
366 Model: stepModel,
367 Steps: steps,
368 StepNumber: len(steps),
369 Messages: stepInputMessages,
370 })
371 if err != nil {
372 return nil, err
373 }
374
375 // Apply prepared step modifications
376 if prepared.Messages != nil {
377 stepInputMessages = prepared.Messages
378 }
379 if prepared.Model != nil {
380 stepModel = prepared.Model
381 }
382 if prepared.System != nil {
383 stepSystemPrompt = *prepared.System
384 }
385 if prepared.ToolChoice != nil {
386 stepToolChoice = *prepared.ToolChoice
387 }
388 if len(prepared.ActiveTools) > 0 {
389 stepActiveTools = prepared.ActiveTools
390 }
391 disableAllTools = prepared.DisableAllTools
392 }
393
394 // Recreate prompt with potentially modified system prompt
395 if stepSystemPrompt != a.settings.systemPrompt {
396 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
397 if err != nil {
398 return nil, err
399 }
400 // Replace system message part, keep the rest
401 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
402 stepInputMessages[0] = stepPrompt[0] // Replace system message
403 }
404 }
405
406 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
407
408 retryOptions := DefaultRetryOptions()
409 retryOptions.OnRetry = opts.OnRetry
410 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
411
412 result, err := retry(ctx, func() (*Response, error) {
413 return stepModel.Generate(ctx, Call{
414 Prompt: stepInputMessages,
415 MaxOutputTokens: opts.MaxOutputTokens,
416 Temperature: opts.Temperature,
417 TopP: opts.TopP,
418 TopK: opts.TopK,
419 PresencePenalty: opts.PresencePenalty,
420 FrequencyPenalty: opts.FrequencyPenalty,
421 Tools: preparedTools,
422 ToolChoice: &stepToolChoice,
423 Headers: opts.Headers,
424 ProviderOptions: opts.ProviderOptions,
425 })
426 })
427 if err != nil {
428 return nil, err
429 }
430
431 var stepToolCalls []ToolCallContent
432 for _, content := range result.Content {
433 if content.GetType() == ContentTypeToolCall {
434 toolCall, ok := AsContentType[ToolCallContent](content)
435 if !ok {
436 continue
437 }
438
439 // Validate and potentially repair the tool call
440 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
441 stepToolCalls = append(stepToolCalls, validatedToolCall)
442 }
443 }
444
445 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
446
447 // Build step content with validated tool calls and tool results
448 stepContent := []Content{}
449 toolCallIndex := 0
450 for _, content := range result.Content {
451 if content.GetType() == ContentTypeToolCall {
452 // Replace with validated tool call
453 if toolCallIndex < len(stepToolCalls) {
454 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
455 toolCallIndex++
456 }
457 } else {
458 // Keep other content as-is
459 stepContent = append(stepContent, content)
460 }
461 }
462 // Add tool results
463 for _, result := range toolResults {
464 stepContent = append(stepContent, result)
465 }
466 currentStepMessages := toResponseMessages(stepContent)
467 responseMessages = append(responseMessages, currentStepMessages...)
468
469 stepResult := StepResult{
470 Response: Response{
471 Content: stepContent,
472 FinishReason: result.FinishReason,
473 Usage: result.Usage,
474 Warnings: result.Warnings,
475 ProviderMetadata: result.ProviderMetadata,
476 },
477 Messages: currentStepMessages,
478 }
479 steps = append(steps, stepResult)
480 shouldStop := isStopConditionMet(opts.StopWhen, steps)
481
482 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
483 break
484 }
485 }
486
487 totalUsage := Usage{}
488
489 for _, step := range steps {
490 usage := step.Usage
491 totalUsage.InputTokens += usage.InputTokens
492 totalUsage.OutputTokens += usage.OutputTokens
493 totalUsage.ReasoningTokens += usage.ReasoningTokens
494 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
495 totalUsage.CacheReadTokens += usage.CacheReadTokens
496 totalUsage.TotalTokens += usage.TotalTokens
497 }
498
499 agentResult := &AgentResult{
500 Steps: steps,
501 Response: steps[len(steps)-1].Response,
502 TotalUsage: totalUsage,
503 }
504 return agentResult, nil
505}
506
507func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
508 if len(conditions) == 0 {
509 return false
510 }
511
512 for _, condition := range conditions {
513 if condition(steps) {
514 return true
515 }
516 }
517 return false
518}
519
520func toResponseMessages(content []Content) []Message {
521 var assistantParts []MessagePart
522 var toolParts []MessagePart
523
524 for _, c := range content {
525 switch c.GetType() {
526 case ContentTypeText:
527 text, ok := AsContentType[TextContent](c)
528 if !ok {
529 continue
530 }
531 assistantParts = append(assistantParts, TextPart{
532 Text: text.Text,
533 ProviderOptions: ProviderOptions(text.ProviderMetadata),
534 })
535 case ContentTypeReasoning:
536 reasoning, ok := AsContentType[ReasoningContent](c)
537 if !ok {
538 continue
539 }
540 assistantParts = append(assistantParts, ReasoningPart{
541 Text: reasoning.Text,
542 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
543 })
544 case ContentTypeToolCall:
545 toolCall, ok := AsContentType[ToolCallContent](c)
546 if !ok {
547 continue
548 }
549 assistantParts = append(assistantParts, ToolCallPart{
550 ToolCallID: toolCall.ToolCallID,
551 ToolName: toolCall.ToolName,
552 Input: toolCall.Input,
553 ProviderExecuted: toolCall.ProviderExecuted,
554 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
555 })
556 case ContentTypeFile:
557 file, ok := AsContentType[FileContent](c)
558 if !ok {
559 continue
560 }
561 assistantParts = append(assistantParts, FilePart{
562 Data: file.Data,
563 MediaType: file.MediaType,
564 ProviderOptions: ProviderOptions(file.ProviderMetadata),
565 })
566 case ContentTypeSource:
567 // Sources are metadata about references used to generate the response.
568 // They don't need to be included in the conversation messages.
569 continue
570 case ContentTypeToolResult:
571 result, ok := AsContentType[ToolResultContent](c)
572 if !ok {
573 continue
574 }
575 toolParts = append(toolParts, ToolResultPart{
576 ToolCallID: result.ToolCallID,
577 Output: result.Result,
578 ProviderOptions: ProviderOptions(result.ProviderMetadata),
579 })
580 }
581 }
582
583 var messages []Message
584 if len(assistantParts) > 0 {
585 messages = append(messages, Message{
586 Role: MessageRoleAssistant,
587 Content: assistantParts,
588 })
589 }
590 if len(toolParts) > 0 {
591 messages = append(messages, Message{
592 Role: MessageRoleTool,
593 Content: toolParts,
594 })
595 }
596 return messages
597}
598
599func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
600 if len(toolCalls) == 0 {
601 return nil, nil
602 }
603
604 // Create a map for quick tool lookup
605 toolMap := make(map[string]AgentTool)
606 for _, tool := range allTools {
607 toolMap[tool.Info().Name] = tool
608 }
609
610 // Execute all tool calls in parallel
611 results := make([]ToolResultContent, len(toolCalls))
612 executeErrors := make([]error, len(toolCalls))
613
614 var wg sync.WaitGroup
615
616 for i, toolCall := range toolCalls {
617 wg.Add(1)
618 go func(index int, call ToolCallContent) {
619 defer wg.Done()
620
621 // Skip invalid tool calls - create error result
622 if call.Invalid {
623 results[index] = ToolResultContent{
624 ToolCallID: call.ToolCallID,
625 ToolName: call.ToolName,
626 Result: ToolResultOutputContentError{
627 Error: call.ValidationError,
628 },
629 ProviderExecuted: false,
630 }
631 if toolResultCallback != nil {
632 err := toolResultCallback(results[index])
633 if err != nil {
634 executeErrors[index] = err
635 }
636 }
637
638 return
639 }
640
641 tool, exists := toolMap[call.ToolName]
642 if !exists {
643 results[index] = ToolResultContent{
644 ToolCallID: call.ToolCallID,
645 ToolName: call.ToolName,
646 Result: ToolResultOutputContentError{
647 Error: errors.New("Error: Tool not found: " + call.ToolName),
648 },
649 ProviderExecuted: false,
650 }
651
652 if toolResultCallback != nil {
653 err := toolResultCallback(results[index])
654 if err != nil {
655 executeErrors[index] = err
656 }
657 }
658 return
659 }
660
661 // Execute the tool
662 result, err := tool.Run(ctx, ToolCall{
663 ID: call.ToolCallID,
664 Name: call.ToolName,
665 Input: call.Input,
666 })
667 if err != nil {
668 results[index] = ToolResultContent{
669 ToolCallID: call.ToolCallID,
670 ToolName: call.ToolName,
671 Result: ToolResultOutputContentError{
672 Error: err,
673 },
674 ClientMetadata: result.Metadata,
675 ProviderExecuted: false,
676 }
677 if toolResultCallback != nil {
678 cbErr := toolResultCallback(results[index])
679 if cbErr != nil {
680 executeErrors[index] = cbErr
681 }
682 }
683 executeErrors[index] = err
684 return
685 }
686
687 if result.IsError {
688 results[index] = ToolResultContent{
689 ToolCallID: call.ToolCallID,
690 ToolName: call.ToolName,
691 Result: ToolResultOutputContentError{
692 Error: errors.New(result.Content),
693 },
694 ClientMetadata: result.Metadata,
695 ProviderExecuted: false,
696 }
697
698 if toolResultCallback != nil {
699 err := toolResultCallback(results[index])
700 if err != nil {
701 executeErrors[index] = err
702 }
703 }
704 } else {
705 results[index] = ToolResultContent{
706 ToolCallID: call.ToolCallID,
707 ToolName: toolCall.ToolName,
708 Result: ToolResultOutputContentText{
709 Text: result.Content,
710 },
711 ClientMetadata: result.Metadata,
712 ProviderExecuted: false,
713 }
714 if toolResultCallback != nil {
715 err := toolResultCallback(results[index])
716 if err != nil {
717 executeErrors[index] = err
718 }
719 }
720 }
721 }(i, toolCall)
722 }
723
724 // Wait for all tool executions to complete
725 wg.Wait()
726
727 for _, err := range executeErrors {
728 if err != nil {
729 return nil, err
730 }
731 }
732
733 return results, nil
734}
735
736// Stream implements Agent.
737func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
738 // Convert AgentStreamCall to AgentCall for preparation
739 call := AgentCall{
740 Prompt: opts.Prompt,
741 Files: opts.Files,
742 Messages: opts.Messages,
743 MaxOutputTokens: opts.MaxOutputTokens,
744 Temperature: opts.Temperature,
745 TopP: opts.TopP,
746 TopK: opts.TopK,
747 PresencePenalty: opts.PresencePenalty,
748 FrequencyPenalty: opts.FrequencyPenalty,
749 ActiveTools: opts.ActiveTools,
750 Headers: opts.Headers,
751 ProviderOptions: opts.ProviderOptions,
752 MaxRetries: opts.MaxRetries,
753 StopWhen: opts.StopWhen,
754 PrepareStep: opts.PrepareStep,
755 RepairToolCall: opts.RepairToolCall,
756 }
757
758 call = a.prepareCall(call)
759
760 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
761 if err != nil {
762 return nil, err
763 }
764
765 var responseMessages []Message
766 var steps []StepResult
767 var totalUsage Usage
768
769 // Start agent stream
770 if opts.OnAgentStart != nil {
771 opts.OnAgentStart()
772 }
773
774 for stepNumber := 0; ; stepNumber++ {
775 stepInputMessages := append(initialPrompt, responseMessages...)
776 stepModel := a.settings.model
777 stepSystemPrompt := a.settings.systemPrompt
778 stepActiveTools := call.ActiveTools
779 stepToolChoice := ToolChoiceAuto
780 disableAllTools := false
781
782 // Apply step preparation if provided
783 if call.PrepareStep != nil {
784 prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
785 Model: stepModel,
786 Steps: steps,
787 StepNumber: stepNumber,
788 Messages: stepInputMessages,
789 })
790 if err != nil {
791 return nil, err
792 }
793
794 if prepared.Messages != nil {
795 stepInputMessages = prepared.Messages
796 }
797 if prepared.Model != nil {
798 stepModel = prepared.Model
799 }
800 if prepared.System != nil {
801 stepSystemPrompt = *prepared.System
802 }
803 if prepared.ToolChoice != nil {
804 stepToolChoice = *prepared.ToolChoice
805 }
806 if len(prepared.ActiveTools) > 0 {
807 stepActiveTools = prepared.ActiveTools
808 }
809 disableAllTools = prepared.DisableAllTools
810 }
811
812 // Recreate prompt with potentially modified system prompt
813 if stepSystemPrompt != a.settings.systemPrompt {
814 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
815 if err != nil {
816 return nil, err
817 }
818 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
819 stepInputMessages[0] = stepPrompt[0]
820 }
821 }
822
823 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
824
825 // Start step stream
826 if opts.OnStepStart != nil {
827 _ = opts.OnStepStart(stepNumber)
828 }
829
830 // Create streaming call
831 streamCall := Call{
832 Prompt: stepInputMessages,
833 MaxOutputTokens: call.MaxOutputTokens,
834 Temperature: call.Temperature,
835 TopP: call.TopP,
836 TopK: call.TopK,
837 PresencePenalty: call.PresencePenalty,
838 FrequencyPenalty: call.FrequencyPenalty,
839 Tools: preparedTools,
840 ToolChoice: &stepToolChoice,
841 Headers: call.Headers,
842 ProviderOptions: call.ProviderOptions,
843 }
844
845 // Get streaming response
846 stream, err := stepModel.Stream(ctx, streamCall)
847 if err != nil {
848 if opts.OnError != nil {
849 opts.OnError(err)
850 }
851 return nil, err
852 }
853
854 // Process stream with tool execution
855 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
856 if err != nil {
857 if opts.OnError != nil {
858 opts.OnError(err)
859 }
860 return nil, err
861 }
862
863 steps = append(steps, stepResult)
864 totalUsage = addUsage(totalUsage, stepResult.Usage)
865
866 // Call step finished callback
867 if opts.OnStepFinish != nil {
868 _ = opts.OnStepFinish(stepResult)
869 }
870
871 // Add step messages to response messages
872 stepMessages := toResponseMessages(stepResult.Content)
873 responseMessages = append(responseMessages, stepMessages...)
874
875 // Check stop conditions
876 shouldStop := isStopConditionMet(call.StopWhen, steps)
877 if shouldStop || !shouldContinue {
878 break
879 }
880 }
881
882 // Finish agent stream
883 agentResult := &AgentResult{
884 Steps: steps,
885 Response: steps[len(steps)-1].Response,
886 TotalUsage: totalUsage,
887 }
888
889 if opts.OnFinish != nil {
890 opts.OnFinish(agentResult)
891 }
892
893 if opts.OnAgentFinish != nil {
894 _ = opts.OnAgentFinish(agentResult)
895 }
896
897 return agentResult, nil
898}
899
900func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
901 preparedTools := make([]Tool, 0, len(tools))
902
903 // If explicitly disabling all tools, return no tools
904 if disableAllTools {
905 return preparedTools
906 }
907
908 for _, tool := range tools {
909 // If activeTools has items, only include tools in the list
910 // If activeTools is empty, include all tools
911 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
912 continue
913 }
914 info := tool.Info()
915 preparedTools = append(preparedTools, FunctionTool{
916 Name: info.Name,
917 Description: info.Description,
918 InputSchema: map[string]any{
919 "type": "object",
920 "properties": info.Parameters,
921 "required": info.Required,
922 },
923 })
924 }
925 return preparedTools
926}
927
928// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
929func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
930 if err := a.validateToolCall(toolCall, availableTools); err == nil {
931 return toolCall
932 } else {
933 if repairFunc != nil {
934 repairOptions := ToolCallRepairOptions{
935 OriginalToolCall: toolCall,
936 ValidationError: err,
937 AvailableTools: availableTools,
938 SystemPrompt: systemPrompt,
939 Messages: messages,
940 }
941
942 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
943 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
944 return *repairedToolCall
945 }
946 }
947 }
948
949 invalidToolCall := toolCall
950 invalidToolCall.Invalid = true
951 invalidToolCall.ValidationError = err
952 return invalidToolCall
953 }
954}
955
956// validateToolCall validates a tool call against available tools and their schemas.
957func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
958 var tool AgentTool
959 for _, t := range availableTools {
960 if t.Info().Name == toolCall.ToolName {
961 tool = t
962 break
963 }
964 }
965
966 if tool == nil {
967 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
968 }
969
970 // Validate JSON parsing
971 var input map[string]any
972 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
973 return fmt.Errorf("invalid JSON input: %w", err)
974 }
975
976 // Basic schema validation (check required fields)
977 // TODO: more robust schema validation using JSON Schema or similar
978 toolInfo := tool.Info()
979 for _, required := range toolInfo.Required {
980 if _, exists := input[required]; !exists {
981 return fmt.Errorf("missing required parameter: %s", required)
982 }
983 }
984 return nil
985}
986
987func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
988 if prompt == "" {
989 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
990 }
991
992 var preparedPrompt Prompt
993
994 if system != "" {
995 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
996 }
997
998 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
999 preparedPrompt = append(preparedPrompt, messages...)
1000 return preparedPrompt, nil
1001}
1002
1003func WithSystemPrompt(prompt string) AgentOption {
1004 return func(s *agentSettings) {
1005 s.systemPrompt = prompt
1006 }
1007}
1008
1009func WithMaxOutputTokens(tokens int64) AgentOption {
1010 return func(s *agentSettings) {
1011 s.maxOutputTokens = &tokens
1012 }
1013}
1014
1015func WithTemperature(temp float64) AgentOption {
1016 return func(s *agentSettings) {
1017 s.temperature = &temp
1018 }
1019}
1020
1021func WithTopP(topP float64) AgentOption {
1022 return func(s *agentSettings) {
1023 s.topP = &topP
1024 }
1025}
1026
1027func WithTopK(topK int64) AgentOption {
1028 return func(s *agentSettings) {
1029 s.topK = &topK
1030 }
1031}
1032
1033func WithPresencePenalty(penalty float64) AgentOption {
1034 return func(s *agentSettings) {
1035 s.presencePenalty = &penalty
1036 }
1037}
1038
1039func WithFrequencyPenalty(penalty float64) AgentOption {
1040 return func(s *agentSettings) {
1041 s.frequencyPenalty = &penalty
1042 }
1043}
1044
1045func WithTools(tools ...AgentTool) AgentOption {
1046 return func(s *agentSettings) {
1047 s.tools = append(s.tools, tools...)
1048 }
1049}
1050
1051func WithStopConditions(conditions ...StopCondition) AgentOption {
1052 return func(s *agentSettings) {
1053 s.stopWhen = append(s.stopWhen, conditions...)
1054 }
1055}
1056
1057func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1058 return func(s *agentSettings) {
1059 s.prepareStep = fn
1060 }
1061}
1062
1063func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1064 return func(s *agentSettings) {
1065 s.repairToolCall = fn
1066 }
1067}
1068
1069// processStepStream processes a single step's stream and returns the step result.
1070func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1071 var stepContent []Content
1072 var stepToolCalls []ToolCallContent
1073 var stepUsage Usage
1074 stepFinishReason := FinishReasonUnknown
1075 var stepWarnings []CallWarning
1076 var stepProviderMetadata ProviderMetadata
1077
1078 activeToolCalls := make(map[string]*ToolCallContent)
1079 activeTextContent := make(map[string]string)
1080
1081 // Process stream parts
1082 for part := range stream {
1083 // Forward all parts to chunk callback
1084 if opts.OnChunk != nil {
1085 err := opts.OnChunk(part)
1086 if err != nil {
1087 return StepResult{}, false, err
1088 }
1089 }
1090
1091 switch part.Type {
1092 case StreamPartTypeWarnings:
1093 stepWarnings = part.Warnings
1094 if opts.OnWarnings != nil {
1095 err := opts.OnWarnings(part.Warnings)
1096 if err != nil {
1097 return StepResult{}, false, err
1098 }
1099 }
1100
1101 case StreamPartTypeTextStart:
1102 activeTextContent[part.ID] = ""
1103 if opts.OnTextStart != nil {
1104 err := opts.OnTextStart(part.ID)
1105 if err != nil {
1106 return StepResult{}, false, err
1107 }
1108 }
1109
1110 case StreamPartTypeTextDelta:
1111 if _, exists := activeTextContent[part.ID]; exists {
1112 activeTextContent[part.ID] += part.Delta
1113 }
1114 if opts.OnTextDelta != nil {
1115 err := opts.OnTextDelta(part.ID, part.Delta)
1116 if err != nil {
1117 return StepResult{}, false, err
1118 }
1119 }
1120
1121 case StreamPartTypeTextEnd:
1122 if text, exists := activeTextContent[part.ID]; exists {
1123 stepContent = append(stepContent, TextContent{
1124 Text: text,
1125 ProviderMetadata: part.ProviderMetadata,
1126 })
1127 delete(activeTextContent, part.ID)
1128 }
1129 if opts.OnTextEnd != nil {
1130 err := opts.OnTextEnd(part.ID)
1131 if err != nil {
1132 return StepResult{}, false, err
1133 }
1134 }
1135
1136 case StreamPartTypeReasoningStart:
1137 activeTextContent[part.ID] = ""
1138 if opts.OnReasoningStart != nil {
1139 err := opts.OnReasoningStart(part.ID)
1140 if err != nil {
1141 return StepResult{}, false, err
1142 }
1143 }
1144
1145 case StreamPartTypeReasoningDelta:
1146 if _, exists := activeTextContent[part.ID]; exists {
1147 activeTextContent[part.ID] += part.Delta
1148 }
1149 if opts.OnReasoningDelta != nil {
1150 err := opts.OnReasoningDelta(part.ID, part.Delta)
1151 if err != nil {
1152 return StepResult{}, false, err
1153 }
1154 }
1155
1156 case StreamPartTypeReasoningEnd:
1157 if text, exists := activeTextContent[part.ID]; exists {
1158 stepContent = append(stepContent, ReasoningContent{
1159 Text: text,
1160 ProviderMetadata: part.ProviderMetadata,
1161 })
1162 if opts.OnReasoningEnd != nil {
1163 err := opts.OnReasoningEnd(part.ID, ReasoningContent{
1164 Text: text,
1165 ProviderMetadata: part.ProviderMetadata,
1166 })
1167 if err != nil {
1168 return StepResult{}, false, err
1169 }
1170 }
1171 delete(activeTextContent, part.ID)
1172 }
1173
1174 case StreamPartTypeToolInputStart:
1175 activeToolCalls[part.ID] = &ToolCallContent{
1176 ToolCallID: part.ID,
1177 ToolName: part.ToolCallName,
1178 Input: "",
1179 ProviderExecuted: part.ProviderExecuted,
1180 }
1181 if opts.OnToolInputStart != nil {
1182 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1183 if err != nil {
1184 return StepResult{}, false, err
1185 }
1186 }
1187
1188 case StreamPartTypeToolInputDelta:
1189 if toolCall, exists := activeToolCalls[part.ID]; exists {
1190 toolCall.Input += part.Delta
1191 }
1192 if opts.OnToolInputDelta != nil {
1193 err := opts.OnToolInputDelta(part.ID, part.Delta)
1194 if err != nil {
1195 return StepResult{}, false, err
1196 }
1197 }
1198
1199 case StreamPartTypeToolInputEnd:
1200 if opts.OnToolInputEnd != nil {
1201 err := opts.OnToolInputEnd(part.ID)
1202 if err != nil {
1203 return StepResult{}, false, err
1204 }
1205 }
1206
1207 case StreamPartTypeToolCall:
1208 toolCall := ToolCallContent{
1209 ToolCallID: part.ID,
1210 ToolName: part.ToolCallName,
1211 Input: part.ToolCallInput,
1212 ProviderExecuted: part.ProviderExecuted,
1213 ProviderMetadata: part.ProviderMetadata,
1214 }
1215
1216 // Validate and potentially repair the tool call
1217 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1218 stepToolCalls = append(stepToolCalls, validatedToolCall)
1219 stepContent = append(stepContent, validatedToolCall)
1220
1221 if opts.OnToolCall != nil {
1222 err := opts.OnToolCall(validatedToolCall)
1223 if err != nil {
1224 return StepResult{}, false, err
1225 }
1226 }
1227
1228 // Clean up active tool call
1229 delete(activeToolCalls, part.ID)
1230
1231 case StreamPartTypeSource:
1232 sourceContent := SourceContent{
1233 SourceType: part.SourceType,
1234 ID: part.ID,
1235 URL: part.URL,
1236 Title: part.Title,
1237 ProviderMetadata: part.ProviderMetadata,
1238 }
1239 stepContent = append(stepContent, sourceContent)
1240 if opts.OnSource != nil {
1241 err := opts.OnSource(sourceContent)
1242 if err != nil {
1243 return StepResult{}, false, err
1244 }
1245 }
1246
1247 case StreamPartTypeFinish:
1248 stepUsage = part.Usage
1249 stepFinishReason = part.FinishReason
1250 stepProviderMetadata = part.ProviderMetadata
1251 if opts.OnStreamFinish != nil {
1252 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1253 if err != nil {
1254 return StepResult{}, false, err
1255 }
1256 }
1257
1258 case StreamPartTypeError:
1259 if opts.OnError != nil {
1260 opts.OnError(part.Error)
1261 }
1262 return StepResult{}, false, part.Error
1263 }
1264 }
1265
1266 // Execute tools if any
1267 var toolResults []ToolResultContent
1268 if len(stepToolCalls) > 0 {
1269 var err error
1270 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1271 if err != nil {
1272 return StepResult{}, false, err
1273 }
1274 // Add tool results to content
1275 for _, result := range toolResults {
1276 stepContent = append(stepContent, result)
1277 }
1278 }
1279
1280 stepResult := StepResult{
1281 Response: Response{
1282 Content: stepContent,
1283 FinishReason: stepFinishReason,
1284 Usage: stepUsage,
1285 Warnings: stepWarnings,
1286 ProviderMetadata: stepProviderMetadata,
1287 },
1288 Messages: toResponseMessages(stepContent),
1289 }
1290
1291 // Determine if we should continue (has tool calls and not stopped)
1292 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1293
1294 return stepResult, shouldContinue, nil
1295}
1296
1297func addUsage(a, b Usage) Usage {
1298 return Usage{
1299 InputTokens: a.InputTokens + b.InputTokens,
1300 OutputTokens: a.OutputTokens + b.OutputTokens,
1301 TotalTokens: a.TotalTokens + b.TotalTokens,
1302 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1303 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1304 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1305 }
1306}
1307
1308func WithHeaders(headers map[string]string) AgentOption {
1309 return func(s *agentSettings) {
1310 s.headers = headers
1311 }
1312}
1313
1314func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1315 return func(s *agentSettings) {
1316 s.providerOptions = providerOptions
1317 }
1318}