Replace Request and Message with a single TypedEnvelope

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

zed/src/rpc_client.rs | 124 +++++++++++++-------------------------------
zed/src/util.rs       |  64 ++---------------------
zed/src/workspace.rs  |  10 +-
3 files changed, 47 insertions(+), 151 deletions(-)

Detailed changes

zed/src/rpc_client.rs 🔗

@@ -1,5 +1,8 @@
 use anyhow::{anyhow, Result};
-use futures::future::Either;
+use futures::{
+    future::{BoxFuture, Either},
+    FutureExt,
+};
 use postage::{
     barrier, mpsc, oneshot,
     prelude::{Sink, Stream},
@@ -31,67 +34,27 @@ struct RpcConnection {
 }
 
 type MessageHandler =
-    Box<dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<ErasedMessage>>;
-
-struct ErasedMessage {
-    id: u32,
-    connection_id: ConnectionId,
-    body: proto::Envelope,
-}
-
-pub struct Message<T: EnvelopedMessage> {
-    connection_id: ConnectionId,
-    body: Option<T>,
-}
-
-impl<T: EnvelopedMessage> From<ErasedMessage> for Message<T> {
-    fn from(message: ErasedMessage) -> Self {
-        Self {
-            connection_id: message.connection_id,
-            body: T::from_envelope(message.body),
-        }
-    }
-}
-
-impl<T: EnvelopedMessage> Message<T> {
-    pub fn connection_id(&self) -> ConnectionId {
-        self.connection_id
-    }
-
-    pub fn body(&mut self) -> T {
-        self.body.take().expect("body already taken")
-    }
-}
+    Box<dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<()>>>;
 
-pub struct Request<T: RequestMessage> {
+pub struct TypedEnvelope<T> {
     id: u32,
     connection_id: ConnectionId,
-    body: Option<T>,
+    payload: T,
 }
 
-impl<T: RequestMessage> From<ErasedMessage> for Request<T> {
-    fn from(message: ErasedMessage) -> Self {
-        Self {
-            id: message.id,
-            connection_id: message.connection_id,
-            body: T::from_envelope(message.body),
-        }
-    }
-}
-
-impl<T: RequestMessage> Request<T> {
+impl<T> TypedEnvelope<T> {
     pub fn connection_id(&self) -> ConnectionId {
         self.connection_id
     }
 
-    pub fn body(&mut self) -> T {
-        self.body.take().expect("body already taken")
+    pub fn payload(&self) -> &T {
+        &self.payload
     }
 }
 
 pub struct RpcClient {
     connections: RwLock<HashMap<ConnectionId, Arc<RpcConnection>>>,
-    message_handlers: RwLock<Vec<(mpsc::Sender<ErasedMessage>, MessageHandler)>>,
+    message_handlers: RwLock<Vec<MessageHandler>>,
     handler_types: Mutex<HashSet<TypeId>>,
     next_connection_id: AtomicU32,
 }
@@ -106,52 +69,37 @@ impl RpcClient {
         })
     }
 
-    pub async fn add_request_handler<T: RequestMessage>(&self) -> impl Stream<Item = Request<T>> {
+    pub async fn add_message_handler<T: EnvelopedMessage>(
+        &self,
+    ) -> mpsc::Receiver<TypedEnvelope<T>> {
         if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
             panic!("duplicate handler type");
         }
 
         let (tx, rx) = mpsc::channel(256);
-        self.message_handlers.write().await.push((
-            tx,
-            Box::new(move |envelope, connection_id| {
-                if envelope.as_ref().map_or(false, T::matches_envelope) {
-                    let envelope = Option::take(envelope).unwrap();
-                    Some(ErasedMessage {
-                        id: envelope.id,
-                        connection_id,
-                        body: envelope,
-                    })
-                } else {
-                    None
-                }
-            }),
-        ));
-        rx.map(Request::from)
-    }
-
-    pub async fn add_message_handler<T: EnvelopedMessage>(&self) -> impl Stream<Item = Message<T>> {
-        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
-            panic!("duplicate handler type");
-        }
-
-        let (tx, rx) = mpsc::channel(256);
-        self.message_handlers.write().await.push((
-            tx,
-            Box::new(move |envelope, connection_id| {
+        self.message_handlers
+            .write()
+            .await
+            .push(Box::new(move |envelope, connection_id| {
                 if envelope.as_ref().map_or(false, T::matches_envelope) {
                     let envelope = Option::take(envelope).unwrap();
-                    Some(ErasedMessage {
-                        id: envelope.id,
-                        connection_id,
-                        body: envelope,
-                    })
+                    let mut tx = tx.clone();
+                    Some(
+                        async move {
+                            tx.send(TypedEnvelope {
+                                id: envelope.id,
+                                connection_id,
+                                payload: T::from_envelope(envelope).unwrap(),
+                            })
+                            .await;
+                        }
+                        .boxed(),
+                    )
                 } else {
                     None
                 }
-            }),
-        ));
-        rx.map(Message::from)
+            }));
+        rx
     }
 
     pub async fn add_connection<Conn>(
@@ -208,9 +156,9 @@ impl RpcClient {
                         } else {
                             let mut handled = false;
                             let mut envelope = Some(incoming);
-                            for (tx, handler) in this.message_handlers.read().await.iter() {
-                                if let Some(message) = handler(&mut envelope, connection_id) {
-                                    let _ = tx.clone().send(message).await;
+                            for handler in this.message_handlers.read().await.iter() {
+                                if let Some(future) = handler(&mut envelope, connection_id) {
+                                    future.await;
                                     handled = true;
                                     break;
                                 }
@@ -303,7 +251,7 @@ impl RpcClient {
 
     pub fn respond<T: RequestMessage>(
         self: &Arc<Self>,
-        request: Request<T>,
+        request: TypedEnvelope<T>,
         response: T::Response,
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();

zed/src/util.rs 🔗

@@ -1,4 +1,4 @@
-use crate::rpc_client::{Message, Request, RpcClient};
+use crate::rpc_client::{RpcClient, TypedEnvelope};
 use postage::prelude::Stream;
 use rand::prelude::*;
 use std::{cmp::Ordering, future::Future, sync::Arc};
@@ -56,41 +56,12 @@ where
     }
 }
 
-pub trait RequestHandler<'a, R: proto::RequestMessage> {
-    type Output: 'a + Future<Output = anyhow::Result<()>>;
-
-    fn handle(
-        &self,
-        request: Request<R>,
-        client: Arc<RpcClient>,
-        cx: &'a mut gpui::AsyncAppContext,
-    ) -> Self::Output;
-}
-
-impl<'a, R, F, Fut> RequestHandler<'a, R> for F
-where
-    R: proto::RequestMessage,
-    F: Fn(Request<R>, Arc<RpcClient>, &'a mut gpui::AsyncAppContext) -> Fut,
-    Fut: 'a + Future<Output = anyhow::Result<()>>,
-{
-    type Output = Fut;
-
-    fn handle(
-        &self,
-        request: Request<R>,
-        client: Arc<RpcClient>,
-        cx: &'a mut gpui::AsyncAppContext,
-    ) -> Self::Output {
-        (self)(request, client, cx)
-    }
-}
-
 pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
     type Output: 'a + Future<Output = anyhow::Result<()>>;
 
     fn handle(
         &self,
-        message: Message<M>,
+        message: TypedEnvelope<M>,
         client: Arc<RpcClient>,
         cx: &'a mut gpui::AsyncAppContext,
     ) -> Self::Output;
@@ -99,14 +70,14 @@ pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
 impl<'a, M, F, Fut> MessageHandler<'a, M> for F
 where
     M: proto::EnvelopedMessage,
-    F: Fn(Message<M>, Arc<RpcClient>, &'a mut gpui::AsyncAppContext) -> Fut,
+    F: Fn(TypedEnvelope<M>, Arc<RpcClient>, &'a mut gpui::AsyncAppContext) -> Fut,
     Fut: 'a + Future<Output = anyhow::Result<()>>,
 {
     type Output = Fut;
 
     fn handle(
         &self,
-        message: Message<M>,
+        message: TypedEnvelope<M>,
         client: Arc<RpcClient>,
         cx: &'a mut gpui::AsyncAppContext,
     ) -> Self::Output {
@@ -114,31 +85,8 @@ where
     }
 }
 
-pub fn spawn_request_handler<H, R>(
-    handler: H,
-    client: &Arc<RpcClient>,
-    cx: &mut gpui::MutableAppContext,
-) where
-    H: 'static + for<'a> RequestHandler<'a, R>,
-    R: proto::RequestMessage,
-{
-    let client = client.clone();
-    let mut requests = smol::block_on(client.add_request_handler::<R>());
-    cx.spawn(|mut cx| async move {
-        while let Some(request) = requests.recv().await {
-            if let Err(err) = handler.handle(request, client.clone(), &mut cx).await {
-                log::error!("error handling request: {:?}", err);
-            }
-        }
-    })
-    .detach();
-}
-
-pub fn spawn_message_handler<H, M>(
-    handler: H,
-    client: &Arc<RpcClient>,
-    cx: &mut gpui::MutableAppContext,
-) where
+pub fn handle_messages<H, M>(handler: H, client: &Arc<RpcClient>, cx: &mut gpui::MutableAppContext)
+where
     H: 'static + for<'a> MessageHandler<'a, M>,
     M: proto::EnvelopedMessage,
 {

zed/src/workspace.rs 🔗

@@ -4,7 +4,7 @@ pub mod pane_group;
 use crate::{
     editor::{Buffer, Editor},
     language::LanguageRegistry,
-    rpc_client::{Request, RpcClient},
+    rpc_client::{RpcClient, TypedEnvelope},
     settings::Settings,
     time::ReplicaId,
     util::{self, SurfResultExt as _},
@@ -46,7 +46,7 @@ pub fn init(cx: &mut MutableAppContext, rpc_client: Arc<RpcClient>) {
     ]);
     pane::init(cx);
 
-    util::spawn_request_handler(handle_open_buffer, &rpc_client, cx);
+    util::handle_messages(handle_open_buffer, &rpc_client, cx);
 }
 
 pub struct OpenParams {
@@ -109,12 +109,12 @@ fn open_paths(params: &OpenParams, cx: &mut MutableAppContext) {
 }
 
 async fn handle_open_buffer(
-    mut request: Request<proto::OpenBuffer>,
+    request: TypedEnvelope<proto::OpenBuffer>,
     rpc_client: Arc<RpcClient>,
     cx: &mut AsyncAppContext,
 ) -> anyhow::Result<()> {
-    let body = request.body();
-    dbg!(body.path);
+    let payload = request.payload();
+    dbg!(&payload.path);
     rpc_client
         .respond(request, proto::OpenBufferResponse { buffer: None })
         .await?;