diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index bd88d5d3b384aadef0c34e997191033045dc2de5..f822b89916a60c32b5f076580f960d47c6a1463c 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -5,9 +5,10 @@ use chrono::{DateTime, Utc}; use client::{Client, UserStore, zed_urls}; use cloud_api_types::Plan; use cloud_llm_client::{ - CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, - CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, - ListModelsResponse, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, + CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, + CountTokensBody, CountTokensResponse, ListModelsResponse, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, @@ -397,8 +398,7 @@ impl CloudLanguageModel { .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") - // TODO: Uncomment once the cloud-side StreamEnded support PR is merged. - // .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") + .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") .body(serde_json::to_string(&body)?.into())?; let mut response = http_client.send(request).await?; @@ -938,8 +938,7 @@ where let provider = provider.clone(); let mut stream = stream.fuse(); - // TODO: Uncomment once the cloud-side StreamEnded support PR is merged. - // let mut saw_stream_ended = false; + let mut saw_stream_ended = false; let mut done = false; let mut pending = VecDeque::new(); @@ -961,10 +960,7 @@ where vec![Err(LanguageModelCompletionError::from(error))] } Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => { - // TODO: Uncomment once the cloud-side StreamEnded support PR is merged. - // let mut saw_stream_ended = false; - // - // saw_stream_ended = true; + saw_stream_ended = true; vec![] } Ok(CompletionEvent::Status(status)) => { @@ -983,15 +979,13 @@ where Poll::Ready(None) => { done = true; - // TODO: Uncomment once the cloud-side StreamEnded support PR is merged. - // - // if !saw_stream_ended { - // return Poll::Ready(Some(Err( - // LanguageModelCompletionError::StreamEndedUnexpectedly { - // provider: provider.clone(), - // }, - // ))); - // } + if !saw_stream_ended { + return Poll::Ready(Some(Err( + LanguageModelCompletionError::StreamEndedUnexpectedly { + provider: provider.clone(), + }, + ))); + } } Poll::Pending => return Poll::Pending, }