cloud_llm_client: Add StreamEnded and Unknown variants to CompletionRequestStatus (#49121)

Tom Houlé and Marshall Bowers created

Add StreamEnded variant so the client can distinguish between a stream
that the cloud ran to completion versus one that was interrupted (see
CLO-258). **That logic is to be added in a follow up PR**.

Add an Unknown fallback with #[serde(other)] for forward-compatible
deserialization of future variants.

The client advertises support via a new
x-zed-client-supports-stream-ended-request-completion-status header. The
server will only send the new variant if that header is passed. Both
StreamEnded and Unknown are silently ignored at the event mapping layer
(from_completion_request_status returns Ok(None)).

Part of CLO-264 and CLO-266; cloud-side changes to follow.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>

Change summary

crates/agent/src/thread.rs                      | 10 +
crates/cloud_llm_client/src/cloud_llm_client.rs |  8 +
crates/language_model/src/language_model.rs     | 10 +
crates/language_models/src/provider/cloud.rs    | 88 ++++++++++++++----
4 files changed, 90 insertions(+), 26 deletions(-)

Detailed changes

crates/agent/src/thread.rs 🔗

@@ -2753,10 +2753,12 @@ impl Thread {
             | ApiEndpointNotFound { .. }
             | PromptTooLarge { .. } => None,
             // These errors might be transient, so retry them
-            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
-                delay: BASE_RETRY_DELAY,
-                max_attempts: 1,
-            }),
+            SerializeRequest { .. } | BuildRequestBody { .. } | StreamEndedUnexpectedly { .. } => {
+                Some(RetryStrategy::Fixed {
+                    delay: BASE_RETRY_DELAY,
+                    max_attempts: 1,
+                })
+            }
             // Retry all other 4xx and 5xx errors once.
             HttpResponseError { status_code, .. }
                 if status_code.is_client_error() || status_code.is_server_error() =>

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -43,6 +43,10 @@ pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-v
 pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
     "x-zed-client-supports-status-messages";
 
+/// The name of the header used by the client to indicate to the server that it supports receiving a "stream_ended" request completion status.
+pub const CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME: &str =
+    "x-zed-client-supports-stream-ended-request-completion-status";
+
 /// The name of the header used by the server to indicate to the client that it supports sending status messages.
 pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
     "x-zed-server-supports-status-messages";
@@ -223,6 +227,10 @@ pub enum CompletionRequestStatus {
         limit: UsageLimit,
     },
     ToolUseLimitReached,
+    /// The cloud sends a StreamEnded message when the stream from the LLM provider finishes.
+    StreamEnded,
+    #[serde(other)]
+    Unknown,
 }
 
 #[derive(Serialize, Deserialize)]

crates/language_model/src/language_model.rs 🔗

