Re-register message handlers in RPC server

Antonio Scandurra created

Change summary

server/src/rpc.rs | 305 ++++++++++++++++++++++++++----------------------
1 file changed, 164 insertions(+), 141 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -35,23 +35,26 @@ use zrpc::{
 
 type ReplicaId = u16;
 
-type Handler = Box<
+type MessageHandler = Box<
     dyn Send
         + Sync
-        + Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
+        + Fn(
+            &mut Option<Box<dyn Any + Send + Sync>>,
+            Arc<Server>,
+        ) -> Option<BoxFuture<'static, tide::Result<()>>>,
 >;
 
 #[derive(Default)]
 struct ServerBuilder {
-    handlers: Vec<Handler>,
+    handlers: Vec<MessageHandler>,
     handler_types: HashSet<TypeId>,
 }
 
 impl ServerBuilder {
-    pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
+    pub fn on_message<F, Fut, M>(mut self, handler: F) -> Self
     where
         F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
-        Fut: 'static + Send + Future<Output = ()>,
+        Fut: 'static + Send + Future<Output = tide::Result<()>>,
         M: EnvelopedMessage,
     {
         if self.handler_types.insert(TypeId::of::<M>()) {
@@ -87,7 +90,7 @@ impl ServerBuilder {
 pub struct Server {
     rpc: Arc<Peer>,
     state: Arc<AppState>,
-    handlers: Vec<Handler>,
+    handlers: Vec<MessageHandler>,
 }
 
 impl Server {
@@ -119,10 +122,16 @@ impl Server {
             futures::select_biased! {
                 message = next_message => {
                     if let Some(message) = message {
+                        let start_time = Instant::now();
+                        log::info!("RPC message received");
                         let mut message = Some(message);
                         for handler in &this.handlers {
                             if let Some(future) = (handler)(&mut message, this.clone()) {
-                                future.await;
+                                 if let Err(err) = future.await {
+                                    log::error!("error handling message: {:?}", err);
+                                } else {
+                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
+                                }
                                 break;
                             }
                         }
@@ -336,26 +345,24 @@ impl State {
 
 pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> {
     ServerBuilder::default()
-        // .on_message(share_worktree)
-        // .on_message(join_worktree)
-        // .on_message(update_worktree)
-        // .on_message(close_worktree)
-        // .on_message(open_buffer)
-        // .on_message(close_buffer)
-        // .on_message(update_buffer)
-        // .on_message(buffer_saved)
-        // .on_message(save_buffer)
-        // .on_message(get_channels)
-        // .on_message(get_users)
-        // .on_message(join_channel)
-        // .on_message(send_channel_message)
+        .on_message(share_worktree)
+        .on_message(join_worktree)
+        .on_message(update_worktree)
+        .on_message(close_worktree)
+        .on_message(open_buffer)
+        .on_message(close_buffer)
+        .on_message(update_buffer)
+        .on_message(buffer_saved)
+        .on_message(save_buffer)
+        .on_message(get_channels)
+        .on_message(get_users)
+        .on_message(join_channel)
+        .on_message(send_channel_message)
         .build(rpc, state)
 }
 
 pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
     let server = build_server(app.state(), rpc);
-
-    let rpc = rpc.clone();
     app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
         let user_id = request.ext::<UserId>().copied();
         let server = server.clone();
@@ -399,11 +406,10 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
 }
 
 async fn share_worktree(
-    mut request: TypedEnvelope<proto::ShareWorktree>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    mut request: Box<TypedEnvelope<proto::ShareWorktree>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
-    let mut state = state.rpc.write().await;
+    let mut state = server.state.rpc.write().await;
     let worktree_id = state.next_worktree_id;
     state.next_worktree_id += 1;
     let access_token = random_token();
@@ -428,26 +434,27 @@ async fn share_worktree(
         },
     );
 
-    rpc.respond(
-        request.receipt(),
-        proto::ShareWorktreeResponse {
-            worktree_id,
-            access_token,
-        },
-    )
-    .await?;
+    server
+        .rpc
+        .respond(
+            request.receipt(),
+            proto::ShareWorktreeResponse {
+                worktree_id,
+                access_token,
+            },
+        )
+        .await?;
     Ok(())
 }
 
 async fn join_worktree(
-    request: TypedEnvelope<proto::OpenWorktree>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::OpenWorktree>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
     let worktree_id = request.payload.worktree_id;
     let access_token = &request.payload.access_token;
 
-    let mut state = state.rpc.write().await;
+    let mut state = server.state.rpc.write().await;
     if let Some((peer_replica_id, worktree)) =
         state.join_worktree(request.sender_id, worktree_id, access_token)
     {
@@ -468,7 +475,7 @@ async fn join_worktree(
         }
 
         broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
-            rpc.send(
+            server.rpc.send(
                 conn_id,
                 proto::AddPeer {
                     worktree_id,
@@ -480,42 +487,45 @@ async fn join_worktree(
             )
         })
         .await?;
-        rpc.respond(
-            request.receipt(),
-            proto::OpenWorktreeResponse {
-                worktree_id,
-                worktree: Some(proto::Worktree {
-                    root_name: worktree.root_name.clone(),
-                    entries: worktree.entries.values().cloned().collect(),
-                }),
-                replica_id: peer_replica_id as u32,
-                peers,
-            },
-        )
-        .await?;
+        server
+            .rpc
+            .respond(
+                request.receipt(),
+                proto::OpenWorktreeResponse {
+                    worktree_id,
+                    worktree: Some(proto::Worktree {
+                        root_name: worktree.root_name.clone(),
+                        entries: worktree.entries.values().cloned().collect(),
+                    }),
+                    replica_id: peer_replica_id as u32,
+                    peers,
+                },
+            )
+            .await?;
     } else {
-        rpc.respond(
-            request.receipt(),
-            proto::OpenWorktreeResponse {
-                worktree_id,
-                worktree: None,
-                replica_id: 0,
-                peers: Vec::new(),
-            },
-        )
-        .await?;
+        server
+            .rpc
+            .respond(
+                request.receipt(),
+                proto::OpenWorktreeResponse {
+                    worktree_id,
+                    worktree: None,
+                    replica_id: 0,
+                    peers: Vec::new(),
+                },
+            )
+            .await?;
     }
 
     Ok(())
 }
 
 async fn update_worktree(
-    request: TypedEnvelope<proto::UpdateWorktree>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::UpdateWorktree>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
     {
-        let mut state = state.rpc.write().await;
+        let mut state = server.state.rpc.write().await;
         let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
         for entry_id in &request.payload.removed_entries {
             worktree.entries.remove(&entry_id);
@@ -526,18 +536,17 @@ async fn update_worktree(
         }
     }
 
-    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
+    broadcast_in_worktree(request.payload.worktree_id, &request, &server).await?;
     Ok(())
 }
 
 async fn close_worktree(
-    request: TypedEnvelope<proto::CloseWorktree>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::CloseWorktree>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
     let connection_ids;
     {
-        let mut state = state.rpc.write().await;
+        let mut state = server.state.rpc.write().await;
         let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
         connection_ids = worktree.connection_ids();
         if worktree.host_connection_id == Some(request.sender_id) {
@@ -548,7 +557,7 @@ async fn close_worktree(
     }
 
     broadcast(request.sender_id, connection_ids, |conn_id| {
-        rpc.send(
+        server.rpc.send(
             conn_id,
             proto::RemovePeer {
                 worktree_id: request.payload.worktree_id,
@@ -562,53 +571,55 @@ async fn close_worktree(
 }
 
 async fn open_buffer(
-    request: TypedEnvelope<proto::OpenBuffer>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::OpenBuffer>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
     let receipt = request.receipt();
     let worktree_id = request.payload.worktree_id;
-    let host_connection_id = state
+    let host_connection_id = server
+        .state
         .rpc
         .read()
         .await
         .read_worktree(worktree_id, request.sender_id)?
         .host_connection_id()?;
 
-    let response = rpc
+    let response = server
+        .rpc
         .forward_request(request.sender_id, host_connection_id, request.payload)
         .await?;
-    rpc.respond(receipt, response).await?;
+    server.rpc.respond(receipt, response).await?;
     Ok(())
 }
 
 async fn close_buffer(
-    request: TypedEnvelope<proto::CloseBuffer>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::CloseBuffer>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
-    let host_connection_id = state
+    let host_connection_id = server
+        .state
         .rpc
         .read()
         .await
         .read_worktree(request.payload.worktree_id, request.sender_id)?
         .host_connection_id()?;
 
-    rpc.forward_send(request.sender_id, host_connection_id, request.payload)
+    server
+        .rpc
+        .forward_send(request.sender_id, host_connection_id, request.payload)
         .await?;
 
     Ok(())
 }
 
 async fn save_buffer(
-    request: TypedEnvelope<proto::SaveBuffer>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::SaveBuffer>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
     let host;
     let guests;
     {
-        let state = state.rpc.read().await;
+        let state = server.state.rpc.read().await;
         let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
         host = worktree.host_connection_id()?;
         guests = worktree
@@ -620,17 +631,19 @@ async fn save_buffer(
 
     let sender = request.sender_id;
     let receipt = request.receipt();
-    let response = rpc
+    let response = server
+        .rpc
         .forward_request(sender, host, request.payload.clone())
         .await?;
 
     broadcast(host, guests, |conn_id| {
         let response = response.clone();
+        let server = &server;
         async move {
             if conn_id == sender {
-                rpc.respond(receipt, response).await
+                server.rpc.respond(receipt, response).await
             } else {
-                rpc.forward_send(host, conn_id, response).await
+                server.rpc.forward_send(host, conn_id, response).await
             }
         }
     })
@@ -640,61 +653,62 @@ async fn save_buffer(
 }
 
 async fn update_buffer(
-    request: TypedEnvelope<proto::UpdateBuffer>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::UpdateBuffer>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
-    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
+    broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
 }
 
 async fn buffer_saved(
-    request: TypedEnvelope<proto::BufferSaved>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::BufferSaved>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
-    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
+    broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
 }
 
 async fn get_channels(
-    request: TypedEnvelope<proto::GetChannels>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::GetChannels>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
-    let user_id = state
+    let user_id = server
+        .state
         .rpc
         .read()
         .await
         .user_id_for_connection(request.sender_id)?;
-    let channels = state.db.get_channels_for_user(user_id).await?;
-    rpc.respond(
-        request.receipt(),
-        proto::GetChannelsResponse {
-            channels: channels
-                .into_iter()
-                .map(|chan| proto::Channel {
-                    id: chan.id.to_proto(),
-                    name: chan.name,
-                })
-                .collect(),
-        },
-    )
-    .await?;
+    let channels = server.state.db.get_channels_for_user(user_id).await?;
+    server
+        .rpc
+        .respond(
+            request.receipt(),
+            proto::GetChannelsResponse {
+                channels: channels
+                    .into_iter()
+                    .map(|chan| proto::Channel {
+                        id: chan.id.to_proto(),
+                        name: chan.name,
+                    })
+                    .collect(),
+            },
+        )
+        .await?;
     Ok(())
 }
 
 async fn get_users(
-    request: TypedEnvelope<proto::GetUsers>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::GetUsers>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
-    let user_id = state
+    let user_id = server
+        .state
         .rpc
         .read()
         .await
         .user_id_for_connection(request.sender_id)?;
     let receipt = request.receipt();
     let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
-    let users = state
+    let users = server
+        .state
         .db
         .get_users_by_ids(user_id, user_ids)
         .await?
@@ -705,23 +719,26 @@ async fn get_users(
             avatar_url: String::new(),
         })
         .collect();
-    rpc.respond(receipt, proto::GetUsersResponse { users })
+    server
+        .rpc
+        .respond(receipt, proto::GetUsersResponse { users })
         .await?;
     Ok(())
 }
 
 async fn join_channel(
-    request: TypedEnvelope<proto::JoinChannel>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::JoinChannel>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
-    let user_id = state
+    let user_id = server
+        .state
         .rpc
         .read()
         .await
         .user_id_for_connection(request.sender_id)?;
     let channel_id = ChannelId::from_proto(request.payload.channel_id);
-    if !state
+    if !server
+        .state
         .db
         .can_user_access_channel(user_id, channel_id)
         .await?
@@ -729,12 +746,14 @@ async fn join_channel(
         Err(anyhow!("access denied"))?;
     }
 
-    state
+    server
+        .state
         .rpc
         .write()
         .await
         .join_channel(request.sender_id, channel_id);
-    let messages = state
+    let messages = server
+        .state
         .db
         .get_recent_channel_messages(channel_id, 50)
         .await?
@@ -746,21 +765,22 @@ async fn join_channel(
             sender_id: msg.sender_id.to_proto(),
         })
         .collect();
-    rpc.respond(request.receipt(), proto::JoinChannelResponse { messages })
+    server
+        .rpc
+        .respond(request.receipt(), proto::JoinChannelResponse { messages })
         .await?;
     Ok(())
 }
 
 async fn send_channel_message(
-    request: TypedEnvelope<proto::SendChannelMessage>,
-    peer: &Arc<Peer>,
-    app: &Arc<AppState>,
+    request: Box<TypedEnvelope<proto::SendChannelMessage>>,
+    server: Arc<Server>,
 ) -> tide::Result<()> {
     let channel_id = ChannelId::from_proto(request.payload.channel_id);
     let user_id;
     let connection_ids;
     {
-        let state = app.rpc.read().await;
+        let state = server.state.rpc.read().await;
         user_id = state.user_id_for_connection(request.sender_id)?;
         if let Some(channel) = state.channels.get(&channel_id) {
             connection_ids = channel.connection_ids();
@@ -770,7 +790,8 @@ async fn send_channel_message(
     }
 
     let timestamp = OffsetDateTime::now_utc();
-    let message_id = app
+    let message_id = server
+        .state
         .db
         .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
         .await?;
@@ -784,7 +805,7 @@ async fn send_channel_message(
         }),
     };
     broadcast(request.sender_id, connection_ids, |conn_id| {
-        peer.send(conn_id, message.clone())
+        server.rpc.send(conn_id, message.clone())
     })
     .await?;
 
@@ -793,11 +814,11 @@ async fn send_channel_message(
 
 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
     worktree_id: u64,
-    request: TypedEnvelope<T>,
-    rpc: &Arc<Peer>,
-    state: &Arc<AppState>,
+    request: &TypedEnvelope<T>,
+    server: &Arc<Server>,
 ) -> tide::Result<()> {
-    let connection_ids = state
+    let connection_ids = server
+        .state
         .rpc
         .read()
         .await
@@ -805,7 +826,9 @@ async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
         .connection_ids();
 
     broadcast(request.sender_id, connection_ids, |conn_id| {
-        rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
+        server
+            .rpc
+            .forward_send(request.sender_id, conn_id, request.payload.clone())
     })
     .await?;