open_ai: Disable `parallel_tool_calls` (#28056)

Marshall Bowers created

This PR disables `parallel_tool_calls` for the models that support it,
as the Agent currently expects at most one tool use per turn.

It was a bit of trial and error to figure this out. OpenAI's API
annoyingly will return an error if passing `parallel_tool_calls` to a
model that doesn't support it.

Release Notes:

- N/A

Change summary

crates/language_models/src/provider/cloud.rs   |  2 +-
crates/language_models/src/provider/open_ai.rs | 16 +++++++++++-----
crates/open_ai/src/open_ai.rs                  | 20 ++++++++++++++++++++
3 files changed, 32 insertions(+), 6 deletions(-)

Detailed changes

crates/language_models/src/provider/cloud.rs 🔗

@@ -690,7 +690,7 @@ impl LanguageModel for CloudLanguageModel {
             }
             CloudModel::OpenAi(model) => {
                 let client = self.client.clone();
-                let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
+                let request = into_open_ai(request, model, model.max_output_tokens());
                 let llm_api_token = self.llm_api_token.clone();
                 let future = self.request_limiter.stream(async move {
                     let response = Self::perform_llm_completion(

crates/language_models/src/provider/open_ai.rs 🔗

@@ -14,7 +14,7 @@ use language_model::{
     LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
     RateLimiter, Role, StopReason,
 };
-use open_ai::{ResponseStreamEvent, stream_completion};
+use open_ai::{Model, ResponseStreamEvent, stream_completion};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
@@ -324,7 +324,7 @@ impl LanguageModel for OpenAiLanguageModel {
         'static,
         Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
     > {
-        let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
+        let request = into_open_ai(request, &self.model, self.max_output_tokens());
         let completions = self.stream_completion(request, cx);
         async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
             .boxed()
@@ -333,10 +333,10 @@ impl LanguageModel for OpenAiLanguageModel {
 
 pub fn into_open_ai(
     request: LanguageModelRequest,
-    model: String,
+    model: &Model,
     max_output_tokens: Option<u32>,
 ) -> open_ai::Request {
-    let stream = !model.starts_with("o1-");
+    let stream = !model.id().starts_with("o1-");
 
     let mut messages = Vec::new();
     for message in request.messages {
@@ -389,12 +389,18 @@ pub fn into_open_ai(
     }
 
     open_ai::Request {
-        model,
+        model: model.id().into(),
         messages,
         stream,
         stop: request.stop,
         temperature: request.temperature.unwrap_or(1.0),
         max_tokens: max_output_tokens,
+        parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
+            // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
+            Some(false)
+        } else {
+            None
+        },
         tools: request
             .tools
             .into_iter()

crates/open_ai/src/open_ai.rs 🔗

@@ -162,6 +162,23 @@ impl Model {
             _ => None,
         }
     }
+
+    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
+    ///
+    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
+    pub fn supports_parallel_tool_calls(&self) -> bool {
+        match self {
+            Self::ThreePointFiveTurbo
+            | Self::Four
+            | Self::FourTurbo
+            | Self::FourOmni
+            | Self::FourOmniMini
+            | Self::O1
+            | Self::O1Preview
+            | Self::O1Mini => true,
+            _ => false,
+        }
+    }
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -176,6 +193,9 @@ pub struct Request {
     pub temperature: f32,
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub tool_choice: Option<ToolChoice>,
+    /// Whether to enable parallel function calling during tool use.
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub parallel_tool_calls: Option<bool>,
     #[serde(default, skip_serializing_if = "Vec::is_empty")]
     pub tools: Vec<ToolDefinition>,
 }