Cargo.lock 🔗
@@ -8985,6 +8985,7 @@ dependencies = [
"editor",
"extension",
"extension_host",
+ "feature_flags",
"fs",
"futures 0.3.31",
"google_ai",
Marshall Bowers created
This PR adds support for using the OpenAI Responses API through the Zed
provider.
This is gated behind the `open-ai-responses-api` feature flag.
Part of CLO-34.
Release Notes:
- N/A
Cargo.lock | 1
crates/feature_flags/src/flags.rs | 12 +
crates/language_models/Cargo.toml | 1
crates/language_models/src/provider/cloud.rs | 140 +++++++++++++------
crates/language_models/src/provider/open_ai.rs | 2
5 files changed, 111 insertions(+), 45 deletions(-)
@@ -8985,6 +8985,7 @@ dependencies = [
"editor",
"extension",
"extension_host",
+ "feature_flags",
"fs",
"futures 0.3.31",
"google_ai",
@@ -45,3 +45,15 @@ impl FeatureFlag for SubagentsFeatureFlag {
false
}
}
+
+/// Whether to use the OpenAI Responses API format when sending requests to Cloud.
+pub struct OpenAiResponsesApiFeatureFlag;
+
+impl FeatureFlag for OpenAiResponsesApiFeatureFlag {
+ const NAME: &'static str = "open-ai-responses-api";
+
+ fn enabled_for_staff() -> bool {
+ // Add yourself to the flag manually to test it out.
+ false
+ }
+}
@@ -30,6 +30,7 @@ credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
extension.workspace = true
extension_host.workspace = true
+feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
@@ -11,6 +11,7 @@ use cloud_llm_client::{
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
+use feature_flags::{FeatureFlagAppExt as _, OpenAiResponsesApiFeatureFlag};
use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
};
@@ -46,7 +47,10 @@ use crate::provider::anthropic::{
AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
};
use crate::provider::google::{GoogleEventMapper, into_google};
-use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
+use crate::provider::open_ai::{
+ OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
+ into_open_ai_response,
+};
use crate::provider::x_ai::count_xai_tokens;
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
@@ -756,6 +760,7 @@ impl LanguageModel for CloudLanguageModel {
let intent = request.intent;
let mode = request.mode;
let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
+ let use_responses_api = cx.update(|cx| cx.has_flag::<OpenAiResponsesApiFeatureFlag>());
let thinking_allowed = request.thinking_allowed;
let provider_name = provider_name(&self.model.provider);
match self.model.provider {
@@ -807,7 +812,7 @@ impl LanguageModel for CloudLanguageModel {
Box::pin(
response_lines(response, includes_status_messages)
.chain(usage_updated_event(usage))
- .chain(tool_use_limit_reached_event(tool_use_limit_reached)), // .map(|_| {}),
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
&provider_name,
move |event| mapper.map_event(event),
@@ -817,50 +822,97 @@ impl LanguageModel for CloudLanguageModel {
}
cloud_llm_client::LanguageModelProvider::OpenAi => {
let client = self.client.clone();
- let request = into_open_ai(
- request,
- &self.model.id.0,
- self.model.supports_parallel_tool_calls,
- true,
- None,
- None,
- );
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let PerformLlmCompletionResponse {
- response,
- usage,
- includes_status_messages,
- tool_use_limit_reached,
- } = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- app_version,
- CompletionBody {
- thread_id,
- prompt_id,
- intent,
- mode,
- provider: cloud_llm_client::LanguageModelProvider::OpenAi,
- model: request.model.clone(),
- provider_request: serde_json::to_value(&request)
- .map_err(|e| anyhow!(e))?,
- },
- )
- .await?;
- let mut mapper = OpenAiEventMapper::new();
- Ok(map_cloud_completion_events(
- Box::pin(
- response_lines(response, includes_status_messages)
- .chain(usage_updated_event(usage))
- .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
- ),
- &provider_name,
- move |event| mapper.map_event(event),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
+ if use_responses_api {
+ let request = into_open_ai_response(
+ request,
+ &self.model.id.0,
+ self.model.supports_parallel_tool_calls,
+ true,
+ None,
+ None,
+ );
+ let future = self.request_limiter.stream(async move {
+ let PerformLlmCompletionResponse {
+ response,
+ usage,
+ includes_status_messages,
+ tool_use_limit_reached,
+ } = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ intent,
+ mode,
+ provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = OpenAiResponseEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(
+ response_lines(response, includes_status_messages)
+ .chain(usage_updated_event(usage))
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+ ),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ } else {
+ let request = into_open_ai(
+ request,
+ &self.model.id.0,
+ self.model.supports_parallel_tool_calls,
+ true,
+ None,
+ None,
+ );
+ let future = self.request_limiter.stream(async move {
+ let PerformLlmCompletionResponse {
+ response,
+ usage,
+ includes_status_messages,
+ tool_use_limit_reached,
+ } = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ intent,
+ mode,
+ provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = OpenAiEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(
+ response_lines(response, includes_status_messages)
+ .chain(usage_updated_event(usage))
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+ ),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
}
cloud_llm_client::LanguageModelProvider::XAi => {
let client = self.client.clone();
@@ -887,7 +887,7 @@ impl OpenAiResponseEventMapper {
})
}
- fn map_event(
+ pub fn map_event(
&mut self,
event: ResponsesStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {