From 451bf25d1c4efd4209483483d570972c6b4368f3 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 9 Jan 2026 17:10:11 -0500 Subject: [PATCH] language_models: Add support for using OpenAI Responses API through Zed provider (#46482) This PR adds support for using the OpenAI Responses API through the Zed provider. This is gated behind the `open-ai-responses-api` feature flag. Part of CLO-34. Release Notes: - N/A --- Cargo.lock | 1 + crates/feature_flags/src/flags.rs | 12 ++ crates/language_models/Cargo.toml | 1 + crates/language_models/src/provider/cloud.rs | 140 ++++++++++++------ .../language_models/src/provider/open_ai.rs | 2 +- 5 files changed, 111 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 716e39b6410e9ab78c135fcc0cb65297c9ad400a..cb369e2e32750798eb40eb958fd629f7d695c750 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8985,6 +8985,7 @@ dependencies = [ "editor", "extension", "extension_host", + "feature_flags", "fs", "futures 0.3.31", "google_ai", diff --git a/crates/feature_flags/src/flags.rs b/crates/feature_flags/src/flags.rs index 2fa896b9b480810cd04d01bb1f6d02a2761a689b..83e422b5bf0824999a994e1facf57d5472634ed1 100644 --- a/crates/feature_flags/src/flags.rs +++ b/crates/feature_flags/src/flags.rs @@ -45,3 +45,15 @@ impl FeatureFlag for SubagentsFeatureFlag { false } } + +/// Whether to use the OpenAI Responses API format when sending requests to Cloud. +pub struct OpenAiResponsesApiFeatureFlag; + +impl FeatureFlag for OpenAiResponsesApiFeatureFlag { + const NAME: &'static str = "open-ai-responses-api"; + + fn enabled_for_staff() -> bool { + // Add yourself to the flag manually to test it out. + false + } +} diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b312342a2106e27d85531c905aab05e07ef8d68c..6f5ca58e221207b2732b4a0388351fa40826e296 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -30,6 +30,7 @@ credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } extension.workspace = true extension_host.workspace = true +feature_flags.workspace = true fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 354052e3af04b18dedb41f4a057bfdfb228c5996..ff8d01ec717c732bfd25dac024a33e48787e706d 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -11,6 +11,7 @@ use cloud_llm_client::{ SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; +use feature_flags::{FeatureFlagAppExt as _, OpenAiResponsesApiFeatureFlag}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -46,7 +47,10 @@ use crate::provider::anthropic::{ AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, }; use crate::provider::google::{GoogleEventMapper, into_google}; -use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; +use crate::provider::open_ai::{ + OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, + into_open_ai_response, +}; use crate::provider::x_ai::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID; @@ -756,6 +760,7 @@ impl LanguageModel for CloudLanguageModel { let intent = request.intent; let mode = request.mode; let app_version = Some(cx.update(|cx| AppVersion::global(cx))); + let use_responses_api = cx.update(|cx| cx.has_flag::()); let thinking_allowed = request.thinking_allowed; let provider_name = provider_name(&self.model.provider); match self.model.provider { @@ -807,7 +812,7 @@ impl LanguageModel for CloudLanguageModel { Box::pin( response_lines(response, includes_status_messages) .chain(usage_updated_event(usage)) - .chain(tool_use_limit_reached_event(tool_use_limit_reached)), // .map(|_| {}), + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), ), &provider_name, move |event| mapper.map_event(event), @@ -817,50 +822,97 @@ impl LanguageModel for CloudLanguageModel { } cloud_llm_client::LanguageModelProvider::OpenAi => { let client = self.client.clone(); - let request = into_open_ai( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - true, - None, - None, - ); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - usage, - includes_status_messages, - tool_use_limit_reached, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - app_version, - CompletionBody { - thread_id, - prompt_id, - intent, - mode, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - let mut mapper = OpenAiEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin( - response_lines(response, includes_status_messages) - .chain(usage_updated_event(usage)) - .chain(tool_use_limit_reached_event(tool_use_limit_reached)), - ), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() + if use_responses_api { + let request = into_open_ai_response( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + true, + None, + None, + ); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + usage, + includes_status_messages, + tool_use_limit_reached, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + mode, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiResponseEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin( + response_lines(response, includes_status_messages) + .chain(usage_updated_event(usage)) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), + ), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } else { + let request = into_open_ai( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + true, + None, + None, + ); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + usage, + includes_status_messages, + tool_use_limit_reached, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + mode, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin( + response_lines(response, includes_status_messages) + .chain(usage_updated_event(usage)) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), + ), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } } cloud_llm_client::LanguageModelProvider::XAi => { let client = self.client.clone(); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 04248ffc3c7a465d04bca0ffa11ca53e39a24880..f0e6802280f4c71bd1a376f37a5b9cade30e04a2 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -887,7 +887,7 @@ impl OpenAiResponseEventMapper { }) } - fn map_event( + pub fn map_event( &mut self, event: ResponsesStreamEvent, ) -> Vec> {