Preserve ordering between responses and other incoming messages

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock                  |  2 +
crates/client/src/client.rs | 45 +++++++++++++++++++++++++++++++++-----
crates/rpc/src/peer.rs      | 39 +++++++++++++++++----------------
crates/server/Cargo.toml    |  2 +
crates/server/src/rpc.rs    | 16 ++++++++++---
5 files changed, 75 insertions(+), 29 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5822,7 +5822,9 @@ dependencies = [
  "clap 3.0.0-beta.2",
  "collections",
  "comrak",
+ "ctor",
  "either",
+ "env_logger",
  "envy",
  "futures",
  "gpui",

crates/client/src/client.rs 🔗

@@ -24,7 +24,10 @@ use std::{
     collections::HashMap,
     convert::TryFrom,
     fmt::Write as _,
-    sync::{Arc, Weak},
+    sync::{
+        atomic::{AtomicUsize, Ordering},
+        Arc, Weak,
+    },
     time::{Duration, Instant},
 };
 use surf::{http::Method, Url};
@@ -54,6 +57,7 @@ pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
 }
 
 pub struct Client {
+    id: usize,
     peer: Arc<Peer>,
     http: Arc<dyn HttpClient>,
     state: RwLock<ClientState>,
@@ -166,7 +170,12 @@ impl Drop for Subscription {
 
 impl Client {
     pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
+        lazy_static! {
+            static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
+        }
+
         Arc::new(Self {
+            id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
             peer: Peer::new(),
             http,
             state: Default::default(),
@@ -447,21 +456,31 @@ impl Client {
                             None
                         };
 
+                        let type_name = message.payload_type_name();
+
                         let handler_key = (payload_type_id, entity_id);
                         if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
                             let mut handler = handler.take().unwrap();
                             drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
-                            let start_time = Instant::now();
-                            log::info!("RPC client message {}", message.payload_type_name());
+
+                            log::debug!(
+                                "rpc message received. client_id:{}, name:{}",
+                                this.id,
+                                type_name
+                            );
                             (handler)(message, &mut cx);
-                            log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
+                            log::debug!(
+                                "rpc message handled. client_id:{}, name:{}",
+                                this.id,
+                                type_name
+                            );
 
                             let mut state = this.state.write();
                             if state.model_handlers.contains_key(&handler_key) {
                                 state.model_handlers.insert(handler_key, Some(handler));
                             }
                         } else {
-                            log::info!("unhandled message {}", message.payload_type_name());
+                            log::info!("unhandled message {}", type_name);
                         }
                     }
                 }
@@ -677,11 +696,23 @@ impl Client {
     }
 
     pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
+        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
         self.peer.send(self.connection_id()?, message)
     }
 
     pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
-        self.peer.request(self.connection_id()?, request).await
+        log::debug!(
+            "rpc request start. client_id: {}. name:{}",
+            self.id,
+            T::NAME
+        );
+        let response = self.peer.request(self.connection_id()?, request).await;
+        log::debug!(
+            "rpc request finish. client_id: {}. name:{}",
+            self.id,
+            T::NAME
+        );
+        response
     }
 
     pub fn respond<T: RequestMessage>(
@@ -689,6 +720,7 @@ impl Client {
         receipt: Receipt<T>,
         response: T::Response,
     ) -> Result<()> {
+        log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
         self.peer.respond(receipt, response)
     }
 
@@ -697,6 +729,7 @@ impl Client {
         receipt: Receipt<T>,
         error: proto::Error,
     ) -> Result<()> {
+        log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
         self.peer.respond_with_error(receipt, error)
     }
 }

crates/rpc/src/peer.rs 🔗

@@ -5,7 +5,7 @@ use futures::stream::BoxStream;
 use futures::{FutureExt as _, StreamExt};
 use parking_lot::{Mutex, RwLock};
 use postage::{
-    mpsc,
+    barrier, mpsc,
     prelude::{Sink as _, Stream as _},
 };
 use smol_timeout::TimeoutExt as _;
