Introduce a new `Session` struct to server message handlers

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/collab/src/rpc.rs | 457 +++++++++++++++++++----------------------
1 file changed, 211 insertions(+), 246 deletions(-)

Detailed changes

crates/collab/src/rpc.rs 🔗

@@ -68,21 +68,20 @@ lazy_static! {
 }
 
 type MessageHandler = Box<
-    dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
+    dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>,
 >;
 
-struct Message<T> {
-    sender_user_id: UserId,
-    sender_connection_id: ConnectionId,
-    payload: T,
-}
-
 struct Response<R> {
     server: Arc<Server>,
     receipt: Receipt<R>,
     responded: Arc<AtomicBool>,
 }
 
+struct Session {
+    user_id: UserId,
+    connection_id: ConnectionId,
+}
+
 impl<R: RequestMessage> Response<R> {
     fn send(self, payload: R::Response) -> Result<()> {
         self.responded.store(true, SeqCst);
@@ -201,13 +200,13 @@ impl Server {
 
     fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
     where
-        F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
+        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Session) -> Fut,
         Fut: 'static + Send + Future<Output = Result<()>>,
         M: EnvelopedMessage,
     {
         let prev_handler = self.handlers.insert(
             TypeId::of::<M>(),
-            Box::new(move |server, sender_user_id, envelope| {
+            Box::new(move |server, envelope, session| {
                 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
                 let span = info_span!(
                     "handle message",
@@ -219,7 +218,7 @@ impl Server {
                         "message received"
                     );
                 });
-                let future = (handler)(server, sender_user_id, *envelope);
+                let future = (handler)(server, *envelope, session);
                 async move {
                     if let Err(error) = future.await {
                         tracing::error!(%error, "error handling message");
@@ -237,19 +236,12 @@ impl Server {
 
     fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
     where
-        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
+        F: 'static + Send + Sync + Fn(Arc<Self>, M, Session) -> Fut,
         Fut: 'static + Send + Future<Output = Result<()>>,
         M: EnvelopedMessage,
     {
-        self.add_handler(move |server, sender_user_id, envelope| {
-            handler(
-                server,
-                Message {
-                    sender_user_id,
-                    sender_connection_id: envelope.sender_id,
-                    payload: envelope.payload,
-                },
-            )
+        self.add_handler(move |server, envelope, session| {
+            handler(server, envelope.payload, session)
         });
         self
     }
@@ -258,27 +250,22 @@ impl Server {
     /// a connection but we want to respond on the connection before anybody else can send on it.
     fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
     where
-        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
+        F: 'static + Send + Sync + Fn(Arc<Self>, M, Response<M>, Session) -> Fut,
         Fut: Send + Future<Output = Result<()>>,
         M: RequestMessage,
     {
         let handler = Arc::new(handler);
-        self.add_handler(move |server, sender_user_id, envelope| {
+        self.add_handler(move |server, envelope, session| {
             let receipt = envelope.receipt();
             let handler = handler.clone();
             async move {
-                let request = Message {
-                    sender_user_id,
-                    sender_connection_id: envelope.sender_id,
-                    payload: envelope.payload,
-                };
                 let responded = Arc::new(AtomicBool::default());
                 let response = Response {
                     server: server.clone(),
                     responded: responded.clone(),
                     receipt,
                 };
-                match (handler)(server.clone(), request, response).await {
+                match (handler)(server.clone(), envelope.payload, response, session).await {
                     Ok(()) => {
                         if responded.load(std::sync::atomic::Ordering::SeqCst) {
                             Ok(())
@@ -392,7 +379,11 @@ impl Server {
                             let span_enter = span.enter();
                             if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
                                 let is_background = message.is_background();
-                                let handle_message = (handler)(this.clone(), user_id, message);
+                                let session = Session {
+                                    user_id,
+                                    connection_id,
+                                };
+                                let handle_message = (handler)(this.clone(), message, session);
                                 drop(span_enter);
 
                                 let handle_message = handle_message.instrument(span);
@@ -509,8 +500,9 @@ impl Server {
 
     async fn ping(
         self: Arc<Server>,
-        _: Message<proto::Ping>,
+        _: proto::Ping,
         response: Response<proto::Ping>,
+        _session: Session,
     ) -> Result<()> {
         response.send(proto::Ack {})?;
         Ok(())
@@ -518,13 +510,14 @@ impl Server {
 
     async fn create_room(
         self: Arc<Server>,
-        request: Message<proto::CreateRoom>,
+        _request: proto::CreateRoom,
         response: Response<proto::CreateRoom>,
+        session: Session,
     ) -> Result<()> {
         let room = self
             .app_state
             .db
-            .create_room(request.sender_user_id, request.sender_connection_id)
+            .create_room(session.user_id, session.connection_id)
             .await?;
 
         let live_kit_connection_info =
@@ -535,10 +528,7 @@ impl Server {
                     .trace_err()
                 {
                     if let Some(token) = live_kit
-                        .room_token(
-                            &room.live_kit_room,
-                            &request.sender_connection_id.to_string(),
-                        )
+                        .room_token(&room.live_kit_room, &session.connection_id.to_string())
                         .trace_err()
                     {
                         Some(proto::LiveKitConnectionInfo {
@@ -559,29 +549,26 @@ impl Server {
             room: Some(room),
             live_kit_connection_info,
         })?;
-        self.update_user_contacts(request.sender_user_id).await?;
+        self.update_user_contacts(session.user_id).await?;
         Ok(())
     }
 
     async fn join_room(
         self: Arc<Server>,
-        request: Message<proto::JoinRoom>,
+        request: proto::JoinRoom,
         response: Response<proto::JoinRoom>,
+        session: Session,
     ) -> Result<()> {
         let room = self
             .app_state
             .db
             .join_room(
-                RoomId::from_proto(request.payload.id),
-                request.sender_user_id,
-                request.sender_connection_id,
+                RoomId::from_proto(request.id),
+                session.user_id,
+                session.connection_id,
             )
             .await?;
-        for connection_id in self
-            .store()
-            .await
-            .connection_ids_for_user(request.sender_user_id)
-        {
+        for connection_id in self.store().await.connection_ids_for_user(session.user_id) {
             self.peer
                 .send(connection_id, proto::CallCanceled {})
                 .trace_err();
@@ -590,10 +577,7 @@ impl Server {
         let live_kit_connection_info =
             if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
                 if let Some(token) = live_kit
-                    .room_token(
-                        &room.live_kit_room,
-                        &request.sender_connection_id.to_string(),
-                    )
+                    .room_token(&room.live_kit_room, &session.connection_id.to_string())
                     .trace_err()
                 {
                     Some(proto::LiveKitConnectionInfo {
@@ -613,12 +597,16 @@ impl Server {
             live_kit_connection_info,
         })?;
 
-        self.update_user_contacts(request.sender_user_id).await?;
+        self.update_user_contacts(session.user_id).await?;
         Ok(())
     }
 
-    async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
-        self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id)
+    async fn leave_room(
+        self: Arc<Server>,
+        _message: proto::LeaveRoom,
+        session: Session,
+    ) -> Result<()> {
+        self.leave_room_for_connection(session.connection_id, session.user_id)
             .await
     }
 
@@ -707,17 +695,15 @@ impl Server {
 
     async fn call(
         self: Arc<Server>,
-        request: Message<proto::Call>,
+        request: proto::Call,
         response: Response<proto::Call>,
+        session: Session,
     ) -> Result<()> {
-        let room_id = RoomId::from_proto(request.payload.room_id);
-        let calling_user_id = request.sender_user_id;
-        let calling_connection_id = request.sender_connection_id;
-        let called_user_id = UserId::from_proto(request.payload.called_user_id);
-        let initial_project_id = request
-            .payload
-            .initial_project_id
-            .map(ProjectId::from_proto);
+        let room_id = RoomId::from_proto(request.room_id);
+        let calling_user_id = session.user_id;
+        let calling_connection_id = session.connection_id;
+        let called_user_id = UserId::from_proto(request.called_user_id);
+        let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
         if !self
             .app_state
             .db
@@ -773,15 +759,16 @@ impl Server {
 
     async fn cancel_call(
         self: Arc<Server>,
-        request: Message<proto::CancelCall>,
+        request: proto::CancelCall,
         response: Response<proto::CancelCall>,
+        session: Session,
     ) -> Result<()> {
-        let called_user_id = UserId::from_proto(request.payload.called_user_id);
-        let room_id = RoomId::from_proto(request.payload.room_id);
+        let called_user_id = UserId::from_proto(request.called_user_id);
+        let room_id = RoomId::from_proto(request.room_id);
         let room = self
             .app_state
             .db
-            .cancel_call(Some(room_id), request.sender_connection_id, called_user_id)
+            .cancel_call(Some(room_id), session.connection_id, called_user_id)
             .await?;
         for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
             self.peer
@@ -795,41 +782,41 @@ impl Server {
         Ok(())
     }
 
-    async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
-        let room_id = RoomId::from_proto(message.payload.room_id);
+    async fn decline_call(
+        self: Arc<Server>,
+        message: proto::DeclineCall,
+        session: Session,
+    ) -> Result<()> {
+        let room_id = RoomId::from_proto(message.room_id);
         let room = self
             .app_state
             .db
-            .decline_call(Some(room_id), message.sender_user_id)
+            .decline_call(Some(room_id), session.user_id)
             .await?;
-        for connection_id in self
-            .store()
-            .await
-            .connection_ids_for_user(message.sender_user_id)
-        {
+        for connection_id in self.store().await.connection_ids_for_user(session.user_id) {
             self.peer
                 .send(connection_id, proto::CallCanceled {})
                 .trace_err();
         }
         self.room_updated(&room);
-        self.update_user_contacts(message.sender_user_id).await?;
+        self.update_user_contacts(session.user_id).await?;
         Ok(())
     }
 
     async fn update_participant_location(
         self: Arc<Server>,
-        request: Message<proto::UpdateParticipantLocation>,
+        request: proto::UpdateParticipantLocation,
         response: Response<proto::UpdateParticipantLocation>,
+        session: Session,
     ) -> Result<()> {
-        let room_id = RoomId::from_proto(request.payload.room_id);
+        let room_id = RoomId::from_proto(request.room_id);
         let location = request
-            .payload
             .location
             .ok_or_else(|| anyhow!("invalid location"))?;
         let room = self
             .app_state
             .db
-            .update_room_participant_location(room_id, request.sender_connection_id, location)
+            .update_room_participant_location(room_id, session.connection_id, location)
             .await?;
         self.room_updated(&room);
         response.send(proto::Ack {})?;
@@ -851,16 +838,17 @@ impl Server {
 
     async fn share_project(
         self: Arc<Server>,
-        request: Message<proto::ShareProject>,
+        request: proto::ShareProject,
         response: Response<proto::ShareProject>,
+        session: Session,
     ) -> Result<()> {
         let (project_id, room) = self
             .app_state
             .db
             .share_project(
-                RoomId::from_proto(request.payload.room_id),
-                request.sender_connection_id,
-                &request.payload.worktrees,
+                RoomId::from_proto(request.room_id),
+                session.connection_id,
+                &request.worktrees,
             )
             .await?;
         response.send(proto::ShareProjectResponse {
@@ -873,21 +861,20 @@ impl Server {
 
     async fn unshare_project(
         self: Arc<Server>,
-        message: Message<proto::UnshareProject>,
+        message: proto::UnshareProject,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(message.payload.project_id);
+        let project_id = ProjectId::from_proto(message.project_id);
 
         let (room, guest_connection_ids) = self
             .app_state
             .db
-            .unshare_project(project_id, message.sender_connection_id)
+            .unshare_project(project_id, session.connection_id)
             .await?;
 
-        broadcast(
-            message.sender_connection_id,
-            guest_connection_ids,
-            |conn_id| self.peer.send(conn_id, message.payload.clone()),
-        );
+        broadcast(session.connection_id, guest_connection_ids, |conn_id| {
+            self.peer.send(conn_id, message.clone())
+        });
         self.room_updated(&room);
 
         Ok(())
@@ -926,26 +913,25 @@ impl Server {
 
     async fn join_project(
         self: Arc<Server>,
-        request: Message<proto::JoinProject>,
+        request: proto::JoinProject,
         response: Response<proto::JoinProject>,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
-        let guest_user_id = request.sender_user_id;
+        let project_id = ProjectId::from_proto(request.project_id);
+        let guest_user_id = session.user_id;
 
         tracing::info!(%project_id, "join project");
 
         let (project, replica_id) = self
             .app_state
             .db
-            .join_project(project_id, request.sender_connection_id)
+            .join_project(project_id, session.connection_id)
             .await?;
 
         let collaborators = project
             .collaborators
             .iter()
-            .filter(|collaborator| {
-                collaborator.connection_id != request.sender_connection_id.0 as i32
-            })
+            .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
             .map(|collaborator| proto::Collaborator {
                 peer_id: collaborator.connection_id as u32,
                 replica_id: collaborator.replica_id.0 as u32,
@@ -970,7 +956,7 @@ impl Server {
                     proto::AddProjectCollaborator {
                         project_id: project_id.to_proto(),
                         collaborator: Some(proto::Collaborator {
-                            peer_id: request.sender_connection_id.0,
+                            peer_id: session.connection_id.0,
                             replica_id: replica_id.0 as u32,
                             user_id: guest_user_id.to_proto(),
                         }),
@@ -1005,14 +991,13 @@ impl Server {
                 is_last_update: worktree.is_complete,
             };
             for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
-                self.peer
-                    .send(request.sender_connection_id, update.clone())?;
+                self.peer.send(session.connection_id, update.clone())?;
             }
 
             // Stream this worktree's diagnostics.
             for summary in worktree.diagnostic_summaries {
                 self.peer.send(
-                    request.sender_connection_id,
+                    session.connection_id,
                     proto::UpdateDiagnosticSummary {
                         project_id: project_id.to_proto(),
                         worktree_id: worktree.id.to_proto(),
@@ -1024,7 +1009,7 @@ impl Server {
 
         for language_server in &project.language_servers {
             self.peer.send(
-                request.sender_connection_id,
+                session.connection_id,
                 proto::UpdateLanguageServer {
                     project_id: project_id.to_proto(),
                     language_server_id: language_server.id,
@@ -1040,9 +1025,13 @@ impl Server {
         Ok(())
     }
 
-    async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
-        let sender_id = request.sender_connection_id;
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+    async fn leave_project(
+        self: Arc<Server>,
+        request: proto::LeaveProject,
+        session: Session,
+    ) -> Result<()> {
+        let sender_id = session.connection_id;
+        let project_id = ProjectId::from_proto(request.project_id);
         let project;
         {
             project = self
@@ -1073,28 +1062,22 @@ impl Server {
 
     async fn update_project(
         self: Arc<Server>,
-        request: Message<proto::UpdateProject>,
+        request: proto::UpdateProject,
         response: Response<proto::UpdateProject>,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_id = ProjectId::from_proto(request.project_id);
         let (room, guest_connection_ids) = self
             .app_state
             .db
-            .update_project(
-                project_id,
-                request.sender_connection_id,
-                &request.payload.worktrees,
-            )
+            .update_project(project_id, session.connection_id, &request.worktrees)
             .await?;
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             guest_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
         self.room_updated(&room);
@@ -1105,24 +1088,22 @@ impl Server {
 
     async fn update_worktree(
         self: Arc<Server>,
-        request: Message<proto::UpdateWorktree>,
+        request: proto::UpdateWorktree,
         response: Response<proto::UpdateWorktree>,
+        session: Session,
     ) -> Result<()> {
         let guest_connection_ids = self
             .app_state
             .db
-            .update_worktree(&request.payload, request.sender_connection_id)
+            .update_worktree(&request, session.connection_id)
             .await?;
 
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             guest_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
         response.send(proto::Ack {})?;
@@ -1131,24 +1112,22 @@ impl Server {
 
     async fn update_diagnostic_summary(
         self: Arc<Server>,
-        request: Message<proto::UpdateDiagnosticSummary>,
+        request: proto::UpdateDiagnosticSummary,
         response: Response<proto::UpdateDiagnosticSummary>,
+        session: Session,
     ) -> Result<()> {
         let guest_connection_ids = self
             .app_state
             .db
-            .update_diagnostic_summary(&request.payload, request.sender_connection_id)
+            .update_diagnostic_summary(&request, session.connection_id)
             .await?;
 
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             guest_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
 
@@ -1158,23 +1137,21 @@ impl Server {
 
     async fn start_language_server(
         self: Arc<Server>,
-        request: Message<proto::StartLanguageServer>,
+        request: proto::StartLanguageServer,
+        session: Session,
     ) -> Result<()> {
         let guest_connection_ids = self
             .app_state
             .db
-            .start_language_server(&request.payload, request.sender_connection_id)
+            .start_language_server(&request, session.connection_id)
             .await?;
 
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             guest_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
         Ok(())
@@ -1182,23 +1159,21 @@ impl Server {
 
     async fn update_language_server(
         self: Arc<Server>,
-        request: Message<proto::UpdateLanguageServer>,
+        request: proto::UpdateLanguageServer,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_id = ProjectId::from_proto(request.project_id);
         let project_connection_ids = self
             .app_state
             .db
-            .project_connection_ids(project_id, request.sender_connection_id)
+            .project_connection_ids(project_id, session.connection_id)
             .await?;
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             project_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
         Ok(())
@@ -1206,17 +1181,18 @@ impl Server {
 
     async fn forward_project_request<T>(
         self: Arc<Server>,
-        request: Message<T>,
+        request: T,
         response: Response<T>,
+        session: Session,
     ) -> Result<()>
     where
         T: EntityMessage + RequestMessage,
     {
-        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
+        let project_id = ProjectId::from_proto(request.remote_entity_id());
         let collaborators = self
             .app_state
             .db
-            .project_collaborators(project_id, request.sender_connection_id)
+            .project_collaborators(project_id, session.connection_id)
             .await?;
         let host = collaborators
             .iter()
@@ -1226,9 +1202,9 @@ impl Server {
         let payload = self
             .peer
             .forward_request(
-                request.sender_connection_id,
+                session.connection_id,
                 ConnectionId(host.connection_id as u32),
-                request.payload,
+                request,
             )
             .await?;
 
@@ -1238,14 +1214,15 @@ impl Server {
 
     async fn save_buffer(
         self: Arc<Server>,
-        request: Message<proto::SaveBuffer>,
+        request: proto::SaveBuffer,
         response: Response<proto::SaveBuffer>,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_id = ProjectId::from_proto(request.project_id);
         let collaborators = self
             .app_state
             .db
-            .project_collaborators(project_id, request.sender_connection_id)
+            .project_collaborators(project_id, session.connection_id)
             .await?;
         let host = collaborators
             .into_iter()
@@ -1254,21 +1231,16 @@ impl Server {
         let host_connection_id = ConnectionId(host.connection_id as u32);
         let response_payload = self
             .peer
-            .forward_request(
-                request.sender_connection_id,
-                host_connection_id,
-                request.payload.clone(),
-            )
+            .forward_request(session.connection_id, host_connection_id, request.clone())
             .await?;
 
         let mut collaborators = self
             .app_state
             .db
-            .project_collaborators(project_id, request.sender_connection_id)
+            .project_collaborators(project_id, session.connection_id)
             .await?;
-        collaborators.retain(|collaborator| {
-            collaborator.connection_id != request.sender_connection_id.0 as i32
-        });
+        collaborators
+            .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
         let project_connection_ids = collaborators
             .into_iter()
             .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
@@ -1282,37 +1254,36 @@ impl Server {
 
     async fn create_buffer_for_peer(
         self: Arc<Server>,
-        request: Message<proto::CreateBufferForPeer>,
+        request: proto::CreateBufferForPeer,
+        session: Session,
     ) -> Result<()> {
         self.peer.forward_send(
-            request.sender_connection_id,
-            ConnectionId(request.payload.peer_id),
-            request.payload,
+            session.connection_id,
+            ConnectionId(request.peer_id),
+            request,
         )?;
         Ok(())
     }
 
     async fn update_buffer(
         self: Arc<Server>,
-        request: Message<proto::UpdateBuffer>,
+        request: proto::UpdateBuffer,
         response: Response<proto::UpdateBuffer>,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_id = ProjectId::from_proto(request.project_id);
         let project_connection_ids = self
             .app_state
             .db
-            .project_connection_ids(project_id, request.sender_connection_id)
+            .project_connection_ids(project_id, session.connection_id)
             .await?;
 
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             project_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
         response.send(proto::Ack {})?;
@@ -1321,24 +1292,22 @@ impl Server {
 
     async fn update_buffer_file(
         self: Arc<Server>,
-        request: Message<proto::UpdateBufferFile>,
+        request: proto::UpdateBufferFile,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_id = ProjectId::from_proto(request.project_id);
         let project_connection_ids = self
             .app_state
             .db
-            .project_connection_ids(project_id, request.sender_connection_id)
+            .project_connection_ids(project_id, session.connection_id)
             .await?;
 
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             project_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
         Ok(())
@@ -1346,44 +1315,43 @@ impl Server {
 
     async fn buffer_reloaded(
         self: Arc<Server>,
-        request: Message<proto::BufferReloaded>,
+        request: proto::BufferReloaded,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_id = ProjectId::from_proto(request.project_id);
         let project_connection_ids = self
             .app_state
             .db
-            .project_connection_ids(project_id, request.sender_connection_id)
+            .project_connection_ids(project_id, session.connection_id)
             .await?;
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             project_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
         Ok(())
     }
 
-    async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+    async fn buffer_saved(
+        self: Arc<Server>,
+        request: proto::BufferSaved,
+        session: Session,
+    ) -> Result<()> {
+        let project_id = ProjectId::from_proto(request.project_id);
         let project_connection_ids = self
             .app_state
             .db
-            .project_connection_ids(project_id, request.sender_connection_id)
+            .project_connection_ids(project_id, session.connection_id)
             .await?;
         broadcast(
-            request.sender_connection_id,
+            session.connection_id,
             project_connection_ids,
             |connection_id| {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    connection_id,
-                    request.payload.clone(),
-                )
+                self.peer
+                    .forward_send(session.connection_id, connection_id, request.clone())
             },
         );
         Ok(())
@@ -1391,16 +1359,17 @@ impl Server {
 
     async fn follow(
         self: Arc<Self>,
-        request: Message<proto::Follow>,
+        request: proto::Follow,
         response: Response<proto::Follow>,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
-        let leader_id = ConnectionId(request.payload.leader_id);
-        let follower_id = request.sender_connection_id;
+        let project_id = ProjectId::from_proto(request.project_id);
+        let leader_id = ConnectionId(request.leader_id);
+        let follower_id = session.connection_id;
         let project_connection_ids = self
             .app_state
             .db
-            .project_connection_ids(project_id, request.sender_connection_id)
+            .project_connection_ids(project_id, session.connection_id)
             .await?;
 
         if !project_connection_ids.contains(&leader_id) {
@@ -1409,7 +1378,7 @@ impl Server {
 
         let mut response_payload = self
             .peer
-            .forward_request(request.sender_connection_id, leader_id, request.payload)
+            .forward_request(session.connection_id, leader_id, request)
             .await?;
         response_payload
             .views
@@ -1418,50 +1387,44 @@ impl Server {
         Ok(())
     }
 
-    async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
-        let leader_id = ConnectionId(request.payload.leader_id);
+    async fn unfollow(self: Arc<Self>, request: proto::Unfollow, session: Session) -> Result<()> {
+        let project_id = ProjectId::from_proto(request.project_id);
+        let leader_id = ConnectionId(request.leader_id);
         let project_connection_ids = self
             .app_state
             .db
-            .project_connection_ids(project_id, request.sender_connection_id)
+            .project_connection_ids(project_id, session.connection_id)
             .await?;
         if !project_connection_ids.contains(&leader_id) {
             Err(anyhow!("no such peer"))?;
         }
         self.peer
-            .forward_send(request.sender_connection_id, leader_id, request.payload)?;
+            .forward_send(session.connection_id, leader_id, request)?;
         Ok(())
     }
 
     async fn update_followers(
         self: Arc<Self>,
-        request: Message<proto::UpdateFollowers>,
+        request: proto::UpdateFollowers,
+        session: Session,
     ) -> Result<()> {
-        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_id = ProjectId::from_proto(request.project_id);
         let project_connection_ids = self
             .app_state
             .db
-            .project_connection_ids(project_id, request.sender_connection_id)
+            .project_connection_ids(project_id, session.connection_id)
             .await?;
 
-        let leader_id = request
-            .payload
-            .variant
-            .as_ref()
-            .and_then(|variant| match variant {
-                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
-                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
-                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
-            });
-        for follower_id in &request.payload.follower_ids {
+        let leader_id = request.variant.as_ref().and_then(|variant| match variant {
+            proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
+            proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
+            proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
+        });
+        for follower_id in &request.follower_ids {
             let follower_id = ConnectionId(*follower_id);
             if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
-                self.peer.forward_send(
-                    request.sender_connection_id,
-                    follower_id,
-                    request.payload.clone(),
-                )?;
+                self.peer
+                    .forward_send(session.connection_id, follower_id, request.clone())?;
             }
         }
         Ok(())
@@ -1469,11 +1432,11 @@ impl Server {
 
     async fn get_users(
         self: Arc<Server>,
-        request: Message<proto::GetUsers>,
+        request: proto::GetUsers,
         response: Response<proto::GetUsers>,
+        _session: Session,
     ) -> Result<()> {
         let user_ids = request
-            .payload
             .user_ids
             .into_iter()
             .map(UserId::from_proto)
@@ -1496,10 +1459,11 @@ impl Server {
 
     async fn fuzzy_search_users(
         self: Arc<Server>,
-        request: Message<proto::FuzzySearchUsers>,
+        request: proto::FuzzySearchUsers,
         response: Response<proto::FuzzySearchUsers>,
+        session: Session,
     ) -> Result<()> {
-        let query = request.payload.query;
+        let query = request.query;
         let db = &self.app_state.db;
         let users = match query.len() {
             0 => vec![],
@@ -1512,7 +1476,7 @@ impl Server {
         };
         let users = users
             .into_iter()
-            .filter(|user| user.id != request.sender_user_id)
+            .filter(|user| user.id != session.user_id)
             .map(|user| proto::User {
                 id: user.id.to_proto(),
                 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
@@ -1525,11 +1489,12 @@ impl Server {
 
     async fn request_contact(
         self: Arc<Server>,
-        request: Message<proto::RequestContact>,
+        request: proto::RequestContact,
         response: Response<proto::RequestContact>,
+        session: Session,
     ) -> Result<()> {
-        let requester_id = request.sender_user_id;
-        let responder_id = UserId::from_proto(request.payload.responder_id);
+        let requester_id = session.user_id;
+        let responder_id = UserId::from_proto(request.responder_id);
         if requester_id == responder_id {
             return Err(anyhow!("cannot add yourself as a contact"))?;
         }