Insert random delays when sending and receiving websocket messages in tests

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/client/src/test.rs   |  2 
crates/gpui/src/executor.rs | 13 +++++++++++
crates/rpc/Cargo.toml       |  3 +
crates/rpc/src/conn.rs      | 40 +++++++++++++++++++++++++++-----------
crates/rpc/src/peer.rs      | 13 +++++++----
crates/server/src/rpc.rs    |  3 +
6 files changed, 53 insertions(+), 21 deletions(-)

Detailed changes

crates/client/src/test.rs 🔗

@@ -94,7 +94,7 @@ impl FakeServer {
             Err(EstablishConnectionError::Unauthorized)?
         }
 
-        let (client_conn, server_conn, _) = Connection::in_memory();
+        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
         let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
         cx.background().spawn(io).detach();
         *self.incoming.lock() = Some(incoming);

crates/gpui/src/executor.rs 🔗

@@ -5,7 +5,7 @@ use collections::HashMap;
 use parking_lot::Mutex;
 use postage::{barrier, prelude::Stream as _};
 use rand::prelude::*;
-use smol::{channel, prelude::*, Executor, Timer};
+use smol::{channel, future::yield_now, prelude::*, Executor, Timer};
 use std::{
     any::Any,
     fmt::{self, Debug, Display},
@@ -528,6 +528,17 @@ impl Background {
             task.await;
         }
     }
+
+    pub async fn simulate_random_delay(&self) {
+        match self {
+            Self::Deterministic { executor, .. } => {
+                if executor.state.lock().rng.gen_range(0..100) < 20 {
+                    yield_now().await;
+                }
+            }
+            _ => panic!("this method can only be called on a deterministic executor"),
+        }
+    }
 }
 
 pub struct Scope<'a> {

crates/rpc/Cargo.toml 🔗

@@ -8,7 +8,7 @@ version = "0.1.0"
 path = "src/rpc.rs"
 
 [features]
-test-support = []
+test-support = ["gpui"]
 
 [dependencies]
 anyhow = "1.0"
@@ -25,6 +25,7 @@ rsa = "0.4"
 serde = { version = "1", features = ["derive"] }
 smol-timeout = "0.6"
 zstd = "0.9"
+gpui = { path = "../gpui", features = ["test-support"], optional = true }
 
 [build-dependencies]
 prost-build = "0.8"

crates/rpc/src/conn.rs 🔗

@@ -34,12 +34,14 @@ impl Connection {
     }
 
     #[cfg(any(test, feature = "test-support"))]
-    pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
+    pub fn in_memory(
+        executor: std::sync::Arc<gpui::executor::Background>,
+    ) -> (Self, Self, postage::watch::Sender<Option<()>>) {
         let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
         postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
 
-        let (a_tx, a_rx) = Self::channel(kill_rx.clone());
-        let (b_tx, b_rx) = Self::channel(kill_rx);
+        let (a_tx, a_rx) = Self::channel(kill_rx.clone(), executor.clone());
+        let (b_tx, b_rx) = Self::channel(kill_rx, executor);
         (
             Self { tx: a_tx, rx: b_rx },
             Self { tx: b_tx, rx: a_rx },
@@ -50,11 +52,12 @@ impl Connection {
     #[cfg(any(test, feature = "test-support"))]
     fn channel(
         kill_rx: postage::watch::Receiver<Option<()>>,
+        executor: std::sync::Arc<gpui::executor::Background>,
     ) -> (
         Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
         Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
     ) {
-        use futures::{future, SinkExt as _};
+        use futures::SinkExt as _;
         use io::{Error, ErrorKind};
 
         let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
@@ -62,26 +65,39 @@ impl Connection {
             .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
             .with({
                 let kill_rx = kill_rx.clone();
+                let executor = executor.clone();
                 move |msg| {
-                    if kill_rx.borrow().is_none() {
-                        future::ready(Ok(msg))
-                    } else {
-                        future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
-                    }
+                    let kill_rx = kill_rx.clone();
+                    let executor = executor.clone();
+                    Box::pin(async move {
+                        executor.simulate_random_delay().await;
+                        if kill_rx.borrow().is_none() {
+                            Ok(msg)
+                        } else {
+                            Err(Error::new(ErrorKind::Other, "connection killed").into())
+                        }
+                    })
                 }
             });
+        let rx = rx.then(move |msg| {
+            let executor = executor.clone();
+            Box::pin(async move {
+                executor.simulate_random_delay().await;
+                msg
+            })
+        });
         let rx = KillableReceiver { kill_rx, rx };
 
         (Box::new(tx), Box::new(rx))
     }
 }
 
-struct KillableReceiver {
-    rx: mpsc::UnboundedReceiver<WebSocketMessage>,
+struct KillableReceiver<S> {
+    rx: S,
     kill_rx: postage::watch::Receiver<Option<()>>,
 }
 
-impl Stream for KillableReceiver {
+impl<S: Unpin + Stream<Item = WebSocketMessage>> Stream for KillableReceiver<S> {
     type Item = Result<WebSocketMessage, WebSocketError>;
 
     fn poll_next(

crates/rpc/src/peer.rs 🔗

@@ -353,12 +353,14 @@ mod tests {
         let client1 = Peer::new();
         let client2 = Peer::new();
 
-        let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
+        let (client1_to_server_conn, server_to_client_1_conn, _) =
+            Connection::in_memory(cx.background());
         let (client1_conn_id, io_task1, client1_incoming) =
             client1.add_connection(client1_to_server_conn).await;
         let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
 
-        let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
+        let (client2_to_server_conn, server_to_client_2_conn, _) =
+            Connection::in_memory(cx.background());
         let (client2_conn_id, io_task3, client2_incoming) =
             client2.add_connection(client2_to_server_conn).await;
         let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
@@ -489,7 +491,8 @@ mod tests {
         let server = Peer::new();
         let client = Peer::new();
 
-        let (client_to_server_conn, server_to_client_conn, _) = Connection::in_memory();
+        let (client_to_server_conn, server_to_client_conn, _) =
+            Connection::in_memory(cx.background());
         let (client_to_server_conn_id, io_task1, mut client_incoming) =
             client.add_connection(client_to_server_conn).await;
         let (server_to_client_conn_id, io_task2, mut server_incoming) =
@@ -589,7 +592,7 @@ mod tests {
     async fn test_disconnect(cx: TestAppContext) {
         let executor = cx.foreground();
 
-        let (client_conn, mut server_conn, _) = Connection::in_memory();
+        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
 
         let client = Peer::new();
         let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
@@ -623,7 +626,7 @@ mod tests {
     #[gpui::test(iterations = 10)]
     async fn test_io_error(cx: TestAppContext) {
         let executor = cx.foreground();
-        let (client_conn, mut server_conn, _) = Connection::in_memory();
+        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
 
         let client = Peer::new();
         let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;

crates/server/src/rpc.rs 🔗

@@ -3242,7 +3242,8 @@ mod tests {
                                 "server is forbidding connections"
                             )))
                         } else {
-                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
+                            let (client_conn, server_conn, kill_conn) =
+                                Connection::in_memory(cx.background());
                             connection_killers.lock().insert(user_id, kill_conn);
                             cx.background()
                                 .spawn(server.handle_connection(