Ensure response barrier is always dropped, even if request is canceled

Antonio Scandurra created

Change summary

crates/gpui/src/executor.rs |   7 +
crates/rpc/src/peer.rs      | 124 +++++++++++++++++++++++++++++++++++++-
2 files changed, 125 insertions(+), 6 deletions(-)

Detailed changes

crates/gpui/src/executor.rs 🔗

@@ -550,8 +550,11 @@ impl Background {
     pub async fn simulate_random_delay(&self) {
         match self {
             Self::Deterministic { executor, .. } => {
-                if executor.state.lock().rng.gen_range(0..100) < 20 {
-                    yield_now().await;
+                if executor.state.lock().rng.gen_bool(0.2) {
+                    let yields = executor.state.lock().rng.gen_range(1..=10);
+                    for _ in 0..yields {
+                        yield_now().await;
+                    }
                 }
             }
             _ => panic!("this method can only be called on a deterministic executor"),

crates/rpc/src/peer.rs 🔗

@@ -180,6 +180,10 @@ impl Peer {
                     if let Some(mut tx) = channel {
                         let mut requester_resumed = barrier::channel();
                         tx.send((incoming, requester_resumed.0)).await.ok();
+                        // Drop response channel before awaiting on the barrier. This allows the
+                        // barrier to get dropped even if the request's future is dropped before it
+                        // has a chance to observe the response.
+                        drop(tx);
                         requester_resumed.1.recv().await;
                     } else {
                         log::warn!("received RPC response to unknown request {}", responding_to);
@@ -337,7 +341,7 @@ mod tests {
     use async_tungstenite::tungstenite::Message as WebSocketMessage;
     use gpui::TestAppContext;
 
-    #[gpui::test(iterations = 10)]
+    #[gpui::test(iterations = 50)]
     async fn test_request_response(cx: TestAppContext) {
         let executor = cx.foreground();
 
@@ -478,7 +482,7 @@ mod tests {
         }
     }
 
-    #[gpui::test(iterations = 10)]
+    #[gpui::test(iterations = 50)]
     async fn test_order_of_response_and_incoming(cx: TestAppContext) {
         let executor = cx.foreground();
         let server = Peer::new();
@@ -576,7 +580,119 @@ mod tests {
         );
     }
 
-    #[gpui::test(iterations = 10)]
+    #[gpui::test(iterations = 50)]
+    async fn test_dropping_request_before_completion(cx: TestAppContext) {
+        let executor = cx.foreground();
+        let server = Peer::new();
+        let client = Peer::new();
+
+        let (client_to_server_conn, server_to_client_conn, _) =
+            Connection::in_memory(cx.background());
+        let (client_to_server_conn_id, io_task1, mut client_incoming) =
+            client.add_connection(client_to_server_conn).await;
+        let (server_to_client_conn_id, io_task2, mut server_incoming) =
+            server.add_connection(server_to_client_conn).await;
+
+        executor.spawn(io_task1).detach();
+        executor.spawn(io_task2).detach();
+
+        executor
+            .spawn(async move {
+                let request1 = server_incoming
+                    .next()
+                    .await
+                    .unwrap()
+                    .into_any()
+                    .downcast::<TypedEnvelope<proto::Ping>>()
+                    .unwrap();
+                let request2 = server_incoming
+                    .next()
+                    .await
+                    .unwrap()
+                    .into_any()
+                    .downcast::<TypedEnvelope<proto::Ping>>()
+                    .unwrap();
+
+                server
+                    .send(
+                        server_to_client_conn_id,
+                        proto::Error {
+                            message: "message 1".to_string(),
+                        },
+                    )
+                    .unwrap();
+                server
+                    .send(
+                        server_to_client_conn_id,
+                        proto::Error {
+                            message: "message 2".to_string(),
+                        },
+                    )
+                    .unwrap();
+                server.respond(request1.receipt(), proto::Ack {}).unwrap();
+                server.respond(request2.receipt(), proto::Ack {}).unwrap();
+
+                // Prevent the connection from being dropped
+                server_incoming.next().await;
+            })
+            .detach();
+
+        let events = Arc::new(Mutex::new(Vec::new()));
+
+        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
+        let request1_task = executor.spawn(request1);
+        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
+        let request2_task = executor.spawn({
+            let events = events.clone();
+            async move {
+                request2.await.unwrap();
+                events.lock().push("response 2".to_string());
+            }
+        });
+
+        executor
+            .spawn({
+                let events = events.clone();
+                async move {
+                    let incoming1 = client_incoming
+                        .next()
+                        .await
+                        .unwrap()
+                        .into_any()
+                        .downcast::<TypedEnvelope<proto::Error>>()
+                        .unwrap();
+                    events.lock().push(incoming1.payload.message);
+                    let incoming2 = client_incoming
+                        .next()
+                        .await
+                        .unwrap()
+                        .into_any()
+                        .downcast::<TypedEnvelope<proto::Error>>()
+                        .unwrap();
+                    events.lock().push(incoming2.payload.message);
+
+                    // Prevent the connection from being dropped
+                    client_incoming.next().await;
+                }
+            })
+            .detach();
+
+        // Allow the request to make some progress before dropping it.
+        cx.background().simulate_random_delay().await;
+        drop(request1_task);
+
+        request2_task.await;
+        assert_eq!(
+            &*events.lock(),
+            &[
+                "message 1".to_string(),
+                "message 2".to_string(),
+                "response 2".to_string()
+            ]
+        );
+    }
+
+    #[gpui::test(iterations = 50)]
     async fn test_disconnect(cx: TestAppContext) {
         let executor = cx.foreground();
 
@@ -611,7 +727,7 @@ mod tests {
             .is_err());
     }
 
-    #[gpui::test(iterations = 10)]
+    #[gpui::test(iterations = 50)]
     async fn test_io_error(cx: TestAppContext) {
         let executor = cx.foreground();
         let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());