feat(agent): add `ToolChoice` to `AgentCall`, `AgentStreamCall`, and `WithToolChoice` option (#213)

fwang2002 and wangfeng01 created

Previously ToolChoice could only be set via PrepareStep. Expose it as a
first-class field on AgentCall/AgentStreamCall and a WithToolChoice agent
option, mirroring how Temperature, TopP, and other call params are surfaced.

Precedence: PrepareStep > call.ToolChoice > WithToolChoice > ToolChoiceAuto.

Co-authored-by: wangfeng01 <wangfeng01@genora.com>

Change summary

agent.go | 43 +++++++++++++++++++++++++++++++------------
1 file changed, 31 insertions(+), 12 deletions(-)

Detailed changes

agent.go 🔗

@@ -146,6 +146,7 @@ type agentSettings struct {
 	providerDefinedTools    []ProviderDefinedTool
 	executableProviderTools []ExecutableProviderTool
 	tools                   []AgentTool
+	toolChoice              *ToolChoice
 	maxRetries              *int
 
 	model LanguageModel
@@ -162,12 +163,13 @@ type AgentCall struct {
 	Files            []FilePart `json:"files"`
 	Messages         []Message  `json:"messages"`
 	MaxOutputTokens  *int64
-	Temperature      *float64 `json:"temperature"`
-	TopP             *float64 `json:"top_p"`
-	TopK             *int64   `json:"top_k"`
-	PresencePenalty  *float64 `json:"presence_penalty"`
-	FrequencyPenalty *float64 `json:"frequency_penalty"`
-	ActiveTools      []string `json:"active_tools"`
+	Temperature      *float64    `json:"temperature"`
+	TopP             *float64    `json:"top_p"`
+	TopK             *int64      `json:"top_k"`
+	PresencePenalty  *float64    `json:"presence_penalty"`
+	FrequencyPenalty *float64    `json:"frequency_penalty"`
+	ActiveTools      []string    `json:"active_tools"`
+	ToolChoice       *ToolChoice `json:"tool_choice"`
 	ProviderOptions  ProviderOptions
 	OnRetry          OnRetryCallback
 	MaxRetries       *int
@@ -252,12 +254,13 @@ type AgentStreamCall struct {
 	Files            []FilePart `json:"files"`
 	Messages         []Message  `json:"messages"`
 	MaxOutputTokens  *int64
-	Temperature      *float64 `json:"temperature"`
-	TopP             *float64 `json:"top_p"`
-	TopK             *int64   `json:"top_k"`
-	PresencePenalty  *float64 `json:"presence_penalty"`
-	FrequencyPenalty *float64 `json:"frequency_penalty"`
-	ActiveTools      []string `json:"active_tools"`
+	Temperature      *float64    `json:"temperature"`
+	TopP             *float64    `json:"top_p"`
+	TopK             *int64      `json:"top_k"`
+	PresencePenalty  *float64    `json:"presence_penalty"`
+	FrequencyPenalty *float64    `json:"frequency_penalty"`
+	ActiveTools      []string    `json:"active_tools"`
+	ToolChoice       *ToolChoice `json:"tool_choice"`
 	Headers          map[string]string
 	ProviderOptions  ProviderOptions
 	OnRetry          OnRetryCallback
@@ -335,6 +338,7 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
 	call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
 	call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
 	call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
+	call.ToolChoice = cmp.Or(call.ToolChoice, a.settings.toolChoice)
 
 	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
 		call.StopWhen = a.settings.stopWhen
@@ -383,6 +387,9 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 		stepSystemPrompt := a.settings.systemPrompt
 		stepActiveTools := opts.ActiveTools
 		stepToolChoice := ToolChoiceAuto
+		if opts.ToolChoice != nil {
+			stepToolChoice = *opts.ToolChoice
+		}
 		disableAllTools := false
 		stepTools := a.settings.tools
 		if opts.PrepareStep != nil {
@@ -787,6 +794,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 		PresencePenalty:  opts.PresencePenalty,
 		FrequencyPenalty: opts.FrequencyPenalty,
 		ActiveTools:      opts.ActiveTools,
+		ToolChoice:       opts.ToolChoice,
 		ProviderOptions:  opts.ProviderOptions,
 		MaxRetries:       opts.MaxRetries,
 		OnRetry:          opts.OnRetry,
@@ -817,6 +825,9 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 		stepSystemPrompt := a.settings.systemPrompt
 		stepActiveTools := call.ActiveTools
 		stepToolChoice := ToolChoiceAuto
+		if call.ToolChoice != nil {
+			stepToolChoice = *call.ToolChoice
+		}
 		disableAllTools := false
 		stepTools := a.settings.tools
 		// Apply step preparation if provided
@@ -1206,6 +1217,14 @@ func WithProviderDefinedTools(tools ...ProviderTool) AgentOption {
 	}
 }
 
+// WithToolChoice sets the default tool choice for the agent. It is overridden
+// by the ToolChoice on a specific call, and by PrepareStep at the step level.
+func WithToolChoice(choice ToolChoice) AgentOption {
+	return func(s *agentSettings) {
+		s.toolChoice = &choice
+	}
+}
+
 // WithStopConditions sets the stop conditions for the agent.
 func WithStopConditions(conditions ...StopCondition) AgentOption {
 	return func(s *agentSettings) {