Implement broadcast of typed envelopes

Nathan Sobo and Max Brunsfeld created

This required a rework of the macro so that we can always construct a typed envelope from our list of available message types from incoming protobuf envelopes.

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

zed/src/channel.rs  |  44 ++----------------
zed/src/worktree.rs |  18 +++---
zrpc/src/peer.rs    |  17 +++++--
zrpc/src/proto.rs   | 111 ++++++++++++++++++++++++++++++++++++----------
4 files changed, 113 insertions(+), 77 deletions(-)

Detailed changes

zed/src/channel.rs 🔗

@@ -1,14 +1,11 @@
 use crate::rpc::{self, Client};
-use anyhow::{anyhow, Result};
 use futures::StreamExt;
-use gpui::{
-    AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
-};
+use gpui::{Entity, ModelContext, Task, WeakModelHandle};
 use std::{
     collections::{HashMap, VecDeque},
     sync::Arc,
 };
-use zrpc::{proto::ChannelMessageSent, ForegroundRouter, Router, TypedEnvelope};
+use zrpc::{proto::ChannelMessageSent, TypedEnvelope};
 
 pub struct ChannelList {
     available_channels: Vec<ChannelDetails>,
@@ -31,49 +28,20 @@ pub struct Channel {
 pub struct ChannelMessage {
     id: u64,
 }
-enum Event {}
+pub enum Event {}
 
 impl Entity for ChannelList {
     type Event = Event;
 }
 
 impl ChannelList {
-    fn new(
-        rpc: Arc<rpc::Client>,
-        router: &mut ForegroundRouter,
-        cx: &mut ModelContext<Self>,
-    ) -> Self {
-        // Subscribe to messages.
-        let this = cx.handle().downgrade();
-
-        // rpc.on_message(
-        //     router,
-        //     |envelope, rpc, cx: &mut AsyncAppContext| async move {
-        //         cx.update(|cx| {
-        //             if let Some(this) = this.upgrade(cx) {
-        //                 this.update(cx, |this, cx| this.receive_message(envelope, cx))
-        //             } else {
-        //                 Err(anyhow!("can't upgrade ChannelList handle"))
-        //             }
-        //         })
-        //     },
-        //     cx,
-        // );
-
+    fn new(rpc: Arc<rpc::Client>) -> Self {
         Self {
             available_channels: Default::default(),
             channels: Default::default(),
             rpc,
         }
     }
-
-    fn receive_message(
-        &mut self,
-        envelope: TypedEnvelope<ChannelMessageSent>,
-        cx: &mut ModelContext<Self>,
-    ) -> Result<()> {
-        Ok(())
-    }
 }
 
 impl Entity for Channel {
@@ -82,8 +50,8 @@ impl Entity for Channel {
 
 impl Channel {
     pub fn new(details: ChannelDetails, rpc: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
-        let messages = rpc.subscribe();
-        let receive_messages = cx.spawn_weak(|this, cx| async move {
+        let mut messages = rpc.subscribe();
+        let receive_messages = cx.spawn_weak(|this, mut cx| async move {
             while let Some(message) = messages.next().await {
                 if let Some(this) = this.upgrade(&cx) {
                     this.update(&mut cx, |this, cx| this.message_received(&message, cx));

zed/src/worktree.rs 🔗

@@ -18,7 +18,7 @@ use futures::{Stream, StreamExt};
 pub use fuzzy::{match_paths, PathMatch};
 use gpui::{
     executor, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
-    Task, WeakModelHandle,
+    Task, UpgradeModelHandle, WeakModelHandle,
 };
 use lazy_static::lazy_static;
 use parking_lot::Mutex;
@@ -373,7 +373,7 @@ impl Worktree {
                 let buffer = worktree
                     .open_buffers
                     .get(&buffer_id)
-                    .and_then(|buf| buf.upgrade(&cx))
+                    .and_then(|buf| buf.upgrade(cx))
                     .ok_or_else(|| {
                         anyhow!("invalid buffer {} in update buffer message", buffer_id)
                     })?;
@@ -382,7 +382,7 @@ impl Worktree {
             Worktree::Remote(worktree) => match worktree.open_buffers.get_mut(&buffer_id) {
                 Some(RemoteBuffer::Operations(pending_ops)) => pending_ops.extend(ops),
                 Some(RemoteBuffer::Loaded(buffer)) => {
-                    if let Some(buffer) = buffer.upgrade(&cx) {
+                    if let Some(buffer) = buffer.upgrade(cx) {
                         buffer.update(cx, |buffer, cx| buffer.apply_ops(ops, cx))?;
                     } else {
                         worktree
@@ -410,7 +410,7 @@ impl Worktree {
             if let Some(buffer) = worktree
                 .open_buffers
                 .get(&(message.buffer_id as usize))
-                .and_then(|buf| buf.upgrade(&cx))
+                .and_then(|buf| buf.upgrade(cx))
             {
                 buffer.update(cx, |buffer, cx| {
                     let version = message.version.try_into()?;
@@ -480,7 +480,7 @@ impl Worktree {
 
         let mut buffers_to_delete = Vec::new();
         for (buffer_id, buffer) in open_buffers {
-            if let Some(buffer) = buffer.upgrade(&cx) {
+            if let Some(buffer) = buffer.upgrade(cx) {
                 buffer.update(cx, |buffer, cx| {
                     let buffer_is_clean = !buffer.is_dirty();
 
@@ -633,7 +633,7 @@ impl LocalWorktree {
 
             cx.spawn_weak(|this, mut cx| async move {
                 while let Ok(scan_state) = scan_states_rx.recv().await {
-                    if let Some(handle) = cx.read(|cx| this.upgrade(&cx)) {
+                    if let Some(handle) = cx.read(|cx| this.upgrade(cx)) {
                         let to_send = handle.update(&mut cx, |this, cx| {
                             last_scan_state_tx.blocking_send(scan_state).ok();
                             this.poll_snapshot(cx);
@@ -778,7 +778,7 @@ impl LocalWorktree {
             .ok_or_else(|| anyhow!("unknown peer {:?}", peer_id))?;
         self.shared_buffers.remove(&peer_id);
         for (_, buffer) in &self.open_buffers {
-            if let Some(buffer) = buffer.upgrade(&cx) {
+            if let Some(buffer) = buffer.upgrade(cx) {
                 buffer.update(cx, |buffer, cx| buffer.remove_peer(replica_id, cx));
             }
         }
@@ -1078,7 +1078,7 @@ impl RemoteWorktree {
             .remove(&peer_id)
             .ok_or_else(|| anyhow!("unknown peer {:?}", peer_id))?;
         for (_, buffer) in &self.open_buffers {
-            if let Some(buffer) = buffer.upgrade(&cx) {
+            if let Some(buffer) = buffer.upgrade(cx) {
                 buffer.update(cx, |buffer, cx| buffer.remove_peer(replica_id, cx));
             }
         }
@@ -1093,7 +1093,7 @@ enum RemoteBuffer {
 }
 
 impl RemoteBuffer {
-    fn upgrade(&self, cx: impl AsRef<AppContext>) -> Option<ModelHandle<Buffer>> {
+    fn upgrade(&self, cx: &impl UpgradeModelHandle) -> Option<ModelHandle<Buffer>> {
         match self {
             Self::Operations(_) => None,
             Self::Loaded(buffer) => buffer.upgrade(cx),

zrpc/src/peer.rs 🔗

@@ -45,7 +45,7 @@ pub struct Receipt<T> {
 
 pub struct TypedEnvelope<T> {
     pub sender_id: ConnectionId,
-    original_sender_id: Option<PeerId>,
+    pub original_sender_id: Option<PeerId>,
     pub message_id: u32,
     pub payload: T,
 }
@@ -158,18 +158,25 @@ impl Peer {
             }
         };
 
+        let mut broadcast_incoming_messages = self.incoming_messages.clone();
         let response_channels = connection.response_channels.clone();
         let handle_messages = async move {
-            while let Some(message) = incoming_rx.recv().await {
-                if let Some(responding_to) = message.responding_to {
+            while let Some(envelope) = incoming_rx.recv().await {
+                if let Some(responding_to) = envelope.responding_to {
                     let channel = response_channels.lock().await.remove(&responding_to);
                     if let Some(mut tx) = channel {
-                        tx.send(message).await.ok();
+                        tx.send(envelope).await.ok();
                     } else {
                         log::warn!("received RPC response to unknown request {}", responding_to);
                     }
                 } else {
-                    router.handle(connection_id, message).await;
+                    router.handle(connection_id, envelope.clone()).await;
+                    match proto::build_typed_envelope(connection_id, envelope) {
+                        Ok(envelope) => {
+                            broadcast_incoming_messages.send(envelope).await.ok();
+                        }
+                        Err(error) => log::error!("{}", error),
+                    }
                 }
             }
             response_channels.lock().await.clear();

zrpc/src/proto.rs 🔗

@@ -1,6 +1,10 @@
+use super::{ConnectionId, PeerId, TypedEnvelope};
+use anyhow::{anyhow, Result};
 use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
 use futures::{SinkExt as _, StreamExt as _};
 use prost::Message;
+use std::any::Any;
+use std::sync::Arc;
 use std::{
     io,
     time::{Duration, SystemTime, UNIX_EPOCH},
@@ -24,6 +28,59 @@ pub trait RequestMessage: EnvelopedMessage {
     type Response: EnvelopedMessage;
 }
 
+macro_rules! messages {
+    ($($name:ident),*) => {
+        fn unicast_message_into_typed_envelope(sender_id: ConnectionId, envelope: &mut Envelope) -> Option<Arc<dyn Any + Send + Sync>> {
+            match &mut envelope.payload {
+                $(payload @ Some(envelope::Payload::$name(_)) => Some(Arc::new(TypedEnvelope {
+                    sender_id,
+                    original_sender_id: envelope.original_sender_id.map(PeerId),
+                    message_id: envelope.id,
+                    payload: payload.take().unwrap(),
+                })), )*
+                _ => None
+            }
+        }
+
+        $(
+            message!($name);
+        )*
+    };
+}
+
+macro_rules! request_messages {
+    ($(($request_name:ident, $response_name:ident)),*) => {
+        fn request_message_into_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Arc<dyn Any + Send + Sync>> {
+            match envelope.payload {
+                $(
+                    Some(envelope::Payload::$request_name(payload)) => Some(Arc::new(TypedEnvelope {
+                        sender_id,
+                        original_sender_id: envelope.original_sender_id.map(PeerId),
+                        message_id: envelope.id,
+                        payload,
+                    })),
+                    Some(envelope::Payload::$response_name(payload)) => Some(Arc::new(TypedEnvelope {
+                        sender_id,
+                        original_sender_id: envelope.original_sender_id.map(PeerId),
+                        message_id: envelope.id,
+                        payload,
+                    })),
+                )*
+                _ => None
+            }
+        }
+
+        $(
+            message!($request_name);
+            message!($response_name);
+        )*
+
+        $(impl RequestMessage for $request_name {
+            type Response = $response_name;
+        })*
+    };
+}
+
 macro_rules! message {
     ($name:ident) => {
         impl EnvelopedMessage for $name {
@@ -58,33 +115,37 @@ macro_rules! message {
     };
 }
 
-macro_rules! request_message {
-    ($req:ident, $resp:ident) => {
-        message!($req);
-        message!($resp);
-        impl RequestMessage for $req {
-            type Response = $resp;
-        }
-    };
+messages!(
+    UpdateWorktree,
+    CloseWorktree,
+    CloseBuffer,
+    UpdateBuffer,
+    AddPeer,
+    RemovePeer,
+    SendChannelMessage,
+    ChannelMessageSent
+);
+
+request_messages!(
+    (Auth, AuthResponse),
+    (ShareWorktree, ShareWorktreeResponse),
+    (OpenWorktree, OpenWorktreeResponse),
+    (OpenBuffer, OpenBufferResponse),
+    (SaveBuffer, BufferSaved),
+    (GetChannels, GetChannelsResponse),
+    (JoinChannel, JoinChannelResponse),
+    (GetUsers, GetUsersResponse)
+);
+
+pub fn build_typed_envelope(
+    sender_id: ConnectionId,
+    mut envelope: Envelope,
+) -> Result<Arc<dyn Any + Send + Sync>> {
+    unicast_message_into_typed_envelope(sender_id, &mut envelope)
+        .or_else(|| request_message_into_typed_envelope(sender_id, envelope))
+        .ok_or_else(|| anyhow!("unrecognized payload type"))
 }
 
-request_message!(Auth, AuthResponse);
-request_message!(ShareWorktree, ShareWorktreeResponse);
-request_message!(OpenWorktree, OpenWorktreeResponse);
-message!(UpdateWorktree);
-message!(CloseWorktree);
-request_message!(OpenBuffer, OpenBufferResponse);
-message!(CloseBuffer);
-message!(UpdateBuffer);
-request_message!(SaveBuffer, BufferSaved);
-message!(AddPeer);
-message!(RemovePeer);
-request_message!(GetChannels, GetChannelsResponse);
-request_message!(JoinChannel, JoinChannelResponse);
-request_message!(GetUsers, GetUsersResponse);
-message!(SendChannelMessage);
-message!(ChannelMessageSent);
-
 /// A stream of protobuf messages.
 pub struct MessageStream<S> {
     stream: S,