@@ -91,7 +91,8 @@ pub struct Peer {
 pub struct ConnectionState {
     outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
     next_message_id: Arc<AtomicU32>,
-    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
+    response_channels:
+        Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 }
 
 const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
@@ -177,7 +178,9 @@ impl Peer {
                 if let Some(responding_to) = incoming.responding_to {
                     let channel = response_channels.lock().as_mut()?.remove(&responding_to);
                     if let Some(mut tx) = channel {
-                        tx.send(incoming).await.ok();
+                        let mut requester_resumed = barrier::channel();
+                        tx.send((incoming, requester_resumed.0)).await.ok();
+                        requester_resumed.1.recv().await;
                     } else {
                         log::warn!("received RPC response to unknown request {}", responding_to);
                     }
@@ -205,7 +208,7 @@ impl Peer {
     }
 
     pub fn request<T: RequestMessage>(
-        self: &Arc<Self>,
+        &self,
         receiver_id: ConnectionId,
         request: T,
     ) -> impl Future<Output = Result<T::Response>> {
@@ -213,7 +216,7 @@ impl Peer {
     }
 
     pub fn forward_request<T: RequestMessage>(
-        self: &Arc<Self>,
+        &self,
         sender_id: ConnectionId,
         receiver_id: ConnectionId,
         request: T,
@@ -222,15 +225,13 @@ impl Peer {
     }
 
     pub fn request_internal<T: RequestMessage>(
-        self: &Arc<Self>,
+        &self,
         original_sender_id: Option<ConnectionId>,
         receiver_id: ConnectionId,
         request: T,
     ) -> impl Future<Output = Result<T::Response>> {
-        let this = self.clone();
-        async move {
-            let (tx, mut rx) = mpsc::channel(1);
-            let connection = this.connection_state(receiver_id)?;
+        let (tx, mut rx) = mpsc::channel(1);
+        let send = self.connection_state(receiver_id).and_then(|connection| {
             let message_id = connection.next_message_id.fetch_add(1, SeqCst);
             connection
                 .response_channels
@@ -246,7 +247,11 @@ impl Peer {
                     original_sender_id.map(|id| id.0),
                 ))
                 .map_err(|_| anyhow!("connection was closed"))?;
-            let response = rx
+            Ok(())
+        });
+        async move {
+            send?;
+            let (response, _barrier) = rx
                 .recv()
                 .await
                 .ok_or_else(|| anyhow!("connection was closed"))?;
@@ -259,11 +264,7 @@ impl Peer {
         }
     }
 
-    pub fn send<T: EnvelopedMessage>(
-        self: &Arc<Self>,
-        receiver_id: ConnectionId,
-        message: T,
-    ) -> Result<()> {
+    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
         let connection = self.connection_state(receiver_id)?;
         let message_id = connection
             .next_message_id
@@ -275,7 +276,7 @@ impl Peer {
     }
 
     pub fn forward_send<T: EnvelopedMessage>(
-        self: &Arc<Self>,
+        &self,
         sender_id: ConnectionId,
         receiver_id: ConnectionId,
         message: T,
@@ -291,7 +292,7 @@ impl Peer {
     }
 
     pub fn respond<T: RequestMessage>(
-        self: &Arc<Self>,
+        &self,
         receipt: Receipt<T>,
         response: T::Response,
     ) -> Result<()> {
@@ -306,7 +307,7 @@ impl Peer {
     }
 
     pub fn respond_with_error<T: RequestMessage>(
-        self: &Arc<Self>,
+        &self,
         receipt: Receipt<T>,
         response: proto::Error,
     ) -> Result<()> {

crates/server/Cargo.toml 🔗

@@ -59,6 +59,8 @@ features = ["runtime-async-std-rustls", "postgres", "time", "uuid"]
 collections = { path = "../collections", features = ["test-support"] }
 gpui = { path = "../gpui" }
 zed = { path = "../zed", features = ["test-support"] }
+ctor = "0.1"
+env_logger = "0.8"
 
 lazy_static = "1.4"
 serde_json = { version = "1.0.64", features = ["preserve_order"] }

crates/server/src/rpc.rs 🔗

@@ -150,19 +150,20 @@ impl Server {
                     message = next_message => {
                         if let Some(message) = message {
                             let start_time = Instant::now();
-                            log::info!("RPC message received: {}", message.payload_type_name());
+                            let type_name = message.payload_type_name();
+                            log::info!("rpc message received. connection:{}, type:{}", connection_id, type_name);
                             if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
                                 if let Err(err) = (handler)(this.clone(), message).await {
-                                    log::error!("error handling message: {:?}", err);
+                                    log::error!("rpc message error. connection:{}, type:{}, error:{:?}", connection_id, type_name, err);
                                 } else {
-                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
+                                    log::info!("rpc message handled. connection:{}, type:{}, duration:{:?}", connection_id, type_name, start_time.elapsed());
                                 }
 
                                 if let Some(mut notifications) = this.notifications.clone() {
                                     let _ = notifications.send(()).await;
                                 }
                             } else {
-                                log::warn!("unhandled message: {}", message.payload_type_name());
+                                log::warn!("unhandled message: {}", type_name);
                             }
                         } else {
                             log::info!("rpc connection closed {:?}", addr);
@@ -1192,6 +1193,13 @@ mod tests {
         project::{DiagnosticSummary, Project, ProjectPath},
     };
 
+    #[cfg(test)]
+    #[ctor::ctor]
+    fn init_logger() {
+        // std::env::set_var("RUST_LOG", "info");
+        env_logger::init();
+    }
+
     #[gpui::test]
     async fn test_share_project(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         let (window_b, _) = cx_b.add_window(|_| EmptyView);