Merge pull request #1224 from zed-industries/forward-deadlock

Antonio Scandurra created

Prevent deadlocks when two clients perform a request to each other

Change summary

crates/collab/src/integration_tests.rs | 81 +++++++++++++++++++++++++--
crates/collab/src/rpc.rs               | 49 +++++++++++-----
crates/gpui/src/executor.rs            | 17 +++--
crates/rpc/src/peer.rs                 | 59 +++++++++++--------
4 files changed, 150 insertions(+), 56 deletions(-)

Detailed changes

crates/collab/src/integration_tests.rs 🔗

@@ -4336,13 +4336,71 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
     );
 }
 
+#[gpui::test(iterations = 10)]
+async fn test_peers_simultaneously_following_each_other(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+
+    let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+    server
+        .make_contacts(vec![(&client_a, cx_a), (&client_b, cx_b)])
+        .await;
+    cx_a.update(editor::init);
+    cx_b.update(editor::init);
+
+    client_a.fs.insert_tree("/a", json!({})).await;
+    let (project_a, _) = client_a.build_local_project("/a", cx_a).await;
+    let workspace_a = client_a.build_workspace(&project_a, cx_a);
+
+    let project_b = client_b.build_remote_project(&project_a, cx_a, cx_b).await;
+    let workspace_b = client_b.build_workspace(&project_b, cx_b);
+
+    deterministic.run_until_parked();
+    let client_a_id = project_b.read_with(cx_b, |project, _| {
+        project.collaborators().values().next().unwrap().peer_id
+    });
+    let client_b_id = project_a.read_with(cx_a, |project, _| {
+        project.collaborators().values().next().unwrap().peer_id
+    });
+
+    let a_follow_b = workspace_a.update(cx_a, |workspace, cx| {
+        workspace
+            .toggle_follow(&ToggleFollow(client_b_id), cx)
+            .unwrap()
+    });
+    let b_follow_a = workspace_b.update(cx_b, |workspace, cx| {
+        workspace
+            .toggle_follow(&ToggleFollow(client_a_id), cx)
+            .unwrap()
+    });
+
+    futures::try_join!(a_follow_b, b_follow_a).unwrap();
+    workspace_a.read_with(cx_a, |workspace, _| {
+        assert_eq!(
+            workspace.leader_for_pane(&workspace.active_pane()),
+            Some(client_b_id)
+        );
+    });
+    workspace_b.read_with(cx_b, |workspace, _| {
+        assert_eq!(
+            workspace.leader_for_pane(&workspace.active_pane()),
+            Some(client_a_id)
+        );
+    });
+}
+
 #[gpui::test(iterations = 100)]
 async fn test_random_collaboration(
     cx: &mut TestAppContext,
     deterministic: Arc<Deterministic>,
     rng: StdRng,
 ) {
-    cx.foreground().forbid_parking();
+    deterministic.forbid_parking();
     let max_peers = env::var("MAX_PEERS")
         .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
         .unwrap_or(5);
@@ -4568,10 +4626,13 @@ async fn test_random_collaboration(
     while operations < max_operations {
         if operations == disconnect_host_at {
             server.disconnect_client(user_ids[0]);
-            cx.foreground().advance_clock(RECEIVE_TIMEOUT);
+            deterministic.advance_clock(RECEIVE_TIMEOUT);
             drop(op_start_signals);
+
+            deterministic.start_waiting();
             let mut clients = futures::future::join_all(clients).await;
-            cx.foreground().run_until_parked();
+            deterministic.finish_waiting();
+            deterministic.run_until_parked();
 
             let (host, host_project, mut host_cx, host_err) = clients.remove(0);
             if let Some(host_err) = host_err {
@@ -4623,6 +4684,8 @@ async fn test_random_collaboration(
                     cx.leak_detector(),
                     next_entity_id,
                 );
+
+                deterministic.start_waiting();
                 let guest = server.create_client(&mut guest_cx, &guest_username).await;
                 let guest_project = Project::remote(
                     host_project_id,
@@ -4635,6 +4698,8 @@ async fn test_random_collaboration(
                 )
                 .await
                 .unwrap();
+                deterministic.finish_waiting();
+
                 let op_start_signal = futures::channel::mpsc::unbounded();
                 user_ids.push(guest.current_user_id(&guest_cx));
                 op_start_signals.push(op_start_signal.0);
@@ -4657,8 +4722,10 @@ async fn test_random_collaboration(
                 op_start_signals.remove(guest_ix);
                 server.forbid_connections();
                 server.disconnect_client(removed_guest_id);
-                cx.foreground().advance_clock(RECEIVE_TIMEOUT);
+                deterministic.advance_clock(RECEIVE_TIMEOUT);
+                deterministic.start_waiting();
                 let (guest, guest_project, mut guest_cx, guest_err) = guest.await;
+                deterministic.finish_waiting();
                 server.allow_connections();
 
                 if let Some(guest_err) = guest_err {
@@ -4708,15 +4775,17 @@ async fn test_random_collaboration(
                 }
 
                 if rng.lock().gen_bool(0.8) {
-                    cx.foreground().run_until_parked();
+                    deterministic.run_until_parked();
                 }
             }
         }
     }
 
     drop(op_start_signals);
+    deterministic.start_waiting();
     let mut clients = futures::future::join_all(clients).await;
-    cx.foreground().run_until_parked();
+    deterministic.finish_waiting();
+    deterministic.run_until_parked();
 
     let (host_client, host_project, mut host_cx, host_err) = clients.remove(0);
     if let Some(host_err) = host_err {

crates/collab/src/rpc.rs 🔗

@@ -26,6 +26,7 @@ use collections::HashMap;
 use futures::{
     channel::mpsc,
     future::{self, BoxFuture},
+    stream::FuturesUnordered,
     FutureExt, SinkExt, StreamExt, TryStreamExt,
 };
 use lazy_static::lazy_static;
@@ -398,6 +399,16 @@ impl Server {
 
             let handle_io = handle_io.fuse();
             futures::pin_mut!(handle_io);
+
+            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
+            // This prevents deadlocks when e.g., client A performs a request to client B and
+            // client B performs a request to client A. If both clients stop processing further
+            // messages until their respective request completes, they won't have a chance to
+            // respond to the other client's request and cause a deadlock.
+            //
+            // This arrangement ensures we will attempt to process earlier messages first, but fall
+            // back to processing messages arrived later in the spirit of making progress.
+            let mut foreground_message_handlers = FuturesUnordered::new();
             loop {
                 let next_message = incoming_rx.next().fuse();
                 futures::pin_mut!(next_message);
@@ -408,30 +419,33 @@ impl Server {
                         }
                         break;
                     }
+                    _ = foreground_message_handlers.next() => {}
                     message = next_message => {
                         if let Some(message) = message {
                             let type_name = message.payload_type_name();
                             let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
-                            async {
-                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
-                                    let notifications = this.notifications.clone();
-                                    let is_background = message.is_background();
-                                    let handle_message = (handler)(this.clone(), message);
-                                    let handle_message = async move {
-                                        handle_message.await;
-                                        if let Some(mut notifications) = notifications {
-                                            let _ = notifications.send(()).await;
-                                        }
-                                    };
-                                    if is_background {
-                                        executor.spawn_detached(handle_message);
-                                    } else {
-                                        handle_message.await;
+                            let span_enter = span.enter();
+                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
+                                let notifications = this.notifications.clone();
+                                let is_background = message.is_background();
+                                let handle_message = (handler)(this.clone(), message);
+
+                                drop(span_enter);
+                                let handle_message = async move {
+                                    handle_message.await;
+                                    if let Some(mut notifications) = notifications {
+                                        let _ = notifications.send(()).await;
                                     }
+                                }.instrument(span);
+
+                                if is_background {
+                                    executor.spawn_detached(handle_message);
                                 } else {
-                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
+                                    foreground_message_handlers.push(handle_message);
                                 }
-                            }.instrument(span).await;
+                            } else {
+                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
+                            }
                         } else {
                             tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
                             break;
@@ -440,6 +454,7 @@ impl Server {
                 }
             }
 
+            drop(foreground_message_handlers);
             tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
             if let Err(error) = this.sign_out(connection_id).await {
                 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");

crates/gpui/src/executor.rs 🔗

@@ -366,6 +366,14 @@ impl Deterministic {
         self.state.lock().now = new_now;
     }
 
+    pub fn start_waiting(&self) {
+        self.state.lock().waiting_backtrace = Some(backtrace::Backtrace::new_unresolved());
+    }
+
+    pub fn finish_waiting(&self) {
+        self.state.lock().waiting_backtrace.take();
+    }
+
     pub fn forbid_parking(&self) {
         use rand::prelude::*;
 
@@ -500,10 +508,7 @@ impl Foreground {
     #[cfg(any(test, feature = "test-support"))]
     pub fn start_waiting(&self) {
         match self {
-            Self::Deterministic { executor, .. } => {
-                executor.state.lock().waiting_backtrace =
-                    Some(backtrace::Backtrace::new_unresolved());
-            }
+            Self::Deterministic { executor, .. } => executor.start_waiting(),
             _ => panic!("this method can only be called on a deterministic executor"),
         }
     }
@@ -511,9 +516,7 @@ impl Foreground {
     #[cfg(any(test, feature = "test-support"))]
     pub fn finish_waiting(&self) {
         match self {
-            Self::Deterministic { executor, .. } => {
-                executor.state.lock().waiting_backtrace.take();
-            }
+            Self::Deterministic { executor, .. } => executor.finish_waiting(),
             _ => panic!("this method can only be called on a deterministic executor"),
         }
     }

crates/rpc/src/peer.rs 🔗

@@ -11,7 +11,6 @@ use futures::{
 };
 use parking_lot::{Mutex, RwLock};
 use serde::{ser::SerializeStruct, Serialize};
-use smol_timeout::TimeoutExt;
 use std::sync::atomic::Ordering::SeqCst;
 use std::{
     fmt,
@@ -177,14 +176,17 @@ impl Peer {
                         outgoing = outgoing_rx.next().fuse() => match outgoing {
                             Some(outgoing) => {
                                 tracing::debug!(%connection_id, "outgoing rpc message: writing");
-                                if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
-                                    tracing::debug!(%connection_id, "outgoing rpc message: done writing");
-                                    result.context("failed to write RPC message")?;
-                                    tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
-                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
-                                } else {
-                                    tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
-                                    Err(anyhow!("timed out writing message"))?;
+                                futures::select_biased! {
+                                    result = writer.write(outgoing).fuse() => {
+                                        tracing::debug!(%connection_id, "outgoing rpc message: done writing");
+                                        result.context("failed to write RPC message")?;
+                                        tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
+                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
+                                    }
+                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
+                                        tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
+                                        Err(anyhow!("timed out writing message"))?;
+                                    }
                                 }
                             }
                             None => {
@@ -199,32 +201,37 @@ impl Peer {
                             receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
                             if let proto::Message::Envelope(incoming) = incoming {
                                 tracing::debug!(%connection_id, "incoming rpc message: processing");
-                                match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
-                                    Some(Ok(_)) => {
-                                        tracing::debug!(%connection_id, "incoming rpc message: processed");
-                                    },
-                                    Some(Err(_)) => {
-                                        tracing::debug!(%connection_id, "incoming rpc message: channel closed");
-                                        return Ok(())
+                                futures::select_biased! {
+                                    result = incoming_tx.send(incoming).fuse() => match result {
+                                        Ok(_) => {
+                                            tracing::debug!(%connection_id, "incoming rpc message: processed");
+                                        }
+                                        Err(_) => {
+                                            tracing::debug!(%connection_id, "incoming rpc message: channel closed");
+                                            return Ok(())
+                                        }
                                     },
-                                    None => {
+                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
                                         tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
                                         Err(anyhow!("timed out processing incoming message"))?
-                                    },
+                                    }
                                 }
                             }
                             break;
                         },
                         _ = keepalive_timer => {
                             tracing::debug!(%connection_id, "keepalive interval: pinging");
-                            if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
-                                tracing::debug!(%connection_id, "keepalive interval: done pinging");
-                                result.context("failed to send keepalive")?;
-                                tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
-                                keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
-                            } else {
-                                tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
-                                Err(anyhow!("timed out sending keepalive"))?;
+                            futures::select_biased! {
+                                result = writer.write(proto::Message::Ping).fuse() => {
+                                    tracing::debug!(%connection_id, "keepalive interval: done pinging");
+                                    result.context("failed to send keepalive")?;
+                                    tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
+                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
+                                }
+                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
+                                    tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
+                                    Err(anyhow!("timed out sending keepalive"))?;
+                                }
                             }
                         }
                         _ = receive_timeout => {