Refactor `add_request_handler` to respond via a `Response` struct

Antonio Scandurra and Nathan Sobo created

This also removes `add_sync_request_handler`.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/collab/src/rpc.rs | 218 +++++++++++++++++++++++------------------
1 file changed, 123 insertions(+), 95 deletions(-)

Detailed changes

crates/collab/src/rpc.rs 🔗

@@ -18,7 +18,7 @@ use axum::{
     headers::{Header, HeaderName},
     http::StatusCode,
     middleware,
-    response::{IntoResponse, Response},
+    response::IntoResponse,
     routing::get,
     Extension, Router, TypedHeader,
 };
@@ -27,7 +27,7 @@ use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt, T
 use lazy_static::lazy_static;
 use rpc::{
     proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
-    Connection, ConnectionId, Peer, TypedEnvelope,
+    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
 };
 use std::{
     any::TypeId,
@@ -36,7 +36,10 @@ use std::{
     net::SocketAddr,
     ops::{Deref, DerefMut},
     rc::Rc,
-    sync::Arc,
+    sync::{
+        atomic::{AtomicBool, Ordering::SeqCst},
+        Arc,
+    },
     time::Duration,
 };
 use store::{Store, Worktree};
@@ -51,6 +54,20 @@ use tracing::{info_span, instrument, Instrument};
 type MessageHandler =
     Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
 
+struct Response<R> {
+    server: Arc<Server>,
+    receipt: Receipt<R>,
+    responded: Arc<AtomicBool>,
+}
+
+impl<R: RequestMessage> Response<R> {
+    fn send(self, payload: R::Response) -> Result<()> {
+        self.responded.store(true, SeqCst);
+        self.server.peer.respond(self.receipt, payload)?;
+        Ok(())
+    }
+}
+
 pub struct Server {
     peer: Arc<Peer>,
     store: RwLock<Store>,
@@ -100,7 +117,7 @@ impl Server {
             .add_message_handler(Server::unregister_project)
             .add_request_handler(Server::share_project)
             .add_message_handler(Server::unshare_project)
-            .add_sync_request_handler(Server::join_project)
+            .add_request_handler(Server::join_project)
             .add_message_handler(Server::leave_project)
             .add_request_handler(Server::register_worktree)
             .add_message_handler(Server::unregister_worktree)
@@ -179,43 +196,12 @@ impl Server {
         self
     }
 
-    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
-    where
-        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
-        Fut: 'static + Send + Future<Output = Result<M::Response>>,
-        M: RequestMessage,
-    {
-        self.add_message_handler(move |server, envelope| {
-            let receipt = envelope.receipt();
-            let response = (handler)(server.clone(), envelope);
-            async move {
-                match response.await {
-                    Ok(response) => {
-                        server.peer.respond(receipt, response)?;
-                        Ok(())
-                    }
-                    Err(error) => {
-                        server.peer.respond_with_error(
-                            receipt,
-                            proto::Error {
-                                message: error.to_string(),
-                            },
-                        )?;
-                        Err(error)
-                    }
-                }
-            }
-        })
-    }
-
     /// Handle a request while holding a lock to the store. This is useful when we're registering
     /// a connection but we want to respond on the connection before anybody else can send on it.
-    fn add_sync_request_handler<F, M>(&mut self, handler: F) -> &mut Self
+    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
     where
-        F: 'static
-            + Send
-            + Sync
-            + Fn(Arc<Self>, &mut Store, TypedEnvelope<M>) -> Result<M::Response>,
+        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
+        Fut: Send + Future<Output = Result<()>>,
         M: RequestMessage,
     {
         let handler = Arc::new(handler);
@@ -223,12 +209,19 @@ impl Server {
             let receipt = envelope.receipt();
             let handler = handler.clone();
             async move {
-                let mut store = server.state_mut().await;
-                let response = (handler)(server.clone(), &mut *store, envelope);
-                match response {
-                    Ok(response) => {
-                        server.peer.respond(receipt, response)?;
-                        Ok(())
+                let responded = Arc::new(AtomicBool::default());
+                let response = Response {
+                    server: server.clone(),
+                    responded: responded.clone(),
+                    receipt: envelope.receipt(),
+                };
+                match (handler)(server.clone(), envelope, response).await {
+                    Ok(()) => {
+                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
+                            Ok(())
+                        } else {
+                            Err(anyhow!("handler did not send a response"))?
+                        }
                     }
                     Err(error) => {
                         server.peer.respond_with_error(
@@ -364,20 +357,27 @@ impl Server {
         Ok(())
     }
 
-    async fn ping(self: Arc<Server>, _: TypedEnvelope<proto::Ping>) -> Result<proto::Ack> {
-        Ok(proto::Ack {})
+    async fn ping(
+        self: Arc<Server>,
+        _: TypedEnvelope<proto::Ping>,
+        response: Response<proto::Ping>,
+    ) -> Result<()> {
+        response.send(proto::Ack {})?;
+        Ok(())
     }
 
     async fn register_project(
         self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterProject>,
-    ) -> Result<proto::RegisterProjectResponse> {
+        response: Response<proto::RegisterProject>,
+    ) -> Result<()> {
         let project_id = {
             let mut state = self.state_mut().await;
             let user_id = state.user_id_for_connection(request.sender_id)?;
             state.register_project(request.sender_id, user_id)
         };
-        Ok(proto::RegisterProjectResponse { project_id })
+        response.send(proto::RegisterProjectResponse { project_id })?;
+        Ok(())
     }
 
     async fn unregister_project(
@@ -393,11 +393,13 @@ impl Server {
     async fn share_project(
         self: Arc<Server>,
         request: TypedEnvelope<proto::ShareProject>,
-    ) -> Result<proto::Ack> {
+        response: Response<proto::ShareProject>,
+    ) -> Result<()> {
         let mut state = self.state_mut().await;
         let project = state.share_project(request.payload.project_id, request.sender_id)?;
         self.update_contacts_for_users(&mut *state, &project.authorized_user_ids);
-        Ok(proto::Ack {})
+        response.send(proto::Ack {})?;
+        Ok(())
     }
 
     async fn unshare_project(
@@ -415,15 +417,16 @@ impl Server {
         Ok(())
     }
 
-    fn join_project(
+    async fn join_project(
         self: Arc<Server>,
-        state: &mut Store,
         request: TypedEnvelope<proto::JoinProject>,
-    ) -> Result<proto::JoinProjectResponse> {
+        response: Response<proto::JoinProject>,
+    ) -> Result<()> {
         let project_id = request.payload.project_id;
 
+        let state = &mut *self.state_mut().await;
         let user_id = state.user_id_for_connection(request.sender_id)?;
-        let (response, connection_ids, contact_user_ids) = state
+        let (response_payload, connection_ids, contact_user_ids) = state
             .join_project(request.sender_id, user_id, project_id)
             .and_then(|joined| {
                 let share = joined.project.share()?;
@@ -480,14 +483,15 @@ impl Server {
                     project_id,
                     collaborator: Some(proto::Collaborator {
                         peer_id: request.sender_id.0,
-                        replica_id: response.replica_id,
+                        replica_id: response_payload.replica_id,
                         user_id: user_id.to_proto(),
                     }),
                 },
             )
         });
         self.update_contacts_for_users(state, &contact_user_ids);
-        Ok(response)
+        response.send(response_payload)?;
+        Ok(())
     }
 
     async fn leave_project(
@@ -514,7 +518,8 @@ impl Server {
     async fn register_worktree(
         self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterWorktree>,
-    ) -> Result<proto::Ack> {
+        response: Response<proto::RegisterWorktree>,
+    ) -> Result<()> {
         let mut contact_user_ids = HashSet::default();
         for github_login in &request.payload.authorized_logins {
             let contact_user_id = self.app_state.db.create_user(github_login, false).await?;
@@ -545,7 +550,8 @@ impl Server {
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
         });
         self.update_contacts_for_users(&*state, &contact_user_ids);
-        Ok(proto::Ack {})
+        response.send(proto::Ack {})?;
+        Ok(())
     }
 
     async fn unregister_worktree(
@@ -573,7 +579,8 @@ impl Server {
     async fn update_worktree(
         self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateWorktree>,
-    ) -> Result<proto::Ack> {
+        response: Response<proto::UpdateWorktree>,
+    ) -> Result<()> {
         let connection_ids = self.state_mut().await.update_worktree(
             request.sender_id,
             request.payload.project_id,
@@ -587,8 +594,8 @@ impl Server {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
         });
-
-        Ok(proto::Ack {})
+        response.send(proto::Ack {})?;
+        Ok(())
     }
 
     async fn update_diagnostic_summary(
@@ -652,7 +659,8 @@ impl Server {
     async fn forward_project_request<T>(
         self: Arc<Server>,
         request: TypedEnvelope<T>,
-    ) -> Result<T::Response>
+        response: Response<T>,
+    ) -> Result<()>
     where
         T: EntityMessage + RequestMessage,
     {
@@ -661,22 +669,26 @@ impl Server {
             .await
             .read_project(request.payload.remote_entity_id(), request.sender_id)?
             .host_connection_id;
-        Ok(self
-            .peer
-            .forward_request(request.sender_id, host_connection_id, request.payload)
-            .await?)
+
+        response.send(
+            self.peer
+                .forward_request(request.sender_id, host_connection_id, request.payload)
+                .await?,
+        )?;
+        Ok(())
     }
 
     async fn save_buffer(
         self: Arc<Server>,
         request: TypedEnvelope<proto::SaveBuffer>,
-    ) -> Result<proto::BufferSaved> {
+        response: Response<proto::SaveBuffer>,
+    ) -> Result<()> {
         let host = self
             .state()
             .await
             .read_project(request.payload.project_id, request.sender_id)?
             .host_connection_id;
-        let response = self
+        let response_payload = self
             .peer
             .forward_request(request.sender_id, host, request.payload.clone())
             .await?;
@@ -688,16 +700,18 @@ impl Server {
             .connection_ids();
         guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
         broadcast(host, guests, |conn_id| {
-            self.peer.forward_send(host, conn_id, response.clone())
+            self.peer
+                .forward_send(host, conn_id, response_payload.clone())
         });
-
-        Ok(response)
+        response.send(response_payload)?;
+        Ok(())
     }
 
     async fn update_buffer(
         self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateBuffer>,
-    ) -> Result<proto::Ack> {
+        response: Response<proto::UpdateBuffer>,
+    ) -> Result<()> {
         let receiver_ids = self
             .state()
             .await
@@ -706,7 +720,8 @@ impl Server {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
         });
-        Ok(proto::Ack {})
+        response.send(proto::Ack {})?;
+        Ok(())
     }
 
     async fn update_buffer_file(
@@ -757,7 +772,8 @@ impl Server {
     async fn follow(
         self: Arc<Self>,
         request: TypedEnvelope<proto::Follow>,
-    ) -> Result<proto::FollowResponse> {
+        response: Response<proto::Follow>,
+    ) -> Result<()> {
         let leader_id = ConnectionId(request.payload.leader_id);
         let follower_id = request.sender_id;
         if !self
@@ -768,14 +784,15 @@ impl Server {
         {
             Err(anyhow!("no such peer"))?;
         }
-        let mut response = self
+        let mut response_payload = self
             .peer
             .forward_request(request.sender_id, leader_id, request.payload)
             .await?;
-        response
+        response_payload
             .views
             .retain(|view| view.leader_id != Some(follower_id.0));
-        Ok(response)
+        response.send(response_payload)?;
+        Ok(())
     }
 
     async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
@@ -823,13 +840,14 @@ impl Server {
     async fn get_channels(
         self: Arc<Server>,
         request: TypedEnvelope<proto::GetChannels>,
-    ) -> Result<proto::GetChannelsResponse> {
+        response: Response<proto::GetChannels>,
+    ) -> Result<()> {
         let user_id = self
             .state()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channels = self.app_state.db.get_accessible_channels(user_id).await?;
-        Ok(proto::GetChannelsResponse {
+        response.send(proto::GetChannelsResponse {
             channels: channels
                 .into_iter()
                 .map(|chan| proto::Channel {
@@ -837,13 +855,15 @@ impl Server {
                     name: chan.name,
                 })
                 .collect(),
-        })
+        })?;
+        Ok(())
     }
 
     async fn get_users(
         self: Arc<Server>,
         request: TypedEnvelope<proto::GetUsers>,
-    ) -> Result<proto::UsersResponse> {
+        response: Response<proto::GetUsers>,
+    ) -> Result<()> {
         let user_ids = request
             .payload
             .user_ids
@@ -862,13 +882,15 @@ impl Server {
                 github_login: user.github_login,
             })
             .collect();
-        Ok(proto::UsersResponse { users })
+        response.send(proto::UsersResponse { users })?;
+        Ok(())
     }
 
     async fn fuzzy_search_users(
         self: Arc<Server>,
         request: TypedEnvelope<proto::FuzzySearchUsers>,
-    ) -> Result<proto::UsersResponse> {
+        response: Response<proto::FuzzySearchUsers>,
+    ) -> Result<()> {
         let query = request.payload.query;
         let db = &self.app_state.db;
         let users = match query.len() {
@@ -888,7 +910,8 @@ impl Server {
                 github_login: user.github_login,
             })
             .collect();
-        Ok(proto::UsersResponse { users })
+        response.send(proto::UsersResponse { users })?;
+        Ok(())
     }
 
     #[instrument(skip(self, state, user_ids))]
@@ -917,7 +940,8 @@ impl Server {
     async fn join_channel(
         self: Arc<Self>,
         request: TypedEnvelope<proto::JoinChannel>,
-    ) -> Result<proto::JoinChannelResponse> {
+        response: Response<proto::JoinChannel>,
+    ) -> Result<()> {
         let user_id = self
             .state()
             .await
@@ -949,10 +973,11 @@ impl Server {
                 nonce: Some(msg.nonce.as_u128().into()),
             })
             .collect::<Vec<_>>();
-        Ok(proto::JoinChannelResponse {
+        response.send(proto::JoinChannelResponse {
             done: messages.len() < MESSAGE_COUNT_PER_PAGE,
             messages,
-        })
+        })?;
+        Ok(())
     }
 
     async fn leave_channel(
@@ -983,7 +1008,8 @@ impl Server {
     async fn send_channel_message(
         self: Arc<Self>,
         request: TypedEnvelope<proto::SendChannelMessage>,
-    ) -> Result<proto::SendChannelMessageResponse> {
+        response: Response<proto::SendChannelMessage>,
+    ) -> Result<()> {
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         let user_id;
         let connection_ids;
@@ -1030,15 +1056,17 @@ impl Server {
                 },
             )
         });
-        Ok(proto::SendChannelMessageResponse {
+        response.send(proto::SendChannelMessageResponse {
             message: Some(message),
-        })
+        })?;
+        Ok(())
     }
 
     async fn get_channel_messages(
         self: Arc<Self>,
         request: TypedEnvelope<proto::GetChannelMessages>,
-    ) -> Result<proto::GetChannelMessagesResponse> {
+        response: Response<proto::GetChannelMessages>,
+    ) -> Result<()> {
         let user_id = self
             .state()
             .await
@@ -1071,11 +1099,11 @@ impl Server {
                 nonce: Some(msg.nonce.as_u128().into()),
             })
             .collect::<Vec<_>>();
-
-        Ok(proto::GetChannelMessagesResponse {
+        response.send(proto::GetChannelMessagesResponse {
             done: messages.len() < MESSAGE_COUNT_PER_PAGE,
             messages,
-        })
+        })?;
+        Ok(())
     }
 
     async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
@@ -1213,7 +1241,7 @@ pub async fn handle_websocket_request(
     Extension(server): Extension<Arc<Server>>,
     Extension(user_id): Extension<UserId>,
     ws: WebSocketUpgrade,
-) -> Response {
+) -> axum::response::Response {
     if protocol_version != rpc::PROTOCOL_VERSION {
         return (
             StatusCode::UPGRADE_REQUIRED,