cloud provider: Use `CompletionEvent` type from `zed_llm_client` (#35285)

Michael Sloan created

Release Notes:

- N/A

Change summary

crates/language_models/src/provider/cloud.rs | 31 ++++++++-------------
1 file changed, 12 insertions(+), 19 deletions(-)

Detailed changes

crates/language_models/src/provider/cloud.rs 🔗

@@ -35,8 +35,8 @@ use ui::{TintColor, prelude::*};
 use util::{ResultExt as _, maybe};
 use zed_llm_client::{
     CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
-    CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
-    ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
+    CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
+    EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
     SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
     TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
 };
@@ -1040,15 +1040,8 @@ impl LanguageModel for CloudLanguageModel {
     }
 }
 
-#[derive(Serialize, Deserialize)]
-#[serde(rename_all = "snake_case")]
-pub enum CloudCompletionEvent<T> {
-    Status(CompletionRequestStatus),
-    Event(T),
-}
-
 fn map_cloud_completion_events<T, F>(
-    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
+    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
     mut map_callback: F,
 ) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 where
@@ -1063,10 +1056,10 @@ where
                 Err(error) => {
                     vec![Err(LanguageModelCompletionError::from(error))]
                 }
-                Ok(CloudCompletionEvent::Status(event)) => {
+                Ok(CompletionEvent::Status(event)) => {
                     vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
                 }
-                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
+                Ok(CompletionEvent::Event(event)) => map_callback(event),
             })
         })
         .boxed()
@@ -1074,9 +1067,9 @@ where
 
 fn usage_updated_event<T>(
     usage: Option<ModelRequestUsage>,
-) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+) -> impl Stream<Item = Result<CompletionEvent<T>>> {
     futures::stream::iter(usage.map(|usage| {
-        Ok(CloudCompletionEvent::Status(
+        Ok(CompletionEvent::Status(
             CompletionRequestStatus::UsageUpdated {
                 amount: usage.amount as usize,
                 limit: usage.limit,
@@ -1087,9 +1080,9 @@ fn usage_updated_event<T>(
 
 fn tool_use_limit_reached_event<T>(
     tool_use_limit_reached: bool,
-) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+) -> impl Stream<Item = Result<CompletionEvent<T>>> {
     futures::stream::iter(tool_use_limit_reached.then(|| {
-        Ok(CloudCompletionEvent::Status(
+        Ok(CompletionEvent::Status(
             CompletionRequestStatus::ToolUseLimitReached,
         ))
     }))
@@ -1098,7 +1091,7 @@ fn tool_use_limit_reached_event<T>(
 fn response_lines<T: DeserializeOwned>(
     response: Response<AsyncBody>,
     includes_status_messages: bool,
-) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+) -> impl Stream<Item = Result<CompletionEvent<T>>> {
     futures::stream::try_unfold(
         (String::new(), BufReader::new(response.into_body())),
         move |(mut line, mut body)| async move {
@@ -1106,9 +1099,9 @@ fn response_lines<T: DeserializeOwned>(
                 Ok(0) => Ok(None),
                 Ok(_) => {
                     let event = if includes_status_messages {
-                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
+                        serde_json::from_str::<CompletionEvent<T>>(&line)?
                     } else {
-                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
+                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
                     };
 
                     line.clear();