language_models: Add support for using OpenAI Responses API through Zed provider (#46482)

Marshall Bowers created

This PR adds support for using the OpenAI Responses API through the Zed
provider.

This is gated behind the `open-ai-responses-api` feature flag.

Part of CLO-34.

Release Notes:

- N/A

Change summary

Cargo.lock                                     |   1 
crates/feature_flags/src/flags.rs              |  12 +
crates/language_models/Cargo.toml              |   1 
crates/language_models/src/provider/cloud.rs   | 140 +++++++++++++------
crates/language_models/src/provider/open_ai.rs |   2 
5 files changed, 111 insertions(+), 45 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8985,6 +8985,7 @@ dependencies = [
  "editor",
  "extension",
  "extension_host",
+ "feature_flags",
  "fs",
  "futures 0.3.31",
  "google_ai",

crates/feature_flags/src/flags.rs 🔗

@@ -45,3 +45,15 @@ impl FeatureFlag for SubagentsFeatureFlag {
         false
     }
 }
+
+/// Whether to use the OpenAI Responses API format when sending requests to Cloud.
+pub struct OpenAiResponsesApiFeatureFlag;
+
+impl FeatureFlag for OpenAiResponsesApiFeatureFlag {
+    const NAME: &'static str = "open-ai-responses-api";
+
+    fn enabled_for_staff() -> bool {
+        // Add yourself to the flag manually to test it out.
+        false
+    }
+}

crates/language_models/Cargo.toml 🔗

@@ -30,6 +30,7 @@ credentials_provider.workspace = true
 deepseek = { workspace = true, features = ["schemars"] }
 extension.workspace = true
 extension_host.workspace = true
+feature_flags.workspace = true
 fs.workspace = true
 futures.workspace = true
 google_ai = { workspace = true, features = ["schemars"] }

crates/language_models/src/provider/cloud.rs 🔗

@@ -11,6 +11,7 @@ use cloud_llm_client::{
     SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
     TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
 };
+use feature_flags::{FeatureFlagAppExt as _, OpenAiResponsesApiFeatureFlag};
 use futures::{
     AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
 };
@@ -46,7 +47,10 @@ use crate::provider::anthropic::{
     AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
 };
 use crate::provider::google::{GoogleEventMapper, into_google};
-use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
+use crate::provider::open_ai::{
+    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
+    into_open_ai_response,
+};
 use crate::provider::x_ai::count_xai_tokens;
 
 const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
@@ -756,6 +760,7 @@ impl LanguageModel for CloudLanguageModel {
         let intent = request.intent;
         let mode = request.mode;
         let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
+        let use_responses_api = cx.update(|cx| cx.has_flag::<OpenAiResponsesApiFeatureFlag>());
         let thinking_allowed = request.thinking_allowed;
         let provider_name = provider_name(&self.model.provider);
         match self.model.provider {
@@ -807,7 +812,7 @@ impl LanguageModel for CloudLanguageModel {
                         Box::pin(
                             response_lines(response, includes_status_messages)
                                 .chain(usage_updated_event(usage))
-                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)), // .map(|_| {}),
+                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
                         ),
                         &provider_name,
                         move |event| mapper.map_event(event),
@@ -817,50 +822,97 @@ impl LanguageModel for CloudLanguageModel {
             }
             cloud_llm_client::LanguageModelProvider::OpenAi => {
                 let client = self.client.clone();
-                let request = into_open_ai(
-                    request,
-                    &self.model.id.0,
-                    self.model.supports_parallel_tool_calls,
-                    true,
-                    None,
-                    None,
-                );
                 let llm_api_token = self.llm_api_token.clone();
-                let future = self.request_limiter.stream(async move {
-                    let PerformLlmCompletionResponse {
-                        response,
-                        usage,
-                        includes_status_messages,
-                        tool_use_limit_reached,
-                    } = Self::perform_llm_completion(
-                        client.clone(),
-                        llm_api_token,
-                        app_version,
-                        CompletionBody {
-                            thread_id,
-                            prompt_id,
-                            intent,
-                            mode,
-                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
-                            model: request.model.clone(),
-                            provider_request: serde_json::to_value(&request)
-                                .map_err(|e| anyhow!(e))?,
-                        },
-                    )
-                    .await?;
 
-                    let mut mapper = OpenAiEventMapper::new();
-                    Ok(map_cloud_completion_events(
-                        Box::pin(
-                            response_lines(response, includes_status_messages)
-                                .chain(usage_updated_event(usage))
-                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
-                        ),
-                        &provider_name,
-                        move |event| mapper.map_event(event),
-                    ))
-                });
-                async move { Ok(future.await?.boxed()) }.boxed()
+                if use_responses_api {
+                    let request = into_open_ai_response(
+                        request,
+                        &self.model.id.0,
+                        self.model.supports_parallel_tool_calls,
+                        true,
+                        None,
+                        None,
+                    );
+                    let future = self.request_limiter.stream(async move {
+                        let PerformLlmCompletionResponse {
+                            response,
+                            usage,
+                            includes_status_messages,
+                            tool_use_limit_reached,
+                        } = Self::perform_llm_completion(
+                            client.clone(),
+                            llm_api_token,
+                            app_version,
+                            CompletionBody {
+                                thread_id,
+                                prompt_id,
+                                intent,
+                                mode,
+                                provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+                                model: request.model.clone(),
+                                provider_request: serde_json::to_value(&request)
+                                    .map_err(|e| anyhow!(e))?,
+                            },
+                        )
+                        .await?;
+
+                        let mut mapper = OpenAiResponseEventMapper::new();
+                        Ok(map_cloud_completion_events(
+                            Box::pin(
+                                response_lines(response, includes_status_messages)
+                                    .chain(usage_updated_event(usage))
+                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+                            ),
+                            &provider_name,
+                            move |event| mapper.map_event(event),
+                        ))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                } else {
+                    let request = into_open_ai(
+                        request,
+                        &self.model.id.0,
+                        self.model.supports_parallel_tool_calls,
+                        true,
+                        None,
+                        None,
+                    );
+                    let future = self.request_limiter.stream(async move {
+                        let PerformLlmCompletionResponse {
+                            response,
+                            usage,
+                            includes_status_messages,
+                            tool_use_limit_reached,
+                        } = Self::perform_llm_completion(
+                            client.clone(),
+                            llm_api_token,
+                            app_version,
+                            CompletionBody {
+                                thread_id,
+                                prompt_id,
+                                intent,
+                                mode,
+                                provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+                                model: request.model.clone(),
+                                provider_request: serde_json::to_value(&request)
+                                    .map_err(|e| anyhow!(e))?,
+                            },
+                        )
+                        .await?;
+
+                        let mut mapper = OpenAiEventMapper::new();
+                        Ok(map_cloud_completion_events(
+                            Box::pin(
+                                response_lines(response, includes_status_messages)
+                                    .chain(usage_updated_event(usage))
+                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+                            ),
+                            &provider_name,
+                            move |event| mapper.map_event(event),
+                        ))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                }
             }
             cloud_llm_client::LanguageModelProvider::XAi => {
                 let client = self.client.clone();

crates/language_models/src/provider/open_ai.rs 🔗

@@ -887,7 +887,7 @@ impl OpenAiResponseEventMapper {
         })
     }
 
-    fn map_event(
+    pub fn map_event(
         &mut self,
         event: ResponsesStreamEvent,
     ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {