Use a `FuturesUnordered` to process foreground messages

Antonio Scandurra created

This prevents deadlocks when e.g., client A performs a request to client B and
client B performs a request to client A. If both clients stop processing further
messages until their respective request completes, they won't have a chance to
respond to the other client's request and cause a deadlock.

This arrangement ensures we will attempt to process earlier messages first, but fall
back to processing messages arrived later in the spirit of making progress.

Change summary

crates/collab/src/rpc.rs | 49 +++++++++++++++++++++++++++--------------
1 file changed, 32 insertions(+), 17 deletions(-)

Detailed changes

crates/collab/src/rpc.rs 🔗

@@ -26,6 +26,7 @@ use collections::HashMap;
 use futures::{
     channel::mpsc,
     future::{self, BoxFuture},
+    stream::FuturesUnordered,
     FutureExt, SinkExt, StreamExt, TryStreamExt,
 };
 use lazy_static::lazy_static;
@@ -398,6 +399,16 @@ impl Server {
 
             let handle_io = handle_io.fuse();
             futures::pin_mut!(handle_io);
+
+            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
+            // This prevents deadlocks when e.g., client A performs a request to client B and
+            // client B performs a request to client A. If both clients stop processing further
+            // messages until their respective request completes, they won't have a chance to
+            // respond to the other client's request and cause a deadlock.
+            //
+            // This arrangement ensures we will attempt to process earlier messages first, but fall
+            // back to processing messages arrived later in the spirit of making progress.
+            let mut foreground_message_handlers = FuturesUnordered::new();
             loop {
                 let next_message = incoming_rx.next().fuse();
                 futures::pin_mut!(next_message);
@@ -408,30 +419,33 @@ impl Server {
                         }
                         break;
                     }
+                    _ = foreground_message_handlers.next() => {}
                     message = next_message => {
                         if let Some(message) = message {
                             let type_name = message.payload_type_name();
                             let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
-                            async {
-                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
-                                    let notifications = this.notifications.clone();
-                                    let is_background = message.is_background();
-                                    let handle_message = (handler)(this.clone(), message);
-                                    let handle_message = async move {
-                                        handle_message.await;
-                                        if let Some(mut notifications) = notifications {
-                                            let _ = notifications.send(()).await;
-                                        }
-                                    };
-                                    if is_background {
-                                        executor.spawn_detached(handle_message);
-                                    } else {
-                                        handle_message.await;
+                            let span_enter = span.enter();
+                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
+                                let notifications = this.notifications.clone();
+                                let is_background = message.is_background();
+                                let handle_message = (handler)(this.clone(), message);
+
+                                drop(span_enter);
+                                let handle_message = async move {
+                                    handle_message.await;
+                                    if let Some(mut notifications) = notifications {
+                                        let _ = notifications.send(()).await;
                                     }
+                                }.instrument(span);
+
+                                if is_background {
+                                    executor.spawn_detached(handle_message);
                                 } else {
-                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
+                                    foreground_message_handlers.push(handle_message);
                                 }
-                            }.instrument(span).await;
+                            } else {
+                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
+                            }
                         } else {
                             tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
                             break;
@@ -440,6 +454,7 @@ impl Server {
                 }
             }
 
+            drop(foreground_message_handlers);
             tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
             if let Err(error) = this.sign_out(connection_id).await {
                 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");