@@ -137,6 +137,34 @@ impl PeerExt for Peer {
}
}
+#[async_trait]
+impl PeerExt for zrpc::peer2::Peer {
+ async fn sign_out(
+ self: &Arc<Self>,
+ connection_id: zrpc::ConnectionId,
+ state: &AppState,
+ ) -> tide::Result<()> {
+ self.disconnect(connection_id).await;
+ let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
+ for worktree_id in worktree_ids {
+ let state = state.rpc.read().await;
+ if let Some(worktree) = state.worktrees.get(&worktree_id) {
+ rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
+ self.send(
+ conn_id,
+ proto::RemovePeer {
+ worktree_id,
+ peer_id: connection_id.0,
+ },
+ )
+ })
+ .await?;
+ }
+ }
+ Ok(())
+ }
+}
+
pub fn build_client(client_id: &str, client_secret: &str) -> Client {
Client::new(
ClientId::new(client_id.to_string()),
@@ -9,8 +9,11 @@ use async_tungstenite::{
tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
WebSocketStream,
};
+use futures::{future::BoxFuture, FutureExt};
+use postage::prelude::Stream as _;
use sha1::{Digest as _, Sha1};
use std::{
+ any::{Any, TypeId},
collections::{HashMap, HashSet},
future::Future,
mem,
@@ -32,6 +35,119 @@ use zrpc::{
type ReplicaId = u16;
+type Handler = Box<
+ dyn Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
+>;
+
+#[derive(Default)]
+struct ServerBuilder {
+ handlers: Vec<Handler>,
+ handler_types: HashSet<TypeId>,
+}
+
+impl ServerBuilder {
+ pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
+ where
+ F: 'static + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
+ Fut: 'static + Send + Future<Output = ()>,
+ M: EnvelopedMessage,
+ {
+ if self.handler_types.insert(TypeId::of::<M>()) {
+ panic!("registered a handler for the same message twice");
+ }
+
+ self.handlers
+ .push(Box::new(move |untyped_envelope, server| {
+ if let Some(typed_envelope) = untyped_envelope.take() {
+ match typed_envelope.downcast::<TypedEnvelope<M>>() {
+ Ok(typed_envelope) => Some((handler)(typed_envelope, server).boxed()),
+ Err(envelope) => {
+ *untyped_envelope = Some(envelope);
+ None
+ }
+ }
+ } else {
+ None
+ }
+ }));
+ self
+ }
+
+ pub fn build(self, rpc: Arc<zrpc::peer2::Peer>, state: Arc<AppState>) -> Arc<Server> {
+ Arc::new(Server {
+ rpc,
+ state,
+ handlers: self.handlers,
+ })
+ }
+}
+
+struct Server {
+ rpc: Arc<zrpc::peer2::Peer>,
+ state: Arc<AppState>,
+ handlers: Vec<Handler>,
+}
+
+impl Server {
+ pub async fn add_connection<Conn>(
+ self: &Arc<Self>,
+ connection: Conn,
+ addr: String,
+ user_id: UserId,
+ ) where
+ Conn: 'static
+ + futures::Sink<WebSocketMessage, Error = WebSocketError>
+ + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ + Send
+ + Unpin,
+ {
+ let this = self.clone();
+ let (connection_id, handle_io, mut incoming_rx) = this.rpc.add_connection(connection).await;
+ this.state
+ .rpc
+ .write()
+ .await
+ .add_connection(connection_id, user_id);
+
+ let handle_io = handle_io.fuse();
+ futures::pin_mut!(handle_io);
+ loop {
+ let next_message = incoming_rx.recv().fuse();
+ futures::pin_mut!(next_message);
+ futures::select_biased! {
+ message = next_message => {
+ if let Some(message) = message {
+ let mut message = Some(message);
+ for handler in &this.handlers {
+ if let Some(future) = (handler)(&mut message, this.clone()) {
+ future.await;
+ break;
+ }
+ }
+
+ if let Some(message) = message {
+ log::warn!("unhandled message: {:?}", message);
+ }
+ } else {
+ log::info!("rpc connection closed {:?}", addr);
+ break;
+ }
+ }
+ handle_io = handle_io => {
+ if let Err(err) = handle_io {
+ log::error!("error handling rpc connection {:?} - {:?}", addr, err);
+ }
+ break;
+ }
+ }
+ }
+
+ if let Err(err) = this.rpc.sign_out(connection_id, &this.state).await {
+ log::error!("error signing out connection {:?} - {:?}", addr, err);
+ }
+ }
+}
+
#[derive(Default)]
pub struct State {
connections: HashMap<ConnectionId, Connection>,