Start on a peer2 module with an alternative implementation

Nathan Sobo created

Change summary

zrpc/proto/zed.proto |  58 +++--
zrpc/src/lib.rs      |   1 
zrpc/src/peer.rs     |   6 
zrpc/src/peer2.rs    | 470 ++++++++++++++++++++++++++++++++++++++++++++++
zrpc/src/proto.rs    |  20 +
5 files changed, 520 insertions(+), 35 deletions(-)

Detailed changes

zrpc/proto/zed.proto 🔗

@@ -6,35 +6,45 @@ message Envelope {
     optional uint32 responding_to = 2;
     optional uint32 original_sender_id = 3;
     oneof payload {
-        Auth auth = 4;
-        AuthResponse auth_response = 5;
-        ShareWorktree share_worktree = 6;
-        ShareWorktreeResponse share_worktree_response = 7;
-        OpenWorktree open_worktree = 8;
-        OpenWorktreeResponse open_worktree_response = 9;
-        UpdateWorktree update_worktree = 10;
-        CloseWorktree close_worktree = 11;
-        OpenBuffer open_buffer = 12;
-        OpenBufferResponse open_buffer_response = 13;
-        CloseBuffer close_buffer = 14;
-        UpdateBuffer update_buffer = 15;
-        SaveBuffer save_buffer = 16;
-        BufferSaved buffer_saved = 17;
-        AddPeer add_peer = 18;
-        RemovePeer remove_peer = 19;
-        GetChannels get_channels = 20;
-        GetChannelsResponse get_channels_response = 21;
-        GetUsers get_users = 22;
-        GetUsersResponse get_users_response = 23;
-        JoinChannel join_channel = 24;
-        JoinChannelResponse join_channel_response = 25;
-        SendChannelMessage send_channel_message = 26;
-        ChannelMessageSent channel_message_sent = 27;
+        Ping ping = 4;
+        Pong pong = 5;
+        Auth auth = 6;
+        AuthResponse auth_response = 7;
+        ShareWorktree share_worktree = 8;
+        ShareWorktreeResponse share_worktree_response = 9;
+        OpenWorktree open_worktree = 10;
+        OpenWorktreeResponse open_worktree_response = 11;
+        UpdateWorktree update_worktree = 12;
+        CloseWorktree close_worktree = 13;
+        OpenBuffer open_buffer = 14;
+        OpenBufferResponse open_buffer_response = 15;
+        CloseBuffer close_buffer = 16;
+        UpdateBuffer update_buffer = 17;
+        SaveBuffer save_buffer = 18;
+        BufferSaved buffer_saved = 19;
+        AddPeer add_peer = 20;
+        RemovePeer remove_peer = 21;
+        GetChannels get_channels = 22;
+        GetChannelsResponse get_channels_response = 23;
+        GetUsers get_users = 24;
+        GetUsersResponse get_users_response = 25;
+        JoinChannel join_channel = 26;
+        JoinChannelResponse join_channel_response = 27;
+        SendChannelMessage send_channel_message = 28;
+        ChannelMessageSent channel_message_sent = 29;
     }
 }
 
 // Messages
 
