1package fantasy
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "sync"
12
13 "charm.land/fantasy/schema"
14)
15
16// StepResult represents the result of a single step in an agent execution.
17type StepResult struct {
18 Response
19 Messages []Message
20}
21
22// stepExecutionResult encapsulates the result of executing a step with stream processing.
23type stepExecutionResult struct {
24 StepResult StepResult
25 ShouldContinue bool
26}
27
28// StopCondition defines a function that determines when an agent should stop executing.
29type StopCondition = func(steps []StepResult) bool
30
31// StepCountIs returns a stop condition that stops after the specified number of steps.
32func StepCountIs(stepCount int) StopCondition {
33 return func(steps []StepResult) bool {
34 return len(steps) >= stepCount
35 }
36}
37
38// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
39func HasToolCall(toolName string) StopCondition {
40 return func(steps []StepResult) bool {
41 if len(steps) == 0 {
42 return false
43 }
44 lastStep := steps[len(steps)-1]
45 toolCalls := lastStep.Content.ToolCalls()
46 for _, toolCall := range toolCalls {
47 if toolCall.ToolName == toolName {
48 return true
49 }
50 }
51 return false
52 }
53}
54
55// HasContent returns a stop condition that stops when the specified content type appears in the last step.
56func HasContent(contentType ContentType) StopCondition {
57 return func(steps []StepResult) bool {
58 if len(steps) == 0 {
59 return false
60 }
61 lastStep := steps[len(steps)-1]
62 for _, content := range lastStep.Content {
63 if content.GetType() == contentType {
64 return true
65 }
66 }
67 return false
68 }
69}
70
71// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
72func FinishReasonIs(reason FinishReason) StopCondition {
73 return func(steps []StepResult) bool {
74 if len(steps) == 0 {
75 return false
76 }
77 lastStep := steps[len(steps)-1]
78 return lastStep.FinishReason == reason
79 }
80}
81
82// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
83func MaxTokensUsed(maxTokens int64) StopCondition {
84 return func(steps []StepResult) bool {
85 var totalTokens int64
86 for _, step := range steps {
87 totalTokens += step.Usage.TotalTokens
88 }
89 return totalTokens >= maxTokens
90 }
91}
92
93// PrepareStepFunctionOptions contains the options for preparing a step in an agent execution.
94type PrepareStepFunctionOptions struct {
95 Steps []StepResult
96 StepNumber int
97 Model LanguageModel
98 Messages []Message
99}
100
101// PrepareStepResult contains the result of preparing a step in an agent execution.
102type PrepareStepResult struct {
103 Model LanguageModel
104 Messages []Message
105 System *string
106 ToolChoice *ToolChoice
107 ActiveTools []string
108 DisableAllTools bool
109 Tools []AgentTool
110}
111
112// ToolCallRepairOptions contains the options for repairing a tool call.
113type ToolCallRepairOptions struct {
114 OriginalToolCall ToolCallContent
115 ValidationError error
116 AvailableTools []AgentTool
117 SystemPrompt string
118 Messages []Message
119}
120
121type (
122 // PrepareStepFunction defines a function that prepares a step in an agent execution.
123 PrepareStepFunction = func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error)
124
125 // OnStepFinishedFunction defines a function that is called when a step finishes.
126 OnStepFinishedFunction = func(step StepResult)
127
128 // RepairToolCallFunction defines a function that repairs a tool call.
129 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
130)
131
132type agentSettings struct {
133 systemPrompt string
134 maxOutputTokens *int64
135 temperature *float64
136 topP *float64
137 topK *int64
138 presencePenalty *float64
139 frequencyPenalty *float64
140 headers map[string]string
141 userAgent string
142 modelSegment string
143 providerOptions ProviderOptions
144
145 // TODO: add support for provider tools
146 tools []AgentTool
147 maxRetries *int
148
149 model LanguageModel
150
151 stopWhen []StopCondition
152 prepareStep PrepareStepFunction
153 repairToolCall RepairToolCallFunction
154 onRetry OnRetryCallback
155}
156
157// AgentCall represents a call to an agent.
158type AgentCall struct {
159 Prompt string `json:"prompt"`
160 Files []FilePart `json:"files"`
161 Messages []Message `json:"messages"`
162 MaxOutputTokens *int64
163 Temperature *float64 `json:"temperature"`
164 TopP *float64 `json:"top_p"`
165 TopK *int64 `json:"top_k"`
166 PresencePenalty *float64 `json:"presence_penalty"`
167 FrequencyPenalty *float64 `json:"frequency_penalty"`
168 ActiveTools []string `json:"active_tools"`
169 ProviderOptions ProviderOptions
170 OnRetry OnRetryCallback
171 MaxRetries *int
172
173 StopWhen []StopCondition
174 PrepareStep PrepareStepFunction
175 RepairToolCall RepairToolCallFunction
176}
177
178// Agent-level callbacks.
179type (
180 // OnAgentStartFunc is called when agent starts.
181 OnAgentStartFunc func()
182
183 // OnAgentFinishFunc is called when agent finishes.
184 OnAgentFinishFunc func(result *AgentResult) error
185
186 // OnStepStartFunc is called when a step starts.
187 OnStepStartFunc func(stepNumber int) error
188
189 // OnStepFinishFunc is called when a step finishes.
190 OnStepFinishFunc func(stepResult StepResult) error
191
192 // OnFinishFunc is called when entire agent completes.
193 OnFinishFunc func(result *AgentResult)
194
195 // OnErrorFunc is called when an error occurs.
196 OnErrorFunc func(error)
197)
198
199// Stream part callbacks - called for each corresponding stream part type.
200type (
201 // OnChunkFunc is called for each stream part (catch-all).
202 OnChunkFunc func(StreamPart) error
203
204 // OnWarningsFunc is called for warnings.
205 OnWarningsFunc func(warnings []CallWarning) error
206
207 // OnTextStartFunc is called when text starts.
208 OnTextStartFunc func(id string) error
209
210 // OnTextDeltaFunc is called for text deltas.
211 OnTextDeltaFunc func(id, text string) error
212
213 // OnTextEndFunc is called when text ends.
214 OnTextEndFunc func(id string) error
215
216 // OnReasoningStartFunc is called when reasoning starts.
217 OnReasoningStartFunc func(id string, reasoning ReasoningContent) error
218
219 // OnReasoningDeltaFunc is called for reasoning deltas.
220 OnReasoningDeltaFunc func(id, text string) error
221
222 // OnReasoningEndFunc is called when reasoning ends.
223 OnReasoningEndFunc func(id string, reasoning ReasoningContent) error
224
225 // OnToolInputStartFunc is called when tool input starts.
226 OnToolInputStartFunc func(id, toolName string) error
227
228 // OnToolInputDeltaFunc is called for tool input deltas.
229 OnToolInputDeltaFunc func(id, delta string) error
230
231 // OnToolInputEndFunc is called when tool input ends.
232 OnToolInputEndFunc func(id string) error
233
234 // OnToolCallFunc is called when tool call is complete.
235 OnToolCallFunc func(toolCall ToolCallContent) error
236
237 // OnToolResultFunc is called when tool execution completes.
238 OnToolResultFunc func(result ToolResultContent) error
239
240 // OnSourceFunc is called for source references.
241 OnSourceFunc func(source SourceContent) error
242
243 // OnStreamFinishFunc is called when stream finishes.
244 OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
245)
246
247// AgentStreamCall represents a streaming call to an agent.
248type AgentStreamCall struct {
249 Prompt string `json:"prompt"`
250 Files []FilePart `json:"files"`
251 Messages []Message `json:"messages"`
252 MaxOutputTokens *int64
253 Temperature *float64 `json:"temperature"`
254 TopP *float64 `json:"top_p"`
255 TopK *int64 `json:"top_k"`
256 PresencePenalty *float64 `json:"presence_penalty"`
257 FrequencyPenalty *float64 `json:"frequency_penalty"`
258 ActiveTools []string `json:"active_tools"`
259 Headers map[string]string
260 ProviderOptions ProviderOptions
261 OnRetry OnRetryCallback
262 MaxRetries *int
263
264 StopWhen []StopCondition
265 PrepareStep PrepareStepFunction
266 RepairToolCall RepairToolCallFunction
267
268 // Agent-level callbacks
269 OnAgentStart OnAgentStartFunc // Called when agent starts
270 OnAgentFinish OnAgentFinishFunc // Called when agent finishes
271 OnStepStart OnStepStartFunc // Called when a step starts
272 OnStepFinish OnStepFinishFunc // Called when a step finishes
273 OnFinish OnFinishFunc // Called when entire agent completes
274 OnError OnErrorFunc // Called when an error occurs
275
276 // Stream part callbacks - called for each corresponding stream part type
277 OnChunk OnChunkFunc // Called for each stream part (catch-all)
278 OnWarnings OnWarningsFunc // Called for warnings
279 OnTextStart OnTextStartFunc // Called when text starts
280 OnTextDelta OnTextDeltaFunc // Called for text deltas
281 OnTextEnd OnTextEndFunc // Called when text ends
282 OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
283 OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
284 OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
285 OnToolInputStart OnToolInputStartFunc // Called when tool input starts
286 OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
287 OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
288 OnToolCall OnToolCallFunc // Called when tool call is complete
289 OnToolResult OnToolResultFunc // Called when tool execution completes
290 OnSource OnSourceFunc // Called for source references
291 OnStreamFinish OnStreamFinishFunc // Called when stream finishes
292}
293
294// AgentResult represents the result of an agent execution.
295type AgentResult struct {
296 Steps []StepResult
297 // Final response
298 Response Response
299 TotalUsage Usage
300}
301
302// Agent represents an AI agent that can generate responses and stream responses.
303type Agent interface {
304 Generate(context.Context, AgentCall) (*AgentResult, error)
305 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
306}
307
308// AgentOption defines a function that configures agent settings.
309type AgentOption = func(*agentSettings)
310
311type agent struct {
312 settings agentSettings
313}
314
315// NewAgent creates a new agent with the given language model and options.
316func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
317 settings := agentSettings{
318 model: model,
319 }
320 for _, o := range opts {
321 o(&settings)
322 }
323 return &agent{
324 settings: settings,
325 }
326}
327
328func (a *agent) prepareCall(call AgentCall) AgentCall {
329 call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
330 call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
331 call.TopP = cmp.Or(call.TopP, a.settings.topP)
332 call.TopK = cmp.Or(call.TopK, a.settings.topK)
333 call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
334 call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
335 call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
336
337 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
338 call.StopWhen = a.settings.stopWhen
339 }
340 if call.PrepareStep == nil && a.settings.prepareStep != nil {
341 call.PrepareStep = a.settings.prepareStep
342 }
343 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
344 call.RepairToolCall = a.settings.repairToolCall
345 }
346 if call.OnRetry == nil && a.settings.onRetry != nil {
347 call.OnRetry = a.settings.onRetry
348 }
349
350 providerOptions := ProviderOptions{}
351 if a.settings.providerOptions != nil {
352 maps.Copy(providerOptions, a.settings.providerOptions)
353 }
354 if call.ProviderOptions != nil {
355 maps.Copy(providerOptions, call.ProviderOptions)
356 }
357 call.ProviderOptions = providerOptions
358
359 headers := map[string]string{}
360
361 if a.settings.headers != nil {
362 maps.Copy(headers, a.settings.headers)
363 }
364
365 return call
366}
367
368// Generate implements Agent.
369func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
370 opts = a.prepareCall(opts)
371 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
372 if err != nil {
373 return nil, err
374 }
375 var responseMessages []Message
376 var steps []StepResult
377
378 for {
379 stepInputMessages := append(initialPrompt, responseMessages...)
380 stepModel := a.settings.model
381 stepSystemPrompt := a.settings.systemPrompt
382 stepActiveTools := opts.ActiveTools
383 stepToolChoice := ToolChoiceAuto
384 disableAllTools := false
385 stepTools := a.settings.tools
386 if opts.PrepareStep != nil {
387 updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
388 Model: stepModel,
389 Steps: steps,
390 StepNumber: len(steps),
391 Messages: stepInputMessages,
392 })
393 if err != nil {
394 return nil, err
395 }
396
397 ctx = updatedCtx
398
399 // Apply prepared step modifications
400 if prepared.Messages != nil {
401 stepInputMessages = prepared.Messages
402 }
403 if prepared.Model != nil {
404 stepModel = prepared.Model
405 }
406 if prepared.System != nil {
407 stepSystemPrompt = *prepared.System
408 }
409 if prepared.ToolChoice != nil {
410 stepToolChoice = *prepared.ToolChoice
411 }
412 if len(prepared.ActiveTools) > 0 {
413 stepActiveTools = prepared.ActiveTools
414 }
415 disableAllTools = prepared.DisableAllTools
416 if prepared.Tools != nil {
417 stepTools = prepared.Tools
418 }
419 }
420
421 // Recreate prompt with potentially modified system prompt
422 if stepSystemPrompt != a.settings.systemPrompt {
423 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
424 if err != nil {
425 return nil, err
426 }
427 // Replace system message part, keep the rest
428 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
429 stepInputMessages[0] = stepPrompt[0] // Replace system message
430 }
431 }
432
433 preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
434
435 retryOptions := DefaultRetryOptions()
436 if opts.MaxRetries != nil {
437 retryOptions.MaxRetries = *opts.MaxRetries
438 }
439 retryOptions.OnRetry = opts.OnRetry
440 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
441
442 result, err := retry(ctx, func() (*Response, error) {
443 return stepModel.Generate(ctx, Call{
444 Prompt: stepInputMessages,
445 MaxOutputTokens: opts.MaxOutputTokens,
446 Temperature: opts.Temperature,
447 TopP: opts.TopP,
448 TopK: opts.TopK,
449 PresencePenalty: opts.PresencePenalty,
450 FrequencyPenalty: opts.FrequencyPenalty,
451 Tools: preparedTools,
452 ToolChoice: &stepToolChoice,
453 UserAgent: a.settings.userAgent,
454 ModelSegment: a.settings.modelSegment,
455 ProviderOptions: opts.ProviderOptions,
456 })
457 })
458 if err != nil {
459 return nil, err
460 }
461
462 var stepToolCalls []ToolCallContent
463 for _, content := range result.Content {
464 if content.GetType() == ContentTypeToolCall {
465 toolCall, ok := AsContentType[ToolCallContent](content)
466 if !ok {
467 continue
468 }
469
470 // Validate and potentially repair the tool call
471 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
472 stepToolCalls = append(stepToolCalls, validatedToolCall)
473 }
474 }
475
476 toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
477
478 // Build step content with validated tool calls and tool results
479 stepContent := []Content{}
480 toolCallIndex := 0
481 for _, content := range result.Content {
482 if content.GetType() == ContentTypeToolCall {
483 // Replace with validated tool call
484 if toolCallIndex < len(stepToolCalls) {
485 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
486 toolCallIndex++
487 }
488 } else {
489 // Keep other content as-is
490 stepContent = append(stepContent, content)
491 }
492 }
493 // Add tool results
494 for _, result := range toolResults {
495 stepContent = append(stepContent, result)
496 }
497 currentStepMessages := toResponseMessages(stepContent)
498 responseMessages = append(responseMessages, currentStepMessages...)
499
500 stepResult := StepResult{
501 Response: Response{
502 Content: stepContent,
503 FinishReason: result.FinishReason,
504 Usage: result.Usage,
505 Warnings: result.Warnings,
506 ProviderMetadata: result.ProviderMetadata,
507 },
508 Messages: currentStepMessages,
509 }
510 steps = append(steps, stepResult)
511 shouldStop := isStopConditionMet(opts.StopWhen, steps)
512
513 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
514 break
515 }
516 }
517
518 totalUsage := Usage{}
519
520 for _, step := range steps {
521 usage := step.Usage
522 totalUsage.InputTokens += usage.InputTokens
523 totalUsage.OutputTokens += usage.OutputTokens
524 totalUsage.ReasoningTokens += usage.ReasoningTokens
525 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
526 totalUsage.CacheReadTokens += usage.CacheReadTokens
527 totalUsage.TotalTokens += usage.TotalTokens
528 }
529
530 agentResult := &AgentResult{
531 Steps: steps,
532 Response: steps[len(steps)-1].Response,
533 TotalUsage: totalUsage,
534 }
535 return agentResult, nil
536}
537
538func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
539 if len(conditions) == 0 {
540 return false
541 }
542
543 for _, condition := range conditions {
544 if condition(steps) {
545 return true
546 }
547 }
548 return false
549}
550
551func toResponseMessages(content []Content) []Message {
552 var assistantParts []MessagePart
553 var toolParts []MessagePart
554
555 for _, c := range content {
556 switch c.GetType() {
557 case ContentTypeText:
558 text, ok := AsContentType[TextContent](c)
559 if !ok {
560 continue
561 }
562 assistantParts = append(assistantParts, TextPart{
563 Text: text.Text,
564 ProviderOptions: ProviderOptions(text.ProviderMetadata),
565 })
566 case ContentTypeReasoning:
567 reasoning, ok := AsContentType[ReasoningContent](c)
568 if !ok {
569 continue
570 }
571 assistantParts = append(assistantParts, ReasoningPart{
572 Text: reasoning.Text,
573 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
574 })
575 case ContentTypeToolCall:
576 toolCall, ok := AsContentType[ToolCallContent](c)
577 if !ok {
578 continue
579 }
580 assistantParts = append(assistantParts, ToolCallPart{
581 ToolCallID: toolCall.ToolCallID,
582 ToolName: toolCall.ToolName,
583 Input: toolCall.Input,
584 ProviderExecuted: toolCall.ProviderExecuted,
585 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
586 })
587 case ContentTypeFile:
588 file, ok := AsContentType[FileContent](c)
589 if !ok {
590 continue
591 }
592 assistantParts = append(assistantParts, FilePart{
593 Data: file.Data,
594 MediaType: file.MediaType,
595 ProviderOptions: ProviderOptions(file.ProviderMetadata),
596 })
597 case ContentTypeSource:
598 // Sources are metadata about references used to generate the response.
599 // They don't need to be included in the conversation messages.
600 continue
601 case ContentTypeToolResult:
602 result, ok := AsContentType[ToolResultContent](c)
603 if !ok {
604 continue
605 }
606 toolParts = append(toolParts, ToolResultPart{
607 ToolCallID: result.ToolCallID,
608 Output: result.Result,
609 ProviderOptions: ProviderOptions(result.ProviderMetadata),
610 })
611 }
612 }
613
614 var messages []Message
615 if len(assistantParts) > 0 {
616 messages = append(messages, Message{
617 Role: MessageRoleAssistant,
618 Content: assistantParts,
619 })
620 }
621 if len(toolParts) > 0 {
622 messages = append(messages, Message{
623 Role: MessageRoleTool,
624 Content: toolParts,
625 })
626 }
627 return messages
628}
629
630func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
631 if len(toolCalls) == 0 {
632 return nil, nil
633 }
634
635 // Create a map for quick tool lookup
636 toolMap := make(map[string]AgentTool)
637 for _, tool := range allTools {
638 toolMap[tool.Info().Name] = tool
639 }
640
641 // Execute all tool calls sequentially in order
642 results := make([]ToolResultContent, 0, len(toolCalls))
643
644 for _, toolCall := range toolCalls {
645 result, isCriticalError := a.executeSingleTool(ctx, toolMap, toolCall, toolResultCallback)
646 results = append(results, result)
647 if isCriticalError {
648 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
649 return nil, errorResult.Error
650 }
651 }
652 }
653
654 return results, nil
655}
656
657// executeSingleTool executes a single tool and returns its result and a critical error flag.
658func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentTool, toolCall ToolCallContent, toolResultCallback func(result ToolResultContent) error) (ToolResultContent, bool) {
659 result := ToolResultContent{
660 ToolCallID: toolCall.ToolCallID,
661 ToolName: toolCall.ToolName,
662 ProviderExecuted: false,
663 }
664
665 // Skip invalid tool calls - create error result (not critical)
666 if toolCall.Invalid {
667 result.Result = ToolResultOutputContentError{
668 Error: toolCall.ValidationError,
669 }
670 if toolResultCallback != nil {
671 _ = toolResultCallback(result)
672 }
673 return result, false
674 }
675
676 tool, exists := toolMap[toolCall.ToolName]
677 if !exists {
678 result.Result = ToolResultOutputContentError{
679 Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
680 }
681 if toolResultCallback != nil {
682 _ = toolResultCallback(result)
683 }
684 return result, false
685 }
686
687 // Execute the tool
688 toolResult, err := tool.Run(ctx, ToolCall{
689 ID: toolCall.ToolCallID,
690 Name: toolCall.ToolName,
691 Input: toolCall.Input,
692 })
693 if err != nil {
694 result.Result = ToolResultOutputContentError{
695 Error: err,
696 }
697 result.ClientMetadata = toolResult.Metadata
698 if toolResultCallback != nil {
699 _ = toolResultCallback(result)
700 }
701 return result, true
702 }
703
704 result.ClientMetadata = toolResult.Metadata
705 if toolResult.IsError {
706 result.Result = ToolResultOutputContentError{
707 Error: errors.New(toolResult.Content),
708 }
709 } else if toolResult.Type == "image" || toolResult.Type == "media" {
710 result.Result = ToolResultOutputContentMedia{
711 Data: string(toolResult.Data),
712 MediaType: toolResult.MediaType,
713 Text: toolResult.Content,
714 }
715 } else {
716 result.Result = ToolResultOutputContentText{
717 Text: toolResult.Content,
718 }
719 }
720 if toolResultCallback != nil {
721 _ = toolResultCallback(result)
722 }
723 return result, false
724}
725
726// Stream implements Agent.
727func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
728 // Convert AgentStreamCall to AgentCall for preparation
729 call := AgentCall{
730 Prompt: opts.Prompt,
731 Files: opts.Files,
732 Messages: opts.Messages,
733 MaxOutputTokens: opts.MaxOutputTokens,
734 Temperature: opts.Temperature,
735 TopP: opts.TopP,
736 TopK: opts.TopK,
737 PresencePenalty: opts.PresencePenalty,
738 FrequencyPenalty: opts.FrequencyPenalty,
739 ActiveTools: opts.ActiveTools,
740 ProviderOptions: opts.ProviderOptions,
741 MaxRetries: opts.MaxRetries,
742 OnRetry: opts.OnRetry,
743 StopWhen: opts.StopWhen,
744 PrepareStep: opts.PrepareStep,
745 RepairToolCall: opts.RepairToolCall,
746 }
747
748 call = a.prepareCall(call)
749
750 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
751 if err != nil {
752 return nil, err
753 }
754
755 var responseMessages []Message
756 var steps []StepResult
757 var totalUsage Usage
758
759 // Start agent stream
760 if opts.OnAgentStart != nil {
761 opts.OnAgentStart()
762 }
763
764 for stepNumber := 0; ; stepNumber++ {
765 stepInputMessages := append(initialPrompt, responseMessages...)
766 stepModel := a.settings.model
767 stepSystemPrompt := a.settings.systemPrompt
768 stepActiveTools := call.ActiveTools
769 stepToolChoice := ToolChoiceAuto
770 disableAllTools := false
771 stepTools := a.settings.tools
772 // Apply step preparation if provided
773 if call.PrepareStep != nil {
774 updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
775 Model: stepModel,
776 Steps: steps,
777 StepNumber: stepNumber,
778 Messages: stepInputMessages,
779 })
780 if err != nil {
781 return nil, err
782 }
783
784 ctx = updatedCtx
785
786 if prepared.Messages != nil {
787 stepInputMessages = prepared.Messages
788 }
789 if prepared.Model != nil {
790 stepModel = prepared.Model
791 }
792 if prepared.System != nil {
793 stepSystemPrompt = *prepared.System
794 }
795 if prepared.ToolChoice != nil {
796 stepToolChoice = *prepared.ToolChoice
797 }
798 if len(prepared.ActiveTools) > 0 {
799 stepActiveTools = prepared.ActiveTools
800 }
801 disableAllTools = prepared.DisableAllTools
802 if prepared.Tools != nil {
803 stepTools = prepared.Tools
804 }
805 }
806
807 // Recreate prompt with potentially modified system prompt
808 if stepSystemPrompt != a.settings.systemPrompt {
809 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
810 if err != nil {
811 return nil, err
812 }
813 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
814 stepInputMessages[0] = stepPrompt[0]
815 }
816 }
817
818 preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
819
820 // Start step stream
821 if opts.OnStepStart != nil {
822 _ = opts.OnStepStart(stepNumber)
823 }
824
825 // Create streaming call
826 streamCall := Call{
827 Prompt: stepInputMessages,
828 MaxOutputTokens: call.MaxOutputTokens,
829 Temperature: call.Temperature,
830 TopP: call.TopP,
831 TopK: call.TopK,
832 PresencePenalty: call.PresencePenalty,
833 FrequencyPenalty: call.FrequencyPenalty,
834 Tools: preparedTools,
835 ToolChoice: &stepToolChoice,
836 UserAgent: a.settings.userAgent,
837 ModelSegment: a.settings.modelSegment,
838 ProviderOptions: call.ProviderOptions,
839 }
840
841 // Execute step with retry logic wrapping both stream creation and processing
842 retryOptions := DefaultRetryOptions()
843 if call.MaxRetries != nil {
844 retryOptions.MaxRetries = *call.MaxRetries
845 }
846 retryOptions.OnRetry = call.OnRetry
847 retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
848
849 result, err := retry(ctx, func() (stepExecutionResult, error) {
850 // Create the stream
851 stream, err := stepModel.Stream(ctx, streamCall)
852 if err != nil {
853 return stepExecutionResult{}, err
854 }
855
856 // Process the stream
857 result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
858 if err != nil {
859 return stepExecutionResult{}, err
860 }
861
862 return result, nil
863 })
864 if err != nil {
865 if opts.OnError != nil {
866 opts.OnError(err)
867 }
868 return nil, err
869 }
870
871 steps = append(steps, result.StepResult)
872 totalUsage = addUsage(totalUsage, result.StepResult.Usage)
873
874 // Call step finished callback
875 if opts.OnStepFinish != nil {
876 _ = opts.OnStepFinish(result.StepResult)
877 }
878
879 // Add step messages to response messages
880 stepMessages := toResponseMessages(result.StepResult.Content)
881 responseMessages = append(responseMessages, stepMessages...)
882
883 // Check stop conditions
884 shouldStop := isStopConditionMet(call.StopWhen, steps)
885 if shouldStop || !result.ShouldContinue {
886 break
887 }
888 }
889
890 // Finish agent stream
891 agentResult := &AgentResult{
892 Steps: steps,
893 Response: steps[len(steps)-1].Response,
894 TotalUsage: totalUsage,
895 }
896
897 if opts.OnFinish != nil {
898 opts.OnFinish(agentResult)
899 }
900
901 if opts.OnAgentFinish != nil {
902 _ = opts.OnAgentFinish(agentResult)
903 }
904
905 return agentResult, nil
906}
907
908func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
909 preparedTools := make([]Tool, 0, len(tools))
910
911 // If explicitly disabling all tools, return no tools
912 if disableAllTools {
913 return preparedTools
914 }
915
916 for _, tool := range tools {
917 // If activeTools has items, only include tools in the list
918 // If activeTools is empty, include all tools
919 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
920 continue
921 }
922 info := tool.Info()
923 inputSchema := map[string]any{
924 "type": "object",
925 "properties": info.Parameters,
926 "required": info.Required,
927 }
928 schema.Normalize(inputSchema)
929 preparedTools = append(preparedTools, FunctionTool{
930 Name: info.Name,
931 Description: info.Description,
932 InputSchema: inputSchema,
933 ProviderOptions: tool.ProviderOptions(),
934 })
935 }
936 return preparedTools
937}
938
939// validateAndRepairToolCall validates a tool call and attempts repair if validation fails.
940func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []AgentTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
941 if err := a.validateToolCall(toolCall, availableTools); err == nil {
942 return toolCall
943 } else { //nolint: revive
944 if repairFunc != nil {
945 repairOptions := ToolCallRepairOptions{
946 OriginalToolCall: toolCall,
947 ValidationError: err,
948 AvailableTools: availableTools,
949 SystemPrompt: systemPrompt,
950 Messages: messages,
951 }
952
953 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
954 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
955 return *repairedToolCall
956 }
957 }
958 }
959
960 invalidToolCall := toolCall
961 invalidToolCall.Invalid = true
962 invalidToolCall.ValidationError = err
963 return invalidToolCall
964 }
965}
966
967// validateToolCall validates a tool call against available tools and their schemas.
968func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []AgentTool) error {
969 var tool AgentTool
970 for _, t := range availableTools {
971 if t.Info().Name == toolCall.ToolName {
972 tool = t
973 break
974 }
975 }
976
977 if tool == nil {
978 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
979 }
980
981 // Validate JSON parsing
982 var input map[string]any
983 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
984 return fmt.Errorf("invalid JSON input: %w", err)
985 }
986
987 // Basic schema validation (check required fields)
988 // TODO: more robust schema validation using JSON Schema or similar
989 toolInfo := tool.Info()
990 for _, required := range toolInfo.Required {
991 if _, exists := input[required]; !exists {
992 return fmt.Errorf("missing required parameter: %s", required)
993 }
994 }
995 return nil
996}
997
998func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
999 if prompt == "" {
1000 return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
1001 }
1002
1003 var preparedPrompt Prompt
1004
1005 if system != "" {
1006 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
1007 }
1008 preparedPrompt = append(preparedPrompt, messages...)
1009 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1010 return preparedPrompt, nil
1011}
1012
1013// WithSystemPrompt sets the system prompt for the agent.
1014func WithSystemPrompt(prompt string) AgentOption {
1015 return func(s *agentSettings) {
1016 s.systemPrompt = prompt
1017 }
1018}
1019
1020// WithMaxOutputTokens sets the maximum output tokens for the agent.
1021func WithMaxOutputTokens(tokens int64) AgentOption {
1022 return func(s *agentSettings) {
1023 s.maxOutputTokens = &tokens
1024 }
1025}
1026
1027// WithTemperature sets the temperature for the agent.
1028func WithTemperature(temp float64) AgentOption {
1029 return func(s *agentSettings) {
1030 s.temperature = &temp
1031 }
1032}
1033
1034// WithTopP sets the top-p value for the agent.
1035func WithTopP(topP float64) AgentOption {
1036 return func(s *agentSettings) {
1037 s.topP = &topP
1038 }
1039}
1040
1041// WithTopK sets the top-k value for the agent.
1042func WithTopK(topK int64) AgentOption {
1043 return func(s *agentSettings) {
1044 s.topK = &topK
1045 }
1046}
1047
1048// WithPresencePenalty sets the presence penalty for the agent.
1049func WithPresencePenalty(penalty float64) AgentOption {
1050 return func(s *agentSettings) {
1051 s.presencePenalty = &penalty
1052 }
1053}
1054
1055// WithFrequencyPenalty sets the frequency penalty for the agent.
1056func WithFrequencyPenalty(penalty float64) AgentOption {
1057 return func(s *agentSettings) {
1058 s.frequencyPenalty = &penalty
1059 }
1060}
1061
1062// WithTools sets the tools for the agent.
1063func WithTools(tools ...AgentTool) AgentOption {
1064 return func(s *agentSettings) {
1065 s.tools = append(s.tools, tools...)
1066 }
1067}
1068
1069// WithStopConditions sets the stop conditions for the agent.
1070func WithStopConditions(conditions ...StopCondition) AgentOption {
1071 return func(s *agentSettings) {
1072 s.stopWhen = append(s.stopWhen, conditions...)
1073 }
1074}
1075
1076// WithPrepareStep sets the prepare step function for the agent.
1077func WithPrepareStep(fn PrepareStepFunction) AgentOption {
1078 return func(s *agentSettings) {
1079 s.prepareStep = fn
1080 }
1081}
1082
1083// WithRepairToolCall sets the repair tool call function for the agent.
1084func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
1085 return func(s *agentSettings) {
1086 s.repairToolCall = fn
1087 }
1088}
1089
1090// WithMaxRetries sets the maximum number of retries for the agent.
1091func WithMaxRetries(maxRetries int) AgentOption {
1092 return func(s *agentSettings) {
1093 s.maxRetries = &maxRetries
1094 }
1095}
1096
1097// WithOnRetry sets the retry callback for the agent.
1098func WithOnRetry(callback OnRetryCallback) AgentOption {
1099 return func(s *agentSettings) {
1100 s.onRetry = callback
1101 }
1102}
1103
1104// processStepStream processes a single step's stream and returns the step result.
1105func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
1106 var stepContent []Content
1107 var stepToolCalls []ToolCallContent
1108 var stepUsage Usage
1109 stepFinishReason := FinishReasonUnknown
1110 var stepWarnings []CallWarning
1111 var stepProviderMetadata ProviderMetadata
1112
1113 activeToolCalls := make(map[string]*ToolCallContent)
1114 activeTextContent := make(map[string]string)
1115 type reasoningContent struct {
1116 content string
1117 options ProviderMetadata
1118 }
1119 activeReasoningContent := make(map[string]reasoningContent)
1120
1121 // Set up concurrent tool execution
1122 type toolExecutionRequest struct {
1123 toolCall ToolCallContent
1124 parallel bool
1125 }
1126 toolChan := make(chan toolExecutionRequest, 10)
1127 var toolExecutionWg sync.WaitGroup
1128 var toolStateMu sync.Mutex
1129 toolResults := make([]ToolResultContent, 0)
1130 var toolExecutionErr error
1131
1132 // Create a map for quick tool lookup
1133 toolMap := make(map[string]AgentTool)
1134 for _, tool := range stepTools {
1135 toolMap[tool.Info().Name] = tool
1136 }
1137
1138 // Semaphores for controlling parallelism
1139 parallelSem := make(chan struct{}, 5)
1140 var sequentialMu sync.Mutex
1141
1142 // Single coordinator goroutine that dispatches tools
1143 toolExecutionWg.Go(func() {
1144 for req := range toolChan {
1145 if req.parallel {
1146 parallelSem <- struct{}{}
1147 toolExecutionWg.Go(func() {
1148 defer func() { <-parallelSem }()
1149 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1150 toolStateMu.Lock()
1151 toolResults = append(toolResults, result)
1152 if isCriticalError && toolExecutionErr == nil {
1153 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1154 toolExecutionErr = errorResult.Error
1155 }
1156 }
1157 toolStateMu.Unlock()
1158 })
1159 } else {
1160 sequentialMu.Lock()
1161 result, isCriticalError := a.executeSingleTool(ctx, toolMap, req.toolCall, opts.OnToolResult)
1162 toolStateMu.Lock()
1163 toolResults = append(toolResults, result)
1164 if isCriticalError && toolExecutionErr == nil {
1165 if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
1166 toolExecutionErr = errorResult.Error
1167 }
1168 }
1169 toolStateMu.Unlock()
1170 sequentialMu.Unlock()
1171 }
1172 }
1173 })
1174
1175 // Process stream parts
1176 for part := range stream {
1177 // Forward all parts to chunk callback
1178 if opts.OnChunk != nil {
1179 err := opts.OnChunk(part)
1180 if err != nil {
1181 return stepExecutionResult{}, err
1182 }
1183 }
1184
1185 switch part.Type {
1186 case StreamPartTypeWarnings:
1187 stepWarnings = part.Warnings
1188 if opts.OnWarnings != nil {
1189 err := opts.OnWarnings(part.Warnings)
1190 if err != nil {
1191 return stepExecutionResult{}, err
1192 }
1193 }
1194
1195 case StreamPartTypeTextStart:
1196 activeTextContent[part.ID] = ""
1197 if opts.OnTextStart != nil {
1198 err := opts.OnTextStart(part.ID)
1199 if err != nil {
1200 return stepExecutionResult{}, err
1201 }
1202 }
1203
1204 case StreamPartTypeTextDelta:
1205 if _, exists := activeTextContent[part.ID]; exists {
1206 activeTextContent[part.ID] += part.Delta
1207 }
1208 if opts.OnTextDelta != nil {
1209 err := opts.OnTextDelta(part.ID, part.Delta)
1210 if err != nil {
1211 return stepExecutionResult{}, err
1212 }
1213 }
1214
1215 case StreamPartTypeTextEnd:
1216 if text, exists := activeTextContent[part.ID]; exists {
1217 stepContent = append(stepContent, TextContent{
1218 Text: text,
1219 ProviderMetadata: part.ProviderMetadata,
1220 })
1221 delete(activeTextContent, part.ID)
1222 }
1223 if opts.OnTextEnd != nil {
1224 err := opts.OnTextEnd(part.ID)
1225 if err != nil {
1226 return stepExecutionResult{}, err
1227 }
1228 }
1229
1230 case StreamPartTypeReasoningStart:
1231 activeReasoningContent[part.ID] = reasoningContent{content: part.Delta, options: part.ProviderMetadata}
1232 if opts.OnReasoningStart != nil {
1233 content := ReasoningContent{
1234 Text: part.Delta,
1235 ProviderMetadata: part.ProviderMetadata,
1236 }
1237 err := opts.OnReasoningStart(part.ID, content)
1238 if err != nil {
1239 return stepExecutionResult{}, err
1240 }
1241 }
1242
1243 case StreamPartTypeReasoningDelta:
1244 if active, exists := activeReasoningContent[part.ID]; exists {
1245 active.content += part.Delta
1246 active.options = part.ProviderMetadata
1247 activeReasoningContent[part.ID] = active
1248 }
1249 if opts.OnReasoningDelta != nil {
1250 err := opts.OnReasoningDelta(part.ID, part.Delta)
1251 if err != nil {
1252 return stepExecutionResult{}, err
1253 }
1254 }
1255
1256 case StreamPartTypeReasoningEnd:
1257 if active, exists := activeReasoningContent[part.ID]; exists {
1258 if part.ProviderMetadata != nil {
1259 active.options = part.ProviderMetadata
1260 }
1261 content := ReasoningContent{
1262 Text: active.content,
1263 ProviderMetadata: active.options,
1264 }
1265 stepContent = append(stepContent, content)
1266 if opts.OnReasoningEnd != nil {
1267 err := opts.OnReasoningEnd(part.ID, content)
1268 if err != nil {
1269 return stepExecutionResult{}, err
1270 }
1271 }
1272 delete(activeReasoningContent, part.ID)
1273 }
1274
1275 case StreamPartTypeToolInputStart:
1276 activeToolCalls[part.ID] = &ToolCallContent{
1277 ToolCallID: part.ID,
1278 ToolName: part.ToolCallName,
1279 Input: "",
1280 ProviderExecuted: part.ProviderExecuted,
1281 }
1282 if opts.OnToolInputStart != nil {
1283 err := opts.OnToolInputStart(part.ID, part.ToolCallName)
1284 if err != nil {
1285 return stepExecutionResult{}, err
1286 }
1287 }
1288
1289 case StreamPartTypeToolInputDelta:
1290 if toolCall, exists := activeToolCalls[part.ID]; exists {
1291 toolCall.Input += part.Delta
1292 }
1293 if opts.OnToolInputDelta != nil {
1294 err := opts.OnToolInputDelta(part.ID, part.Delta)
1295 if err != nil {
1296 return stepExecutionResult{}, err
1297 }
1298 }
1299
1300 case StreamPartTypeToolInputEnd:
1301 if opts.OnToolInputEnd != nil {
1302 err := opts.OnToolInputEnd(part.ID)
1303 if err != nil {
1304 return stepExecutionResult{}, err
1305 }
1306 }
1307
1308 case StreamPartTypeToolCall:
1309 toolCall := ToolCallContent{
1310 ToolCallID: part.ID,
1311 ToolName: part.ToolCallName,
1312 Input: part.ToolCallInput,
1313 ProviderExecuted: part.ProviderExecuted,
1314 ProviderMetadata: part.ProviderMetadata,
1315 }
1316
1317 // Validate and potentially repair the tool call
1318 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1319 stepToolCalls = append(stepToolCalls, validatedToolCall)
1320 stepContent = append(stepContent, validatedToolCall)
1321
1322 if opts.OnToolCall != nil {
1323 err := opts.OnToolCall(validatedToolCall)
1324 if err != nil {
1325 return stepExecutionResult{}, err
1326 }
1327 }
1328
1329 // Determine if tool can run in parallel
1330 isParallel := false
1331 if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
1332 isParallel = tool.Info().Parallel
1333 }
1334
1335 // Send tool call to execution channel
1336 toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
1337
1338 // Clean up active tool call
1339 delete(activeToolCalls, part.ID)
1340
1341 case StreamPartTypeSource:
1342 sourceContent := SourceContent{
1343 SourceType: part.SourceType,
1344 ID: part.ID,
1345 URL: part.URL,
1346 Title: part.Title,
1347 ProviderMetadata: part.ProviderMetadata,
1348 }
1349 stepContent = append(stepContent, sourceContent)
1350 if opts.OnSource != nil {
1351 err := opts.OnSource(sourceContent)
1352 if err != nil {
1353 return stepExecutionResult{}, err
1354 }
1355 }
1356
1357 case StreamPartTypeFinish:
1358 stepUsage = part.Usage
1359 stepFinishReason = part.FinishReason
1360 stepProviderMetadata = part.ProviderMetadata
1361 if opts.OnStreamFinish != nil {
1362 err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1363 if err != nil {
1364 return stepExecutionResult{}, err
1365 }
1366 }
1367
1368 case StreamPartTypeError:
1369 return stepExecutionResult{}, part.Error
1370 }
1371 }
1372
1373 // Close the tool execution channel and wait for all executions to complete
1374 close(toolChan)
1375 toolExecutionWg.Wait()
1376
1377 // Check for tool execution errors
1378 if toolExecutionErr != nil {
1379 return stepExecutionResult{}, toolExecutionErr
1380 }
1381
1382 // Add tool results to content if any
1383 if len(toolResults) > 0 {
1384 for _, result := range toolResults {
1385 stepContent = append(stepContent, result)
1386 }
1387 }
1388
1389 stepResult := StepResult{
1390 Response: Response{
1391 Content: stepContent,
1392 FinishReason: stepFinishReason,
1393 Usage: stepUsage,
1394 Warnings: stepWarnings,
1395 ProviderMetadata: stepProviderMetadata,
1396 },
1397 Messages: toResponseMessages(stepContent),
1398 }
1399
1400 // Determine if we should continue (has tool calls and not stopped)
1401 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1402
1403 return stepExecutionResult{
1404 StepResult: stepResult,
1405 ShouldContinue: shouldContinue,
1406 }, nil
1407}
1408
1409func addUsage(a, b Usage) Usage {
1410 return Usage{
1411 InputTokens: a.InputTokens + b.InputTokens,
1412 OutputTokens: a.OutputTokens + b.OutputTokens,
1413 TotalTokens: a.TotalTokens + b.TotalTokens,
1414 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1415 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1416 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1417 }
1418}
1419
1420// WithHeaders sets the headers for the agent.
1421func WithHeaders(headers map[string]string) AgentOption {
1422 return func(s *agentSettings) {
1423 s.headers = headers
1424 }
1425}
1426
1427// WithUserAgent sets the User-Agent header for the agent. This overrides any
1428// provider-level User-Agent setting.
1429func WithUserAgent(ua string) AgentOption {
1430 return func(s *agentSettings) {
1431 s.userAgent = ua
1432 }
1433}
1434
1435// WithModelSegment sets the model segment appended to the default User-Agent
1436// header. The default UA becomes "Fantasy/<version> (<segment>)". An empty
1437// string clears any previously set segment. This is overridden by WithUserAgent
1438// at either the agent or provider level.
1439func WithModelSegment(segment string) AgentOption {
1440 return func(s *agentSettings) {
1441 s.modelSegment = segment
1442 }
1443}
1444
1445// WithProviderOptions sets the provider options for the agent.
1446func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
1447 return func(s *agentSettings) {
1448 s.providerOptions = providerOptions
1449 }
1450}