@@ -146,6 +146,7 @@ type agentSettings struct {
providerDefinedTools []ProviderDefinedTool
executableProviderTools []ExecutableProviderTool
tools []AgentTool
+ toolChoice *ToolChoice
maxRetries *int
model LanguageModel
@@ -162,12 +163,13 @@ type AgentCall struct {
Files []FilePart `json:"files"`
Messages []Message `json:"messages"`
MaxOutputTokens *int64
- Temperature *float64 `json:"temperature"`
- TopP *float64 `json:"top_p"`
- TopK *int64 `json:"top_k"`
- PresencePenalty *float64 `json:"presence_penalty"`
- FrequencyPenalty *float64 `json:"frequency_penalty"`
- ActiveTools []string `json:"active_tools"`
+ Temperature *float64 `json:"temperature"`
+ TopP *float64 `json:"top_p"`
+ TopK *int64 `json:"top_k"`
+ PresencePenalty *float64 `json:"presence_penalty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty"`
+ ActiveTools []string `json:"active_tools"`
+ ToolChoice *ToolChoice `json:"tool_choice"`
ProviderOptions ProviderOptions
OnRetry OnRetryCallback
MaxRetries *int
@@ -252,12 +254,13 @@ type AgentStreamCall struct {
Files []FilePart `json:"files"`
Messages []Message `json:"messages"`
MaxOutputTokens *int64
- Temperature *float64 `json:"temperature"`
- TopP *float64 `json:"top_p"`
- TopK *int64 `json:"top_k"`
- PresencePenalty *float64 `json:"presence_penalty"`
- FrequencyPenalty *float64 `json:"frequency_penalty"`
- ActiveTools []string `json:"active_tools"`
+ Temperature *float64 `json:"temperature"`
+ TopP *float64 `json:"top_p"`
+ TopK *int64 `json:"top_k"`
+ PresencePenalty *float64 `json:"presence_penalty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty"`
+ ActiveTools []string `json:"active_tools"`
+ ToolChoice *ToolChoice `json:"tool_choice"`
Headers map[string]string
ProviderOptions ProviderOptions
OnRetry OnRetryCallback
@@ -335,6 +338,7 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
+ call.ToolChoice = cmp.Or(call.ToolChoice, a.settings.toolChoice)
if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
call.StopWhen = a.settings.stopWhen
@@ -383,6 +387,9 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
stepSystemPrompt := a.settings.systemPrompt
stepActiveTools := opts.ActiveTools
stepToolChoice := ToolChoiceAuto
+ if opts.ToolChoice != nil {
+ stepToolChoice = *opts.ToolChoice
+ }
disableAllTools := false
stepTools := a.settings.tools
if opts.PrepareStep != nil {
@@ -787,6 +794,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
PresencePenalty: opts.PresencePenalty,
FrequencyPenalty: opts.FrequencyPenalty,
ActiveTools: opts.ActiveTools,
+ ToolChoice: opts.ToolChoice,
ProviderOptions: opts.ProviderOptions,
MaxRetries: opts.MaxRetries,
OnRetry: opts.OnRetry,
@@ -817,6 +825,9 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
stepSystemPrompt := a.settings.systemPrompt
stepActiveTools := call.ActiveTools
stepToolChoice := ToolChoiceAuto
+ if call.ToolChoice != nil {
+ stepToolChoice = *call.ToolChoice
+ }
disableAllTools := false
stepTools := a.settings.tools
// Apply step preparation if provided
@@ -1206,6 +1217,14 @@ func WithProviderDefinedTools(tools ...ProviderTool) AgentOption {
}
}
+// WithToolChoice sets the default tool choice for the agent. It is overridden
+// by the ToolChoice on a specific call, and by PrepareStep at the step level.
+func WithToolChoice(choice ToolChoice) AgentOption {
+ return func(s *agentSettings) {
+ s.toolChoice = &choice
+ }
+}
+
// WithStopConditions sets the stop conditions for the agent.
func WithStopConditions(conditions ...StopCondition) AgentOption {
return func(s *agentSettings) {