xmlstream: allow simplex stream shutdown

Jonas Schรคfer created

Unlike poll_close, poll_shutdown will only kill the sending side of the
stream. This is relevant to perform a fully clean shutdown procedure in
XMPP.

Change summary

tokio-xmpp/src/xmlstream/common.rs | 16 ++++++
tokio-xmpp/src/xmlstream/mod.rs    | 81 +++++++++++++++++++++++--------
tokio-xmpp/src/xmlstream/tests.rs  | 70 +++++++++++++++++++++++++++
3 files changed, 146 insertions(+), 21 deletions(-)

Detailed changes

tokio-xmpp/src/xmlstream/common.rs ๐Ÿ”—

@@ -404,6 +404,22 @@ impl<'x, Io: AsyncWrite> RawXmlStreamProj<'x, Io> {
     }
 }
 
+impl<Io: AsyncWrite> RawXmlStream<Io> {
+    /// Flush all buffered data and shut down the sender side of the
+    /// underlying transport.
+    ///
+    /// Unlike `poll_close` (from the `Sink` impls), this will not close the
+    /// receiving side of the underlying the transport. It is advisable to call
+    /// `poll_close` eventually after `poll_shutdown` in order to gracefully
+    /// handle situations where the remote side does not close the stream
+    /// cleanly.
+    pub fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+        ready!(self.as_mut().poll_flush(cx))?;
+        let this = self.project();
+        this.parser.inner_pinned().poll_shutdown(cx)
+    }
+}
+
 impl<'x, Io: AsyncWrite> Sink<xso::Item<'x>> for RawXmlStream<Io> {
     type Error = io::Error;
 

tokio-xmpp/src/xmlstream/mod.rs ๐Ÿ”—

@@ -57,6 +57,7 @@
 //! resetting the stream in a single step.
 
 use core::fmt;
+use core::future::Future;
 use core::pin::Pin;
 use core::task::{Context, Poll};
 use std::io;
@@ -357,22 +358,11 @@ impl<Io: AsyncBufRead, T: FromXml + AsXml + fmt::Debug> Stream for XmlStream<Io,
     }
 }
 
-impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStream<Io, T> {
-    type Error = io::Error;
-
-    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        let this = self.project();
-        this.write_state.check_writable()?;
-        this.inner.poll_ready(cx)
-    }
-
-    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        let this = self.project();
-        this.write_state.check_writable()?;
-        this.inner.poll_flush(cx)
-    }
-
-    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+impl<Io: AsyncWrite, T: FromXml + AsXml> XmlStream<Io, T> {
+    /// Initiate stream shutdown and poll for completion.
+    ///
+    /// Please see [`Self::shutdown`] for details.
+    pub fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
         let mut this = self.project();
         this.write_state.check_ok()?;
         loop {
@@ -401,14 +391,49 @@ impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStr
                     }
                     *this.write_state = WriteState::FooterSent;
                 }
-                // Footer sent => just poll the inner sink for flush.
-                WriteState::FooterSent => {
-                    ready!(this.inner.as_mut().poll_flush(cx)?);
-                    break;
-                }
+                // Footer sent => just close the inner stream.
+                WriteState::FooterSent => break,
                 WriteState::Failed => unreachable!(), // caught by check_ok()
             }
         }
+        this.inner.poll_shutdown(cx)
+    }
+}
+
+impl<Io: AsyncWrite + Unpin, T: FromXml + AsXml> XmlStream<Io, T> {
+    /// Send the stream footer and close the sender side of the underlying
+    /// transport.
+    ///
+    /// Unlike `poll_close` (from the `Sink` impls), this will not close the
+    /// receiving side of the underlying the transport. It is advisable to call
+    /// `poll_close` eventually after `poll_shutdown` in order to gracefully
+    /// handle situations where the remote side does not close the stream
+    /// cleanly.
+    pub fn shutdown(&mut self) -> Shutdown<'_, Io, T> {
+        Shutdown {
+            stream: Pin::new(self),
+        }
+    }
+}
+
+impl<'x, Io: AsyncWrite, T: FromXml + AsXml> Sink<&'x T> for XmlStream<Io, T> {
+    type Error = io::Error;
+
+    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        let this = self.project();
+        this.write_state.check_writable()?;
+        this.inner.poll_ready(cx)
+    }
+
+    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        let this = self.project();
+        this.write_state.check_writable()?;
+        this.inner.poll_flush(cx)
+    }
+
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        ready!(self.as_mut().poll_shutdown(cx))?;
+        let this = self.project();
         this.inner.poll_close(cx)
     }
 
