From 93ead966c23858c88a8139f561bc2cbdf4e60fb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Houl=C3=A9?= <13155277+tomhoule@users.noreply.github.com> Date: Mon, 16 Feb 2026 15:39:47 +0100 Subject: [PATCH] cloud_llm_client: Add StreamEnded and Unknown variants to CompletionRequestStatus (#49121) Add StreamEnded variant so the client can distinguish between a stream that the cloud ran to completion versus one that was interrupted (see CLO-258). **That logic is to be added in a follow up PR**. Add an Unknown fallback with #[serde(other)] for forward-compatible deserialization of future variants. The client advertises support via a new x-zed-client-supports-stream-ended-request-completion-status header. The server will only send the new variant if that header is passed. Both StreamEnded and Unknown are silently ignored at the event mapping layer (from_completion_request_status returns Ok(None)). Part of CLO-264 and CLO-266; cloud-side changes to follow. Release Notes: - N/A --------- Co-authored-by: Marshall Bowers --- crates/agent/src/thread.rs | 10 ++- .../cloud_llm_client/src/cloud_llm_client.rs | 8 ++ crates/language_model/src/language_model.rs | 10 ++- crates/language_models/src/provider/cloud.rs | 88 +++++++++++++++---- 4 files changed, 90 insertions(+), 26 deletions(-) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 5e16ec826d1ff56c7e185344a1431cc5d8cafe12..8e76291854e88524790acc321b72df707b718910 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2753,10 +2753,12 @@ impl Thread { | ApiEndpointNotFound { .. } | PromptTooLarge { .. } => None, // These errors might be transient, so retry them - SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed { - delay: BASE_RETRY_DELAY, - max_attempts: 1, - }), + SerializeRequest { .. } | BuildRequestBody { .. } | StreamEndedUnexpectedly { .. } => { + Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 1, + }) + } // Retry all other 4xx and 5xx errors once. HttpResponseError { status_code, .. } if status_code.is_client_error() || status_code.is_server_error() => diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 01194e66e7db9eb399e431aa16dc9b9bbe0a6177..20e4d49bb3e42e0e9ce92e61bb0dfa377d9c2ad6 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -43,6 +43,10 @@ pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-v pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = "x-zed-client-supports-status-messages"; +/// The name of the header used by the client to indicate to the server that it supports receiving a "stream_ended" request completion status. +pub const CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME: &str = + "x-zed-client-supports-stream-ended-request-completion-status"; + /// The name of the header used by the server to indicate to the client that it supports sending status messages. pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = "x-zed-server-supports-status-messages"; @@ -223,6 +227,10 @@ pub enum CompletionRequestStatus { limit: UsageLimit, }, ToolUseLimitReached, + /// The cloud sends a StreamEnded message when the stream from the LLM provider finishes. + StreamEnded, + #[serde(other)] + Unknown, } #[derive(Serialize, Deserialize)] diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 1ff58babbc77dc3ce8881fcd352afe393da7761d..e63a5537ff966ae95d75db31bf0ad3fae342dbde 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -104,12 +104,13 @@ impl LanguageModelCompletionEvent { pub fn from_completion_request_status( status: CompletionRequestStatus, upstream_provider: LanguageModelProviderName, - ) -> Result { + ) -> Result, LanguageModelCompletionError> { match status { CompletionRequestStatus::Queued { position } => { - Ok(LanguageModelCompletionEvent::Queued { position }) + Ok(Some(LanguageModelCompletionEvent::Queued { position })) } - CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started), + CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)), + CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None), CompletionRequestStatus::UsageUpdated { .. } | CompletionRequestStatus::ToolUseLimitReached => Err( LanguageModelCompletionError::Other(anyhow!("Unexpected status: {status:?}")), @@ -212,6 +213,9 @@ pub enum LanguageModelCompletionError { error: serde_json::Error, }, + #[error("stream from {provider} ended unexpectedly")] + StreamEndedUnexpectedly { provider: LanguageModelProviderName }, + // TODO: Ideally this would be removed in favor of having a comprehensive list of errors. #[error(transparent)] Other(#[from] anyhow::Error), diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 2473d37e363da06752214b3e9ac1de685840a4ea..6caf56ebcabd1f862caf5518e33099e9160b8653 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -6,12 +6,14 @@ use client::{Client, UserStore, zed_urls}; use cloud_api_types::Plan; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, - CompletionEvent, CountTokensBody, CountTokensResponse, ListModelsResponse, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, + CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, + ListModelsResponse, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; use feature_flags::{CloudThinkingEffortFeatureFlag, FeatureFlagAppExt as _}; use futures::{ - AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, + AsyncBufReadExt, FutureExt, Stream, StreamExt, + future::BoxFuture, + stream::{self, BoxStream}, }; use google_ai::GoogleModelMode; use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; @@ -33,9 +35,11 @@ use settings::SettingsStore; pub use settings::ZedDotDevAvailableModel as AvailableModel; pub use settings::ZedDotDevAvailableProvider as AvailableProvider; use smol::io::{AsyncReadExt, BufReader}; +use std::collections::VecDeque; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::task::Poll; use std::time::Duration; use thiserror::Error; use ui::{TintColor, prelude::*}; @@ -410,6 +414,8 @@ impl CloudLanguageModel { .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") + // TODO: Uncomment once the cloud-side StreamEnded support PR is merged. + // .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") .body(serde_json::to_string(&body)?.into())?; let mut response = http_client.send(request).await?; @@ -953,24 +959,68 @@ where + 'static, { let provider = provider.clone(); - stream - .flat_map(move |event| { - futures::stream::iter(match event { - Err(error) => { - vec![Err(LanguageModelCompletionError::from(error))] + let mut stream = stream.fuse(); + + // TODO: Uncomment once the cloud-side StreamEnded support PR is merged. + // let mut saw_stream_ended = false; + + let mut done = false; + let mut pending = VecDeque::new(); + + stream::poll_fn(move |cx| { + loop { + if let Some(item) = pending.pop_front() { + return Poll::Ready(Some(item)); + } + + if done { + return Poll::Ready(None); + } + + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(event)) => { + let items = match event { + Err(error) => { + vec![Err(LanguageModelCompletionError::from(error))] + } + Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => { + // TODO: Uncomment once the cloud-side StreamEnded support PR is merged. + // let mut saw_stream_ended = false; + // + // saw_stream_ended = true; + vec![] + } + Ok(CompletionEvent::Status(status)) => { + LanguageModelCompletionEvent::from_completion_request_status( + status, + provider.clone(), + ) + .transpose() + .map(|event| vec![event]) + .unwrap_or_default() + } + Ok(CompletionEvent::Event(event)) => map_callback(event), + }; + pending.extend(items); } - Ok(CompletionEvent::Status(event)) => { - vec![ - LanguageModelCompletionEvent::from_completion_request_status( - event, - provider.clone(), - ), - ] + Poll::Ready(None) => { + done = true; + + // TODO: Uncomment once the cloud-side StreamEnded support PR is merged. + // + // if !saw_stream_ended { + // return Poll::Ready(Some(Err( + // LanguageModelCompletionError::StreamEndedUnexpectedly { + // provider: provider.clone(), + // }, + // ))); + // } } - Ok(CompletionEvent::Event(event)) => map_callback(event), - }) - }) - .boxed() + Poll::Pending => return Poll::Pending, + } + } + }) + .boxed() } fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {