If a test connection has been killed, never return a message

Nathan Sobo created

Change summary

zrpc/src/conn.rs | 46 ++++++++++++++++++++++++++++++----------------
1 file changed, 30 insertions(+), 16 deletions(-)

Detailed changes

zrpc/src/conn.rs 🔗

@@ -1,5 +1,6 @@
 use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{SinkExt as _, StreamExt as _};
+use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
+use std::{io, task::Poll};
 
 pub struct Conn {
     pub(crate) tx:
@@ -53,10 +54,10 @@ impl Conn {
         Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
         Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
     ) {
-        use futures::{future, stream, SinkExt as _, StreamExt as _};
-        use std::io::{Error, ErrorKind};
+        use futures::{future, SinkExt as _};
+        use io::{Error, ErrorKind};
 
-        let (tx, rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
+        let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
         let tx = tx
             .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
             .with({
@@ -69,19 +70,32 @@ impl Conn {
                     }
                 }
             });
-        let rx = stream::select(
-            rx.map(Ok),
-            kill_rx.filter_map(|kill| {
-                if kill.is_none() {
-                    future::ready(None)
-                } else {
-                    future::ready(Some(Err(
-                        Error::new(ErrorKind::Other, "connection killed").into()
-                    )))
-                }
-            }),
-        );
+        let rx = KillableReceiver { kill_rx, rx };
 
         (Box::new(tx), Box::new(rx))
     }
 }
+
+struct KillableReceiver {
+    rx: mpsc::UnboundedReceiver<WebSocketMessage>,
+    kill_rx: postage::watch::Receiver<Option<()>>,
+}
+
+impl Stream for KillableReceiver {
+    type Item = Result<WebSocketMessage, WebSocketError>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
+            Poll::Ready(Some(Err(io::Error::new(
+                io::ErrorKind::Other,
+                "connection killed",
+            )
+            .into())))
+        } else {
+            self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
+        }
+    }
+}