@@ -419,5 +444,19 @@ impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStr
     }
 }
 
+/// Future implementing [`XmlStream::shutdown`] using
+/// [`XmlStream::poll_shutdown`].
+pub struct Shutdown<'a, Io: AsyncWrite, T: FromXml + AsXml> {
+    stream: Pin<&'a mut XmlStream<Io, T>>,
+}
+
+impl<'a, Io: AsyncWrite, T: FromXml + AsXml> Future for Shutdown<'a, Io, T> {
+    type Output = io::Result<()>;
+
+    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+        self.stream.as_mut().poll_shutdown(cx)
+    }
+}
+
 /// Convenience alias for an XML stream using [`XmppStreamElement`].
 pub type XmppStream<Io> = XmlStream<Io, XmppStreamElement>;

tokio-xmpp/src/xmlstream/tests.rs ๐Ÿ”—

@@ -395,3 +395,73 @@ async fn test_emits_soft_timeout_after_silence() {
     responder.await.unwrap().expect("responder failed");
     initiator.await.unwrap().expect("initiator failed");
 }
+
+#[tokio::test]
+async fn test_can_receive_after_shutdown() {
+    let (lhs, rhs) = tokio::io::duplex(65536);
+
+    let initiator = tokio::spawn(async move {
+        let stream = initiate_stream(
+            tokio::io::BufStream::new(lhs),
+            "jabber:client",
+            StreamHeader::default(),
+            Timeouts::tight(),
+        )
+        .await?;
+        let (_, mut stream) = stream.recv_features::<Data>().await?;
+        match stream.next().await {
+            Some(Err(ReadError::StreamFooterReceived)) => (),
+            other => panic!("unexpected stream message: {:?}", other),
+        }
+        match stream.next().await {
+            None => (),
+            other => panic!("unexpected stream message: {:?}", other),
+        }
+        stream
+            .send(&Data {
+                contents: "hello".to_owned(),
+            })
+            .await?;
+        stream
+            .send(&Data {
+                contents: "world!".to_owned(),
+            })
+            .await?;
+        stream.close().await?;
+        Ok::<_, io::Error>(())
+    });
+
+    let responder = tokio::spawn(async move {
+        let stream = accept_stream(
+            tokio::io::BufStream::new(rhs),
+            "jabber:client",
+            Timeouts::tight(),
+        )
+        .await?;
+        let stream = stream.send_header(StreamHeader::default()).await?;
+        let mut stream = stream
+            .send_features::<Data>(&StreamFeatures::default())
+            .await?;
+        stream.shutdown().await?;
+        match stream.next().await {
+            Some(Ok(Data { contents })) => assert_eq!(contents, "hello"),
+            other => panic!("unexpected stream message: {:?}", other),
+        }
+        match stream.next().await {
+            Some(Ok(Data { contents })) => assert_eq!(contents, "world!"),
+            other => panic!("unexpected stream message: {:?}", other),
+        }
+        match stream.next().await {
+            Some(Err(ReadError::StreamFooterReceived)) => (),
+            other => panic!("unexpected stream message: {:?}", other),
+        }
+        match stream.next().await {
+            None => (),
+            other => panic!("unexpected stream message: {:?}", other),
+        }
+        Ok::<_, io::Error>(())
+    });
+
+    responder.await.unwrap().expect("responder failed");
+    initiator.await.unwrap().expect("initiator failed");
+}