Add `RpcClient::subscribe`

Antonio Scandurra and Max Brunsfeld created

Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

zed/src/rpc_client.rs | 48 +++++++++++++++++++++++++++++++++++---------
1 file changed, 38 insertions(+), 10 deletions(-)

Detailed changes

zed/src/rpc_client.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
 use gpui::executor::Background;
 use parking_lot::Mutex;
 use postage::{
-    oneshot,
+    mpsc, oneshot,
     prelude::{Sink, Stream},
 };
 use smol::{
@@ -11,11 +11,13 @@ use smol::{
     prelude::{AsyncRead, AsyncWrite},
 };
 use std::{collections::HashMap, sync::Arc};
-use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage};
+use zed_rpc::proto::{
+    self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage,
+};
 
 pub struct RpcClient<Conn> {
     stream: MessageStream<WriteHalf<Conn>>,
-    response_channels: Arc<Mutex<HashMap<i32, oneshot::Sender<proto::from_server::Variant>>>>,
+    response_channels: Arc<Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>>,
     next_message_id: i32,
     _drop_tx: oneshot::Sender<()>,
 }
@@ -59,9 +61,15 @@ where
                         Message::Message(message) => {
                             if let Some(variant) = message.variant {
                                 if let Some(request_id) = message.request_id {
-                                    let tx = response_channels.lock().remove(&request_id);
-                                    if let Some(mut tx) = tx {
-                                        tx.send(variant).await?;
+                                    let channel = response_channels.lock().remove(&request_id);
+                                    if let Some((mut tx, oneshot)) = channel {
+                                        if tx.send(variant).await.is_ok() {
+                                            if !oneshot {
+                                                response_channels
+                                                    .lock()
+                                                    .insert(request_id, (tx, false));
+                                            }
+                                        }
                                     } else {
                                         log::warn!(
                                             "received RPC response to unknown request id {}",
@@ -85,10 +93,8 @@ where
     pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
         let message_id = self.next_message_id;
         self.next_message_id += 1;
-
-        let (tx, mut rx) = oneshot::channel();
-        self.response_channels.lock().insert(message_id, tx);
-
+        let (tx, mut rx) = mpsc::channel(1);
+        self.response_channels.lock().insert(message_id, (tx, true));
         self.stream
             .write_message(&proto::FromClient {
                 id: message_id,
@@ -114,6 +120,28 @@ where
             .await?;
         Ok(())
     }
+
+    pub async fn subscribe<T: SubscribeMessage>(
+        &mut self,
+        subscription: T,
+    ) -> Result<impl Stream<Item = Result<T::Event>>> {
+        let message_id = self.next_message_id;
+        self.next_message_id += 1;
+        let (tx, rx) = mpsc::channel(256);
+        self.response_channels
+            .lock()
+            .insert(message_id, (tx, false));
+        self.stream
+            .write_message(&proto::FromClient {
+                id: message_id,
+                variant: Some(subscription.to_variant()),
+            })
+            .await?;
+
+        Ok(rx.map(|event| {
+            T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
+        }))
+    }
 }
 
 #[cfg(test)]