Merge pull request #180 from zed-industries/peer-test-io-error-flaky

Antonio Scandurra created

Fix flaky `zrpc::tests::peer::test_io_error` test

Change summary

zrpc/src/peer.rs | 50 ++++++++++++++++++++++++--------------------------
1 file changed, 24 insertions(+), 26 deletions(-)

Detailed changes

zrpc/src/peer.rs 🔗

@@ -87,7 +87,7 @@ pub struct Peer {
 struct ConnectionState {
     outgoing_tx: mpsc::Sender<proto::Envelope>,
     next_message_id: Arc<AtomicU32>,
-    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
+    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 }
 
 impl Peer {
@@ -115,7 +115,7 @@ impl Peer {
         let connection_state = ConnectionState {
             outgoing_tx,
             next_message_id: Default::default(),
-            response_channels: Default::default(),
+            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
         };
         let mut writer = MessageStream::new(connection.tx);
         let mut reader = MessageStream::new(connection.rx);
@@ -123,7 +123,7 @@ impl Peer {
         let this = self.clone();
         let response_channels = connection_state.response_channels.clone();
         let handle_io = async move {
-            loop {
+            let result = 'outer: loop {
                 let read_message = reader.read_message().fuse();
                 futures::pin_mut!(read_message);
                 loop {
@@ -131,7 +131,7 @@ impl Peer {
                         incoming = read_message => match incoming {
                             Ok(incoming) => {
                                 if let Some(responding_to) = incoming.responding_to {
-                                    let channel = response_channels.lock().await.remove(&responding_to);
+                                    let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
                                     if let Some(mut tx) = channel {
                                         tx.send(incoming).await.ok();
                                     } else {
@@ -140,9 +140,7 @@ impl Peer {
                                 } else {
                                     if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
                                         if incoming_tx.send(envelope).await.is_err() {
-                                            response_channels.lock().await.clear();
-                                            this.connections.write().await.remove(&connection_id);
-                                            return Ok(())
+                                            break 'outer Ok(())
                                         }
                                     } else {
                                         log::error!("unable to construct a typed envelope");
@@ -152,28 +150,24 @@ impl Peer {
                                 break;
                             }
                             Err(error) => {
-                                response_channels.lock().await.clear();
-                                this.connections.write().await.remove(&connection_id);
-                                Err(error).context("received invalid RPC message")?;
+                                break 'outer Err(error).context("received invalid RPC message")
                             }
                         },
                         outgoing = outgoing_rx.recv().fuse() => match outgoing {
                             Some(outgoing) => {
                                 if let Err(result) = writer.write_message(&outgoing).await {
-                                    response_channels.lock().await.clear();
-                                    this.connections.write().await.remove(&connection_id);
-                                    Err(result).context("failed to write RPC message")?;
+                                    break 'outer Err(result).context("failed to write RPC message")
                                 }
                             }
-                            None => {
-                                response_channels.lock().await.clear();
-                                this.connections.write().await.remove(&connection_id);
-                                return Ok(())
-                            }
+                            None => break 'outer Ok(()),
                         }
                     }
                 }
-            }
+            };
+
+            response_channels.lock().await.take();
+            this.connections.write().await.remove(&connection_id);
+            result
         };
 
         self.connections
@@ -226,6 +220,8 @@ impl Peer {
                 .response_channels
                 .lock()
                 .await
+                .as_mut()
+                .ok_or_else(|| anyhow!("connection was closed"))?
                 .insert(message_id, tx);
             connection
                 .outgoing_tx
@@ -520,8 +516,7 @@ mod tests {
     #[test]
     fn test_io_error() {
         smol::block_on(async move {
-            let (client_conn, server_conn, _) = Connection::in_memory();
-            drop(server_conn);
+            let (client_conn, mut server_conn, _) = Connection::in_memory();
 
             let client = Peer::new();
             let (connection_id, io_handler, mut incoming) =
@@ -529,11 +524,14 @@ mod tests {
             smol::spawn(io_handler).detach();
             smol::spawn(async move { incoming.next().await }).detach();
 
-            let err = client
-                .request(connection_id, proto::Ping {})
-                .await
-                .unwrap_err();
-            assert_eq!(err.to_string(), "connection was closed");
+            let response = smol::spawn(client.request(connection_id, proto::Ping {}));
+            let _request = server_conn.rx.next().await.unwrap().unwrap();
+
+            drop(server_conn);
+            assert_eq!(
+                response.await.unwrap_err().to_string(),
+                "connection was closed"
+            );
         });
     }
 }