@@ -690,7 +690,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
- let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
+ let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
@@ -14,7 +14,7 @@ use language_model::{
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
};
-use open_ai::{ResponseStreamEvent, stream_completion};
+use open_ai::{Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@@ -324,7 +324,7 @@ impl LanguageModel for OpenAiLanguageModel {
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
- let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
+ let request = into_open_ai(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
.boxed()
@@ -333,10 +333,10 @@ impl LanguageModel for OpenAiLanguageModel {
pub fn into_open_ai(
request: LanguageModelRequest,
- model: String,
+ model: &Model,
max_output_tokens: Option<u32>,
) -> open_ai::Request {
- let stream = !model.starts_with("o1-");
+ let stream = !model.id().starts_with("o1-");
let mut messages = Vec::new();
for message in request.messages {
@@ -389,12 +389,18 @@ pub fn into_open_ai(
}
open_ai::Request {
- model,
+ model: model.id().into(),
messages,
stream,
stop: request.stop,
temperature: request.temperature.unwrap_or(1.0),
max_tokens: max_output_tokens,
+ parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
+ // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
+ Some(false)
+ } else {
+ None
+ },
tools: request
.tools
.into_iter()
@@ -162,6 +162,23 @@ impl Model {
_ => None,
}
}
+
+ /// Returns whether the given model supports the `parallel_tool_calls` parameter.
+ ///
+ /// If the model does not support the parameter, do not pass it up, or the API will return an error.
+ pub fn supports_parallel_tool_calls(&self) -> bool {
+ match self {
+ Self::ThreePointFiveTurbo
+ | Self::Four
+ | Self::FourTurbo
+ | Self::FourOmni
+ | Self::FourOmniMini
+ | Self::O1
+ | Self::O1Preview
+ | Self::O1Mini => true,
+ _ => false,
+ }
+ }
}
#[derive(Debug, Serialize, Deserialize)]
@@ -176,6 +193,9 @@ pub struct Request {
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
+ /// Whether to enable parallel function calling during tool use.
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub parallel_tool_calls: Option<bool>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
}