+message Ping {
+    int32 id = 1;
+}
+
+message Pong {
+    int32 id = 2;
+}
+
 message Auth {
     int32 user_id = 1;
     string access_token = 2;

zrpc/src/lib.rs 🔗

@@ -1,5 +1,6 @@
 pub mod auth;
 mod peer;
+mod peer2;
 pub mod proto;
 #[cfg(any(test, feature = "test-support"))]
 pub mod test;

zrpc/src/peer.rs 🔗

@@ -38,8 +38,8 @@ type ForegroundMessageHandler =
     Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
 
 pub struct Receipt<T> {
-    sender_id: ConnectionId,
-    message_id: u32,
+    pub sender_id: ConnectionId,
+    pub message_id: u32,
     payload_type: PhantomData<T>,
 }
 
@@ -172,7 +172,7 @@ impl Peer {
                 } else {
                     router.handle(connection_id, envelope.clone()).await;
                     if let Some(envelope) = proto::build_typed_envelope(connection_id, envelope) {
-                        broadcast_incoming_messages.send(envelope).await.ok();
+                        broadcast_incoming_messages.send(Arc::from(envelope)).await.ok();
                     } else {
                         log::error!("unable to construct a typed envelope");
                     }

zrpc/src/peer2.rs 🔗

@@ -0,0 +1,470 @@
+use crate::{
+    proto::{self, EnvelopedMessage, MessageStream, RequestMessage},
+    ConnectionId, PeerId, Receipt,
+};
+use anyhow::{anyhow, Context, Result};
+use async_lock::{Mutex, RwLock};
+use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
+use futures::{FutureExt, StreamExt};
+use postage::{
+    mpsc,
+    prelude::{Sink as _, Stream as _},
+};
+use std::{
+    any::Any,
+    collections::HashMap,
+    future::Future,
+    sync::{
+        atomic::{self, AtomicU32},
+        Arc,
+    },
+};
+
+pub struct Peer {
+    connections: RwLock<HashMap<ConnectionId, Connection>>,
+    next_connection_id: AtomicU32,
+}
+
+#[derive(Clone)]
+struct Connection {
+    outgoing_tx: mpsc::Sender<proto::Envelope>,
+    next_message_id: Arc<AtomicU32>,
+    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
+}
+
+impl Peer {
+    pub fn new() -> Arc<Self> {
+        Arc::new(Self {
+            connections: Default::default(),
+            next_connection_id: Default::default(),
+        })
+    }
+
+    pub async fn add_connection<Conn>(
+        self: &Arc<Self>,
+        conn: Conn,
+    ) -> (
+        ConnectionId,
+        impl Future<Output = anyhow::Result<()>> + Send,
+        mpsc::Receiver<Box<dyn Any + Sync + Send>>,
+    )
+    where
+        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
+            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+            + Send
+            + Unpin,
+    {
+        let (tx, rx) = conn.split();
+        let connection_id = ConnectionId(
+            self.next_connection_id
+                .fetch_add(1, atomic::Ordering::SeqCst),
+        );
+        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
+        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
+        let connection = Connection {
+            outgoing_tx,
+            next_message_id: Default::default(),
+            response_channels: Default::default(),
+        };
+        let mut writer = MessageStream::new(tx);
+        let mut reader = MessageStream::new(rx);
+
+        let response_channels = connection.response_channels.clone();
+        let handle_io = async move {
+            loop {
+                let read_message = reader.read_message().fuse();
+                futures::pin_mut!(read_message);
+                loop {
+                    futures::select_biased! {
+                        incoming = read_message => match incoming {
+                            Ok(incoming) => {
+                                if let Some(responding_to) = incoming.responding_to {
+                                    let channel = response_channels.lock().await.remove(&responding_to);
+                                    if let Some(mut tx) = channel {
+                                        tx.send(incoming).await.ok();
+                                    } else {
+                                        log::warn!("received RPC response to unknown request {}", responding_to);
+                                    }
+                                } else {
+                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
+                                        if incoming_tx.send(envelope).await.is_err() {
+                                            response_channels.lock().await.clear();
+                                            return Ok(())
+                                        }
+                                    } else {
+                                        log::error!("unable to construct a typed envelope");
+                                    }
+                                }
+
+                                break;
+                            }
+                            Err(error) => {
+                                response_channels.lock().await.clear();
+                                Err(error).context("received invalid RPC message")?;
+                            }
+                        },
+                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
+                            Some(outgoing) => {
+                                if let Err(result) = writer.write_message(&outgoing).await {
+                                    response_channels.lock().await.clear();
+                                    Err(result).context("failed to write RPC message")?;
+                                }
+                            }
+                            None => {
+                                response_channels.lock().await.clear();
+                                return Ok(())
+                            }
+                        }
+                    }
+                }
+            }
+        };
+
+        self.connections
+            .write()
+            .await
+            .insert(connection_id, connection);
+
+        (connection_id, handle_io, incoming_rx)
+    }
+
+    pub async fn disconnect(&self, connection_id: ConnectionId) {
+        self.connections.write().await.remove(&connection_id);
+    }
+
+    pub async fn reset(&self) {
+        self.connections.write().await.clear();
+    }
+
+    pub fn request<T: RequestMessage>(
+        self: &Arc<Self>,
+        receiver_id: ConnectionId,
+        request: T,
+    ) -> impl Future<Output = Result<T::Response>> {
+        self.request_internal(None, receiver_id, request)
+    }
+
+    pub fn forward_request<T: RequestMessage>(
+        self: &Arc<Self>,
+        sender_id: ConnectionId,
+        receiver_id: ConnectionId,
+        request: T,
+    ) -> impl Future<Output = Result<T::Response>> {
+        self.request_internal(Some(sender_id), receiver_id, request)
+    }
+
+    pub fn request_internal<T: RequestMessage>(
+        self: &Arc<Self>,
+        original_sender_id: Option<ConnectionId>,
+        receiver_id: ConnectionId,
+        request: T,
+    ) -> impl Future<Output = Result<T::Response>> {
+        let this = self.clone();
+        let (tx, mut rx) = mpsc::channel(1);
+        async move {
+            let mut connection = this.connection(receiver_id).await?;
+            let message_id = connection
+                .next_message_id
+                .fetch_add(1, atomic::Ordering::SeqCst);
+            connection
+                .response_channels
+                .lock()
+                .await
+                .insert(message_id, tx);
+            connection
+                .outgoing_tx
+                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
+                .await
+                .map_err(|_| anyhow!("connection was closed"))?;
+            let response = rx
+                .recv()
+                .await
+                .ok_or_else(|| anyhow!("connection was closed"))?;
+            T::Response::from_envelope(response)
+                .ok_or_else(|| anyhow!("received response of the wrong type"))
+        }
+    }
+
+    pub fn send<T: EnvelopedMessage>(
+        self: &Arc<Self>,
+        receiver_id: ConnectionId,
+        message: T,
+    ) -> impl Future<Output = Result<()>> {
+        let this = self.clone();
+        async move {
+            let mut connection = this.connection(receiver_id).await?;
+            let message_id = connection
+                .next_message_id
+                .fetch_add(1, atomic::Ordering::SeqCst);
+            connection
+                .outgoing_tx
+                .send(message.into_envelope(message_id, None, None))
+                .await?;
+            Ok(())
+        }
+    }
+
+    pub fn forward_send<T: EnvelopedMessage>(
+        self: &Arc<Self>,
+        sender_id: ConnectionId,
+        receiver_id: ConnectionId,
+        message: T,
+    ) -> impl Future<Output = Result<()>> {
+        let this = self.clone();
+        async move {
+            let mut connection = this.connection(receiver_id).await?;
+            let message_id = connection
+                .next_message_id
+                .fetch_add(1, atomic::Ordering::SeqCst);
+            connection
+                .outgoing_tx
+                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
+                .await?;
+            Ok(())
+        }
+    }
+
+    pub fn respond<T: RequestMessage>(
+        self: &Arc<Self>,
+        receipt: Receipt<T>,
+        response: T::Response,
+    ) -> impl Future<Output = Result<()>> {
+        let this = self.clone();
+        async move {
+            let mut connection = this.connection(receipt.sender_id).await?;
+            let message_id = connection
+                .next_message_id
+                .fetch_add(1, atomic::Ordering::SeqCst);
+            connection
+                .outgoing_tx
+                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
+                .await?;
+            Ok(())
+        }
+    }
+
+    fn connection(
+        self: &Arc<Self>,
+        connection_id: ConnectionId,
+    ) -> impl Future<Output = Result<Connection>> {
+        let this = self.clone();
+        async move {
+            let connections = this.connections.read().await;
+            let connection = connections
+                .get(&connection_id)
+                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
+            Ok(connection.clone())
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::{test, TypedEnvelope};
+
+    #[test]
+    fn test_request_response() {
+        smol::block_on(async move {
+            // create 2 clients connected to 1 server
+            let server = Peer::new();
+            let client1 = Peer::new();
+            let client2 = Peer::new();
+
+            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
+            let (client1_conn_id, io_task1, _) =
+                client1.add_connection(client1_to_server_conn).await;
+            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
+
+            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
+            let (client2_conn_id, io_task3, _) =
+                client2.add_connection(client2_to_server_conn).await;
+            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
+
+            smol::spawn(io_task1).detach();
+            smol::spawn(io_task2).detach();
+            smol::spawn(io_task3).detach();
+            smol::spawn(io_task4).detach();
+            smol::spawn(handle_messages(incoming1, server.clone())).detach();
+            smol::spawn(handle_messages(incoming2, server.clone())).detach();
+
+            assert_eq!(
+                client1
+                    .request(client1_conn_id, proto::Ping { id: 1 },)
+                    .await
+                    .unwrap(),
+                proto::Pong { id: 1 }
+            );
+
+            assert_eq!(
+                client2
+                    .request(client2_conn_id, proto::Ping { id: 2 },)
+                    .await
+                    .unwrap(),
+                proto::Pong { id: 2 }
+            );
+
+            assert_eq!(
+                client1
+                    .request(
+                        client1_conn_id,
+                        proto::OpenBuffer {
+                            worktree_id: 1,
+                            path: "path/one".to_string(),
+                        },
+                    )
+                    .await
+                    .unwrap(),
+                proto::OpenBufferResponse {
+                    buffer: Some(proto::Buffer {
+                        id: 101,
+                        content: "path/one content".to_string(),
+                        history: vec![],
+                        selections: vec![],
+                    }),
+                }
+            );
+
+            assert_eq!(
+                client2
+                    .request(
+                        client2_conn_id,
+                        proto::OpenBuffer {
+                            worktree_id: 2,
+                            path: "path/two".to_string(),
+                        },
+                    )
+                    .await
+                    .unwrap(),
+                proto::OpenBufferResponse {
+                    buffer: Some(proto::Buffer {
+                        id: 102,
+                        content: "path/two content".to_string(),
+                        history: vec![],
+                        selections: vec![],
+                    }),
+                }
+            );
+
+            client1.disconnect(client1_conn_id).await;
+            client2.disconnect(client1_conn_id).await;
+
+            async fn handle_messages(
+                mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
+                peer: Arc<Peer>,
+            ) -> Result<()> {
+                while let Some(envelope) = messages.next().await {
+                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
+                        let receipt = envelope.receipt();
+                        peer.respond(
+                            receipt,
+                            proto::Pong {
+                                id: envelope.payload.id,
+                            },
+                        )
+                        .await?
+                    } else if let Some(envelope) =
+                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
+                    {
+                        let message = &envelope.payload;
+                        let receipt = envelope.receipt();
+                        let response = match message.path.as_str() {
+                            "path/one" => {
+                                assert_eq!(message.worktree_id, 1);
+                                proto::OpenBufferResponse {
+                                    buffer: Some(proto::Buffer {
+                                        id: 101,
+                                        content: "path/one content".to_string(),
+                                        history: vec![],
+                                        selections: vec![],
+                                    }),
+                                }
+                            }
+                            "path/two" => {
+                                assert_eq!(message.worktree_id, 2);
+                                proto::OpenBufferResponse {
+                                    buffer: Some(proto::Buffer {
+                                        id: 102,
+                                        content: "path/two content".to_string(),
+                                        history: vec![],
+                                        selections: vec![],
+                                    }),
+                                }
+                            }
+                            _ => {
+                                panic!("unexpected path {}", message.path);
+                            }
+                        };
+
+                        peer.respond(receipt, response).await?
+                    } else {
+                        panic!("unknown message type");
+                    }
+                }
+
+                Ok(())
+            }
+        });
+    }
+
+    #[test]
+    fn test_disconnect() {
+        smol::block_on(async move {
+            let (client_conn, mut server_conn) = test::Channel::bidirectional();
+
+            let client = Peer::new();
+            let (connection_id, io_handler, mut incoming) =
+                client.add_connection(client_conn).await;
+
+            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
+            smol::spawn(async move {
+                io_handler.await.ok();
+                io_ended_tx.send(()).await.unwrap();
+            })
+            .detach();
+
+            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
+            smol::spawn(async move {
+                incoming.next().await;
+                messages_ended_tx.send(()).await.unwrap();
+            })
+            .detach();
+
+            client.disconnect(connection_id).await;
+
+            io_ended_rx.recv().await;
+            messages_ended_rx.recv().await;
+            assert!(
+                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
+                    .await
+                    .is_err()
+            );
+        });
+    }
+
+    #[test]
+    fn test_io_error() {
+        smol::block_on(async move {
+            let (client_conn, server_conn) = test::Channel::bidirectional();
+            drop(server_conn);
+
+            let client = Peer::new();
+            let (connection_id, io_handler, mut incoming) =
+                client.add_connection(client_conn).await;
+            smol::spawn(io_handler).detach();
+            smol::spawn(async move { incoming.next().await }).detach();
+
+            let err = client
+                .request(
+                    connection_id,
+                    proto::Auth {
+                        user_id: 42,
+                        access_token: "token".to_string(),
+                    },
+                )
+                .await
+                .unwrap_err();
+            assert_eq!(err.to_string(), "connection was closed");
+        });
+    }
+}

zrpc/src/proto.rs 🔗

@@ -4,7 +4,6 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSock
 use futures::{SinkExt as _, StreamExt as _};
 use prost::Message;
 use std::any::Any;
-use std::sync::Arc;
 use std::{
     io,
     time::{Duration, SystemTime, UNIX_EPOCH},
@@ -34,14 +33,16 @@ pub trait RequestMessage: EnvelopedMessage {
 
 macro_rules! messages {
     ($($name:ident),* $(,)?) => {
-        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Arc<dyn Any + Send + Sync>> {
+        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn Any + Send + Sync>> {
             match envelope.payload {
-                $(Some(envelope::Payload::$name(payload)) => Some(Arc::new(TypedEnvelope {
-                    sender_id,
-                    original_sender_id: envelope.original_sender_id.map(PeerId),
-                    message_id: envelope.id,
-                    payload,
-                })), )*
+                $(Some(envelope::Payload::$name(payload)) => {
+                    Some(Box::new(TypedEnvelope {
+                        sender_id,
+                        original_sender_id: envelope.original_sender_id.map(PeerId),
+                        message_id: envelope.id,
+                        payload,
+                    }))
+                }, )*
                 _ => None
             }
         }
@@ -116,6 +117,8 @@ messages!(
     OpenBufferResponse,
     OpenWorktree,
     OpenWorktreeResponse,
+    Ping,
+    Pong,
     RemovePeer,
     SaveBuffer,
     SendChannelMessage,
@@ -132,6 +135,7 @@ request_messages!(
     (JoinChannel, JoinChannelResponse),
     (OpenBuffer, OpenBufferResponse),
     (OpenWorktree, OpenWorktreeResponse),
+    (Ping, Pong),
     (SaveBuffer, BufferSaved),
     (ShareWorktree, ShareWorktreeResponse),
 );