@@ -104,12 +104,13 @@ impl LanguageModelCompletionEvent {
     pub fn from_completion_request_status(
         status: CompletionRequestStatus,
         upstream_provider: LanguageModelProviderName,
-    ) -> Result<Self, LanguageModelCompletionError> {
+    ) -> Result<Option<Self>, LanguageModelCompletionError> {
         match status {
             CompletionRequestStatus::Queued { position } => {
-                Ok(LanguageModelCompletionEvent::Queued { position })
+                Ok(Some(LanguageModelCompletionEvent::Queued { position }))
             }
-            CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
+            CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
+            CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
             CompletionRequestStatus::UsageUpdated { .. }
             | CompletionRequestStatus::ToolUseLimitReached => Err(
                 LanguageModelCompletionError::Other(anyhow!("Unexpected status: {status:?}")),
@@ -212,6 +213,9 @@ pub enum LanguageModelCompletionError {
         error: serde_json::Error,
     },
 
+    #[error("stream from {provider} ended unexpectedly")]
+    StreamEndedUnexpectedly { provider: LanguageModelProviderName },
+
     // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
     #[error(transparent)]
     Other(#[from] anyhow::Error),

crates/language_models/src/provider/cloud.rs 🔗

@@ -6,12 +6,14 @@ use client::{Client, UserStore, zed_urls};
 use cloud_api_types::Plan;
 use cloud_llm_client::{
     CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody,
-    CompletionEvent, CountTokensBody, CountTokensResponse, ListModelsResponse,
-    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
+    CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
+    ListModelsResponse, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
 };
 use feature_flags::{CloudThinkingEffortFeatureFlag, FeatureFlagAppExt as _};
 use futures::{
-    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
+    AsyncBufReadExt, FutureExt, Stream, StreamExt,
+    future::BoxFuture,
+    stream::{self, BoxStream},
 };
 use google_ai::GoogleModelMode;
 use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
@@ -33,9 +35,11 @@ use settings::SettingsStore;
 pub use settings::ZedDotDevAvailableModel as AvailableModel;
 pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
 use smol::io::{AsyncReadExt, BufReader};
+use std::collections::VecDeque;
 use std::pin::Pin;
 use std::str::FromStr;
 use std::sync::Arc;
+use std::task::Poll;
 use std::time::Duration;
 use thiserror::Error;
 use ui::{TintColor, prelude::*};
@@ -410,6 +414,8 @@ impl CloudLanguageModel {
                 .header("Content-Type", "application/json")
                 .header("Authorization", format!("Bearer {token}"))
                 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
+                // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
+                // .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
                 .body(serde_json::to_string(&body)?.into())?;
 
             let mut response = http_client.send(request).await?;
@@ -953,24 +959,68 @@ where
         + 'static,
 {
     let provider = provider.clone();
-    stream
-        .flat_map(move |event| {
-            futures::stream::iter(match event {
-                Err(error) => {
-                    vec![Err(LanguageModelCompletionError::from(error))]
+    let mut stream = stream.fuse();
+
+    // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
+    // let mut saw_stream_ended = false;
+
+    let mut done = false;
+    let mut pending = VecDeque::new();
+
+    stream::poll_fn(move |cx| {
+        loop {
+            if let Some(item) = pending.pop_front() {
+                return Poll::Ready(Some(item));
+            }
+
+            if done {
+                return Poll::Ready(None);
+            }
+
+            match stream.poll_next_unpin(cx) {
+                Poll::Ready(Some(event)) => {
+                    let items = match event {
+                        Err(error) => {
+                            vec![Err(LanguageModelCompletionError::from(error))]
+                        }
+                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
+                            // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
+                            // let mut saw_stream_ended = false;
+                            //
+                            // saw_stream_ended = true;
+                            vec![]
+                        }
+                        Ok(CompletionEvent::Status(status)) => {
+                            LanguageModelCompletionEvent::from_completion_request_status(
+                                status,
+                                provider.clone(),
+                            )
+                            .transpose()
+                            .map(|event| vec![event])
+                            .unwrap_or_default()
+                        }
+                        Ok(CompletionEvent::Event(event)) => map_callback(event),
+                    };
+                    pending.extend(items);
                 }
-                Ok(CompletionEvent::Status(event)) => {
-                    vec![
-                        LanguageModelCompletionEvent::from_completion_request_status(
-                            event,
-                            provider.clone(),
-                        ),
-                    ]
+                Poll::Ready(None) => {
+                    done = true;
+
+                    // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
+                    //
+                    // if !saw_stream_ended {
+                    //     return Poll::Ready(Some(Err(
+                    //         LanguageModelCompletionError::StreamEndedUnexpectedly {
+                    //             provider: provider.clone(),
+                    //         },
+                    //     )));
+                    // }
                 }
-                Ok(CompletionEvent::Event(event)) => map_callback(event),
-            })
-        })
-        .boxed()
+                Poll::Pending => return Poll::Pending,
+            }
+        }
+    })
+    .boxed()
 }
 
 fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {