diff --git a/server/src/auth.rs b/server/src/auth.rs index d61428fa371a0f26bded6a42250aab37a1e1d9f5..ac326b15defe9cc89ae67886e6d0add79e9c111f 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -137,6 +137,34 @@ impl PeerExt for Peer { } } +#[async_trait] +impl PeerExt for zrpc::peer2::Peer { + async fn sign_out( + self: &Arc, + connection_id: zrpc::ConnectionId, + state: &AppState, + ) -> tide::Result<()> { + self.disconnect(connection_id).await; + let worktree_ids = state.rpc.write().await.remove_connection(connection_id); + for worktree_id in worktree_ids { + let state = state.rpc.read().await; + if let Some(worktree) = state.worktrees.get(&worktree_id) { + rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| { + self.send( + conn_id, + proto::RemovePeer { + worktree_id, + peer_id: connection_id.0, + }, + ) + }) + .await?; + } + } + Ok(()) + } +} + pub fn build_client(client_id: &str, client_secret: &str) -> Client { Client::new( ClientId::new(client_id.to_string()), diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 8696f0369130ef71b0f5e1f2ab9c9c7ed931091d..b7be90d3485cbd8e2e0a611e327efe1b5e02dd05 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -9,8 +9,11 @@ use async_tungstenite::{ tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage}, WebSocketStream, }; +use futures::{future::BoxFuture, FutureExt}; +use postage::prelude::Stream as _; use sha1::{Digest as _, Sha1}; use std::{ + any::{Any, TypeId}, collections::{HashMap, HashSet}, future::Future, mem, @@ -32,6 +35,119 @@ use zrpc::{ type ReplicaId = u16; +type Handler = Box< + dyn Fn(&mut Option>, Arc) -> Option>, +>; + +#[derive(Default)] +struct ServerBuilder { + handlers: Vec, + handler_types: HashSet, +} + +impl ServerBuilder { + pub fn on_message(&mut self, handler: F) -> &mut Self + where + F: 'static + Fn(Box>, Arc) -> Fut, + Fut: 'static + Send + Future, + M: EnvelopedMessage, + { + if self.handler_types.insert(TypeId::of::()) { + panic!("registered a handler for the same message twice"); + } + + self.handlers + .push(Box::new(move |untyped_envelope, server| { + if let Some(typed_envelope) = untyped_envelope.take() { + match typed_envelope.downcast::>() { + Ok(typed_envelope) => Some((handler)(typed_envelope, server).boxed()), + Err(envelope) => { + *untyped_envelope = Some(envelope); + None + } + } + } else { + None + } + })); + self + } + + pub fn build(self, rpc: Arc, state: Arc) -> Arc { + Arc::new(Server { + rpc, + state, + handlers: self.handlers, + }) + } +} + +struct Server { + rpc: Arc, + state: Arc, + handlers: Vec, +} + +impl Server { + pub async fn add_connection( + self: &Arc, + connection: Conn, + addr: String, + user_id: UserId, + ) where + Conn: 'static + + futures::Sink + + futures::Stream> + + Send + + Unpin, + { + let this = self.clone(); + let (connection_id, handle_io, mut incoming_rx) = this.rpc.add_connection(connection).await; + this.state + .rpc + .write() + .await + .add_connection(connection_id, user_id); + + let handle_io = handle_io.fuse(); + futures::pin_mut!(handle_io); + loop { + let next_message = incoming_rx.recv().fuse(); + futures::pin_mut!(next_message); + futures::select_biased! { + message = next_message => { + if let Some(message) = message { + let mut message = Some(message); + for handler in &this.handlers { + if let Some(future) = (handler)(&mut message, this.clone()) { + future.await; + break; + } + } + + if let Some(message) = message { + log::warn!("unhandled message: {:?}", message); + } + } else { + log::info!("rpc connection closed {:?}", addr); + break; + } + } + handle_io = handle_io => { + if let Err(err) = handle_io { + log::error!("error handling rpc connection {:?} - {:?}", addr, err); + } + break; + } + } + } + + if let Err(err) = this.rpc.sign_out(connection_id, &this.state).await { + log::error!("error signing out connection {:?} - {:?}", addr, err); + } + } +} + #[derive(Default)] pub struct State { connections: HashMap, diff --git a/zrpc/src/lib.rs b/zrpc/src/lib.rs index be3625e51f23765effb83b9079a774597553761f..67132cf299253bd63737bd0ad7474cde8ff9e3e9 100644 --- a/zrpc/src/lib.rs +++ b/zrpc/src/lib.rs @@ -1,6 +1,6 @@ pub mod auth; mod peer; -mod peer2; +pub mod peer2; pub mod proto; #[cfg(any(test, feature = "test-support"))] pub mod test;