@@ -2753,10 +2753,12 @@ impl Thread {
| ApiEndpointNotFound { .. }
| PromptTooLarge { .. } => None,
// These errors might be transient, so retry them
- SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
- delay: BASE_RETRY_DELAY,
- max_attempts: 1,
- }),
+ SerializeRequest { .. } | BuildRequestBody { .. } | StreamEndedUnexpectedly { .. } => {
+ Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 1,
+ })
+ }
// Retry all other 4xx and 5xx errors once.
HttpResponseError { status_code, .. }
if status_code.is_client_error() || status_code.is_server_error() =>
@@ -43,6 +43,10 @@ pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-v
pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
"x-zed-client-supports-status-messages";
+/// The name of the header used by the client to indicate to the server that it supports receiving a "stream_ended" request completion status.
+pub const CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME: &str =
+ "x-zed-client-supports-stream-ended-request-completion-status";
+
/// The name of the header used by the server to indicate to the client that it supports sending status messages.
pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
"x-zed-server-supports-status-messages";
@@ -223,6 +227,10 @@ pub enum CompletionRequestStatus {
limit: UsageLimit,
},
ToolUseLimitReached,
+ /// The cloud sends a StreamEnded message when the stream from the LLM provider finishes.
+ StreamEnded,
+ #[serde(other)]
+ Unknown,
}
#[derive(Serialize, Deserialize)]
@@ -104,12 +104,13 @@ impl LanguageModelCompletionEvent {
pub fn from_completion_request_status(
status: CompletionRequestStatus,
upstream_provider: LanguageModelProviderName,
- ) -> Result<Self, LanguageModelCompletionError> {
+ ) -> Result<Option<Self>, LanguageModelCompletionError> {
match status {
CompletionRequestStatus::Queued { position } => {
- Ok(LanguageModelCompletionEvent::Queued { position })
+ Ok(Some(LanguageModelCompletionEvent::Queued { position }))
}
- CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
+ CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
+ CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
CompletionRequestStatus::UsageUpdated { .. }
| CompletionRequestStatus::ToolUseLimitReached => Err(
LanguageModelCompletionError::Other(anyhow!("Unexpected status: {status:?}")),
@@ -212,6 +213,9 @@ pub enum LanguageModelCompletionError {
error: serde_json::Error,
},
+ #[error("stream from {provider} ended unexpectedly")]
+ StreamEndedUnexpectedly { provider: LanguageModelProviderName },
+
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
#[error(transparent)]
Other(#[from] anyhow::Error),
@@ -6,12 +6,14 @@ use client::{Client, UserStore, zed_urls};
use cloud_api_types::Plan;
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody,
- CompletionEvent, CountTokensBody, CountTokensResponse, ListModelsResponse,
- SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
+ CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
+ ListModelsResponse, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
use feature_flags::{CloudThinkingEffortFeatureFlag, FeatureFlagAppExt as _};
use futures::{
- AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
+ AsyncBufReadExt, FutureExt, Stream, StreamExt,
+ future::BoxFuture,
+ stream::{self, BoxStream},
};
use google_ai::GoogleModelMode;
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
@@ -33,9 +35,11 @@ use settings::SettingsStore;
pub use settings::ZedDotDevAvailableModel as AvailableModel;
pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
use smol::io::{AsyncReadExt, BufReader};
+use std::collections::VecDeque;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
+use std::task::Poll;
use std::time::Duration;
use thiserror::Error;
use ui::{TintColor, prelude::*};
@@ -410,6 +414,8 @@ impl CloudLanguageModel {
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
+ // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
+ // .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?;
@@ -953,24 +959,68 @@ where
+ 'static,
{
let provider = provider.clone();
- stream
- .flat_map(move |event| {
- futures::stream::iter(match event {
- Err(error) => {
- vec![Err(LanguageModelCompletionError::from(error))]
+ let mut stream = stream.fuse();
+
+ // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
+ // let mut saw_stream_ended = false;
+
+ let mut done = false;
+ let mut pending = VecDeque::new();
+
+ stream::poll_fn(move |cx| {
+ loop {
+ if let Some(item) = pending.pop_front() {
+ return Poll::Ready(Some(item));
+ }
+
+ if done {
+ return Poll::Ready(None);
+ }
+
+ match stream.poll_next_unpin(cx) {
+ Poll::Ready(Some(event)) => {
+ let items = match event {
+ Err(error) => {
+ vec![Err(LanguageModelCompletionError::from(error))]
+ }
+ Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
+ // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
+ // let mut saw_stream_ended = false;
+ //
+ // saw_stream_ended = true;
+ vec![]
+ }
+ Ok(CompletionEvent::Status(status)) => {
+ LanguageModelCompletionEvent::from_completion_request_status(
+ status,
+ provider.clone(),
+ )
+ .transpose()
+ .map(|event| vec![event])
+ .unwrap_or_default()
+ }
+ Ok(CompletionEvent::Event(event)) => map_callback(event),
+ };
+ pending.extend(items);
}
- Ok(CompletionEvent::Status(event)) => {
- vec![
- LanguageModelCompletionEvent::from_completion_request_status(
- event,
- provider.clone(),
- ),
- ]
+ Poll::Ready(None) => {
+ done = true;
+
+ // TODO: Uncomment once the cloud-side StreamEnded support PR is merged.
+ //
+ // if !saw_stream_ended {
+ // return Poll::Ready(Some(Err(
+ // LanguageModelCompletionError::StreamEndedUnexpectedly {
+ // provider: provider.clone(),
+ // },
+ // )));
+ // }
}
- Ok(CompletionEvent::Event(event)) => map_callback(event),
- })
- })
- .boxed()
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+ })
+ .boxed()
}
fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {