Start using the new `zrpc::peer2::Peer` in Server

Antonio Scandurra created

Change summary

server/src/auth.rs |  28 +++++++++++
server/src/rpc.rs  | 116 ++++++++++++++++++++++++++++++++++++++++++++++++
zrpc/src/lib.rs    |   2 
3 files changed, 145 insertions(+), 1 deletion(-)

Detailed changes

server/src/auth.rs 🔗

@@ -137,6 +137,34 @@ impl PeerExt for Peer {
     }
 }
 
+#[async_trait]
+impl PeerExt for zrpc::peer2::Peer {
+    async fn sign_out(
+        self: &Arc<Self>,
+        connection_id: zrpc::ConnectionId,
+        state: &AppState,
+    ) -> tide::Result<()> {
+        self.disconnect(connection_id).await;
+        let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
+        for worktree_id in worktree_ids {
+            let state = state.rpc.read().await;
+            if let Some(worktree) = state.worktrees.get(&worktree_id) {
+                rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
+                    self.send(
+                        conn_id,
+                        proto::RemovePeer {
+                            worktree_id,
+                            peer_id: connection_id.0,
+                        },
+                    )
+                })
+                .await?;
+            }
+        }
+        Ok(())
+    }
+}
+
 pub fn build_client(client_id: &str, client_secret: &str) -> Client {
     Client::new(
         ClientId::new(client_id.to_string()),

server/src/rpc.rs 🔗

@@ -9,8 +9,11 @@ use async_tungstenite::{
     tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
     WebSocketStream,
 };
+use futures::{future::BoxFuture, FutureExt};
+use postage::prelude::Stream as _;
 use sha1::{Digest as _, Sha1};
 use std::{
+    any::{Any, TypeId},
     collections::{HashMap, HashSet},
     future::Future,
     mem,
@@ -32,6 +35,119 @@ use zrpc::{
 
 type ReplicaId = u16;
 
+type Handler = Box<
+    dyn Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
+>;
+
+#[derive(Default)]
+struct ServerBuilder {
+    handlers: Vec<Handler>,
+    handler_types: HashSet<TypeId>,
+}
+
+impl ServerBuilder {
+    pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
+    where
+        F: 'static + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
+        Fut: 'static + Send + Future<Output = ()>,
+        M: EnvelopedMessage,
+    {
+        if self.handler_types.insert(TypeId::of::<M>()) {
+            panic!("registered a handler for the same message twice");
+        }
+
+        self.handlers
+            .push(Box::new(move |untyped_envelope, server| {
+                if let Some(typed_envelope) = untyped_envelope.take() {
+                    match typed_envelope.downcast::<TypedEnvelope<M>>() {
+                        Ok(typed_envelope) => Some((handler)(typed_envelope, server).boxed()),
+                        Err(envelope) => {
+                            *untyped_envelope = Some(envelope);
+                            None
+                        }
+                    }
+                } else {
+                    None
+                }
+            }));
+        self
+    }
+
+    pub fn build(self, rpc: Arc<zrpc::peer2::Peer>, state: Arc<AppState>) -> Arc<Server> {
+        Arc::new(Server {
+            rpc,
+            state,
+            handlers: self.handlers,
+        })
+    }
+}
+
+struct Server {
+    rpc: Arc<zrpc::peer2::Peer>,
+    state: Arc<AppState>,
+    handlers: Vec<Handler>,
+}
+
+impl Server {
+    pub async fn add_connection<Conn>(
+        self: &Arc<Self>,
+        connection: Conn,
+        addr: String,
+        user_id: UserId,
+    ) where
+        Conn: 'static
+            + futures::Sink<WebSocketMessage, Error = WebSocketError>
+            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+            + Send
+            + Unpin,
+    {
+        let this = self.clone();
+        let (connection_id, handle_io, mut incoming_rx) = this.rpc.add_connection(connection).await;
+        this.state
+            .rpc
+            .write()
+            .await
+            .add_connection(connection_id, user_id);
+
+        let handle_io = handle_io.fuse();
+        futures::pin_mut!(handle_io);
+        loop {
+            let next_message = incoming_rx.recv().fuse();
+            futures::pin_mut!(next_message);
+            futures::select_biased! {
+                message = next_message => {
+                    if let Some(message) = message {
+                        let mut message = Some(message);
+                        for handler in &this.handlers {
+                            if let Some(future) = (handler)(&mut message, this.clone()) {
+                                future.await;
+                                break;
+                            }
+                        }
+
+                        if let Some(message) = message {
+                            log::warn!("unhandled message: {:?}", message);
+                        }
+                    } else {
+                        log::info!("rpc connection closed {:?}", addr);
+                        break;
+                    }
+                }
+                handle_io = handle_io => {
+                    if let Err(err) = handle_io {
+                        log::error!("error handling rpc connection {:?} - {:?}", addr, err);
+                    }
+                    break;
+                }
+            }
+        }
+
+        if let Err(err) = this.rpc.sign_out(connection_id, &this.state).await {
+            log::error!("error signing out connection {:?} - {:?}", addr, err);
+        }
+    }
+}
+
 #[derive(Default)]
 pub struct State {
     connections: HashMap<ConnectionId, Connection>,

zrpc/src/lib.rs 🔗

@@ -1,6 +1,6 @@
 pub mod auth;
 mod peer;
-mod peer2;
+pub mod peer2;
 pub mod proto;
 #[cfg(any(test, feature = "test-support"))]
 pub mod test;