Restore synchronization between responses and incoming messages

Max Brunsfeld and Antonio Scandurra created

This removes the need to buffer pending messages in Client.

Co-Authored-By: Antonio Scandurra <me@as-cii.com>

Change summary

crates/client/src/client.rs   | 68 ++++--------------------------------
crates/project/src/project.rs | 42 +++++++++++++---------
crates/rpc/src/peer.rs        | 17 ++++++--
3 files changed, 44 insertions(+), 83 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -134,7 +134,6 @@ struct ClientState {
     _maintain_connection: Option<Task<()>>,
     heartbeat_interval: Duration,
 
-    pending_messages: HashMap<(TypeId, u64), Vec<Box<dyn AnyTypedEnvelope>>>,
     models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
     models_by_message_type: HashMap<TypeId, AnyModelHandle>,
     model_types_by_message_type: HashMap<TypeId, TypeId>,
@@ -169,7 +168,6 @@ impl Default for ClientState {
             models_by_message_type: Default::default(),
             models_by_entity_type_and_remote_id: Default::default(),
             model_types_by_message_type: Default::default(),
-            pending_messages: Default::default(),
             message_handlers: Default::default(),
         }
     }
@@ -311,46 +309,6 @@ impl Client {
         state
             .models_by_entity_type_and_remote_id
             .insert(id, handle.downgrade());
-        let pending_messages = state.pending_messages.remove(&id);
-        drop(state);
-
-        let client_id = self.id;
-        for message in pending_messages.into_iter().flatten() {
-            let type_id = message.payload_type_id();
-            let type_name = message.payload_type_name();
-            let state = self.state.read();
-            if let Some(handler) = state.message_handlers.get(&type_id).cloned() {
-                let future = (handler)(handle.clone(), message, cx.to_async());
-                drop(state);
-                log::debug!(
-                    "deferred rpc message received. client_id:{}, name:{}",
-                    client_id,
-                    type_name
-                );
-                cx.foreground()
-                    .spawn(async move {
-                        match future.await {
-                            Ok(()) => {
-                                log::debug!(
-                                    "deferred rpc message handled. client_id:{}, name:{}",
-                                    client_id,
-                                    type_name
-                                );
-                            }
-                            Err(error) => {
-                                log::error!(
-                                    "error handling deferred message. client_id:{}, name:{}, {}",
-                                    client_id,
-                                    type_name,
-                                    error
-                                );
-                            }
-                        }
-                    })
-                    .detach();
-            }
-        }
-
         Subscription::Entity {
             client: Arc::downgrade(self),
             id,
@@ -568,22 +526,20 @@ impl Client {
                         let mut state = this.state.write();
                         let payload_type_id = message.payload_type_id();
                         let type_name = message.payload_type_name();
-                        let model_type_id = state
-                            .model_types_by_message_type
-                            .get(&payload_type_id)
-                            .copied();
-                        let entity_id = state
-                            .entity_id_extractors
-                            .get(&message.payload_type_id())
-                            .map(|extract_entity_id| (extract_entity_id)(message.as_ref()));
 
                         let model = state
                             .models_by_message_type
                             .get(&payload_type_id)
                             .cloned()
                             .or_else(|| {
-                                let model_type_id = model_type_id?;
-                                let entity_id = entity_id?;
+                                let model_type_id =
+                                    *state.model_types_by_message_type.get(&payload_type_id)?;
+                                let entity_id = state
+                                    .entity_id_extractors
+                                    .get(&message.payload_type_id())
+                                    .map(|extract_entity_id| {
+                                        (extract_entity_id)(message.as_ref())
+                                    })?;
                                 let model = state
                                     .models_by_entity_type_and_remote_id
                                     .get(&(model_type_id, entity_id))?;
@@ -601,14 +557,6 @@ impl Client {
                             model
                         } else {
                             log::info!("unhandled message {}", type_name);
-                            if let Some((model_type_id, entity_id)) = model_type_id.zip(entity_id) {
-                                state
-                                    .pending_messages
-                                    .entry((model_type_id, entity_id))
-                                    .or_default()
-                                    .push(message);
-                            }
-
                             continue;
                         };
 

crates/project/src/project.rs 🔗

@@ -286,21 +286,7 @@ impl Project {
             load_task.detach();
         }
 
-        let user_ids = response
-            .collaborators
-            .iter()
-            .map(|peer| peer.user_id)
-            .collect();
-        user_store
-            .update(cx, |user_store, cx| user_store.load_users(user_ids, cx))
-            .await?;
-        let mut collaborators = HashMap::default();
-        for message in response.collaborators {
-            let collaborator = Collaborator::from_proto(message, &user_store, cx).await?;
-            collaborators.insert(collaborator.peer_id, collaborator);
-        }
-
-        Ok(cx.add_model(|cx| {
+        let this = cx.add_model(|cx| {
             let mut this = Self {
                 worktrees: Vec::new(),
                 open_buffers: Default::default(),
@@ -308,9 +294,9 @@ impl Project {
                 opened_buffer: broadcast::channel(1).0,
                 shared_buffers: Default::default(),
                 active_entry: None,
-                collaborators,
+                collaborators: Default::default(),
                 languages,
-                user_store,
+                user_store: user_store.clone(),
                 fs,
                 subscriptions: vec![client.add_model_for_remote_entity(remote_id, cx)],
                 client,
@@ -326,7 +312,27 @@ impl Project {
                 this.add_worktree(&worktree, cx);
             }
             this
-        }))
+        });
+
+        let user_ids = response
+            .collaborators
+            .iter()
+            .map(|peer| peer.user_id)
+            .collect();
+        user_store
+            .update(cx, |user_store, cx| user_store.load_users(user_ids, cx))
+            .await?;
+        let mut collaborators = HashMap::default();
+        for message in response.collaborators {
+            let collaborator = Collaborator::from_proto(message, &user_store, cx).await?;
+            collaborators.insert(collaborator.peer_id, collaborator);
+        }
+
+        this.update(cx, |this, _| {
+            this.collaborators = collaborators;
+        });
+
+        Ok(this)
     }
 
     #[cfg(any(test, feature = "test-support"))]

crates/rpc/src/peer.rs 🔗

@@ -5,7 +5,7 @@ use futures::stream::BoxStream;
 use futures::{FutureExt as _, StreamExt};
 use parking_lot::{Mutex, RwLock};
 use postage::{
-    mpsc,
+    barrier, mpsc,
     prelude::{Sink as _, Stream as _},
 };
 use smol_timeout::TimeoutExt as _;
@@ -91,7 +91,8 @@ pub struct Peer {
 pub struct ConnectionState {
     outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
     next_message_id: Arc<AtomicU32>,
-    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
+    response_channels:
+        Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 }
 
 const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
@@ -177,12 +178,18 @@ impl Peer {
                 if let Some(responding_to) = incoming.responding_to {
                     let channel = response_channels.lock().as_mut()?.remove(&responding_to);
                     if let Some(mut tx) = channel {
-                        if let Err(error) = tx.send(incoming).await {
+                        let mut requester_resumed = barrier::channel();
+                        if let Err(error) = tx.send((incoming, requester_resumed.0)).await {
                             log::debug!(
                                 "received RPC but request future was dropped {:?}",
-                                error.0
+                                error.0 .0
                             );
                         }
+                        // Drop response channel before awaiting on the barrier. This allows the
+                        // barrier to get dropped even if the request's future is dropped before it
+                        // has a chance to observe the response.
+                        drop(tx);
+                        requester_resumed.1.recv().await;
                     } else {
                         log::warn!("received RPC response to unknown request {}", responding_to);
                     }
@@ -253,7 +260,7 @@ impl Peer {
         });
         async move {
             send?;
-            let response = rx
+            let (response, _barrier) = rx
                 .recv()
                 .await
                 .ok_or_else(|| anyhow!("connection was closed"))?;