Prevent requests from hanging when shutting down a connection

Antonio Scandurra created

When closing a connection (either due to an error or simply because the
user wanted to), we will now  *take* `response_channels` as opposed to
clearing them. This ensures that `Peer::request` can't succeed in both
adding the oneshot channel in `response_channels` map _and_ submit the
message onto the `outgoing_tx` channel.

This also streamlines how we close a connection by unifying all the exit
code paths of the IO handling future.

Change summary

zrpc/src/peer.rs | 32 ++++++++++++++------------------
1 file changed, 14 insertions(+), 18 deletions(-)

Detailed changes

zrpc/src/peer.rs 🔗

@@ -87,7 +87,7 @@ pub struct Peer {
 struct ConnectionState {
     outgoing_tx: mpsc::Sender<proto::Envelope>,
     next_message_id: Arc<AtomicU32>,
-    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
+    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 }
 
 impl Peer {
@@ -123,7 +123,7 @@ impl Peer {
         let this = self.clone();
         let response_channels = connection_state.response_channels.clone();
         let handle_io = async move {
-            loop {
+            let result = 'outer: loop {
                 let read_message = reader.read_message().fuse();
                 futures::pin_mut!(read_message);
                 loop {
@@ -131,7 +131,7 @@ impl Peer {
                         incoming = read_message => match incoming {
                             Ok(incoming) => {
                                 if let Some(responding_to) = incoming.responding_to {
-                                    let channel = response_channels.lock().await.remove(&responding_to);
+                                    let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
                                     if let Some(mut tx) = channel {
                                         tx.send(incoming).await.ok();
                                     } else {
@@ -140,9 +140,7 @@ impl Peer {
                                 } else {
                                     if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
                                         if incoming_tx.send(envelope).await.is_err() {
-                                            response_channels.lock().await.clear();
-                                            this.connections.write().await.remove(&connection_id);
-                                            return Ok(())
+                                            break 'outer Ok(())
                                         }
                                     } else {
                                         log::error!("unable to construct a typed envelope");
@@ -152,28 +150,24 @@ impl Peer {
                                 break;
                             }
                             Err(error) => {
-                                response_channels.lock().await.clear();
-                                this.connections.write().await.remove(&connection_id);
-                                Err(error).context("received invalid RPC message")?;
+                                break 'outer Err(error).context("received invalid RPC message")
                             }
                         },
                         outgoing = outgoing_rx.recv().fuse() => match outgoing {
                             Some(outgoing) => {
                                 if let Err(result) = writer.write_message(&outgoing).await {
-                                    response_channels.lock().await.clear();
-                                    this.connections.write().await.remove(&connection_id);
-                                    Err(result).context("failed to write RPC message")?;
+                                    break 'outer Err(result).context("failed to write RPC message")
                                 }
                             }
-                            None => {
-                                response_channels.lock().await.clear();
-                                this.connections.write().await.remove(&connection_id);
-                                return Ok(())
-                            }
+                            None => break 'outer Ok(()),
                         }
                     }
                 }
-            }
+            };
+
+            response_channels.lock().await.take();
+            this.connections.write().await.remove(&connection_id);
+            result
         };
 
         self.connections
@@ -226,6 +220,8 @@ impl Peer {
                 .response_channels
                 .lock()
                 .await
+                .as_mut()
+                .ok_or_else(|| anyhow!("connection was closed"))?
                 .insert(message_id, tx);
             connection
                 .outgoing_tx