diff --git a/tokio-xmpp/src/xmlstream/common.rs b/tokio-xmpp/src/xmlstream/common.rs index ff509824745f13fd8045b1abc60d726fae316815..c94a6c0e9084eefee3ad1e69bf43a9d9a3af7ddc 100644 --- a/tokio-xmpp/src/xmlstream/common.rs +++ b/tokio-xmpp/src/xmlstream/common.rs @@ -404,6 +404,22 @@ impl<'x, Io: AsyncWrite> RawXmlStreamProj<'x, Io> { } } +impl RawXmlStream { + /// Flush all buffered data and shut down the sender side of the + /// underlying transport. + /// + /// Unlike `poll_close` (from the `Sink` impls), this will not close the + /// receiving side of the underlying the transport. It is advisable to call + /// `poll_close` eventually after `poll_shutdown` in order to gracefully + /// handle situations where the remote side does not close the stream + /// cleanly. + pub fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + let this = self.project(); + this.parser.inner_pinned().poll_shutdown(cx) + } +} + impl<'x, Io: AsyncWrite> Sink> for RawXmlStream { type Error = io::Error; diff --git a/tokio-xmpp/src/xmlstream/mod.rs b/tokio-xmpp/src/xmlstream/mod.rs index 5e4149660a5b9f7d8ba953da11fd1669a8473683..f077fe11da6ca026a9b0208806a3f20333e53747 100644 --- a/tokio-xmpp/src/xmlstream/mod.rs +++ b/tokio-xmpp/src/xmlstream/mod.rs @@ -57,6 +57,7 @@ //! resetting the stream in a single step. use core::fmt; +use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; use std::io; @@ -357,22 +358,11 @@ impl Stream for XmlStream Sink<&'x T> for XmlStream { - type Error = io::Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - this.write_state.check_writable()?; - this.inner.poll_ready(cx) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - this.write_state.check_writable()?; - this.inner.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +impl XmlStream { + /// Initiate stream shutdown and poll for completion. + /// + /// Please see [`Self::shutdown`] for details. + pub fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let mut this = self.project(); this.write_state.check_ok()?; loop { @@ -401,14 +391,49 @@ impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStr } *this.write_state = WriteState::FooterSent; } - // Footer sent => just poll the inner sink for flush. - WriteState::FooterSent => { - ready!(this.inner.as_mut().poll_flush(cx)?); - break; - } + // Footer sent => just close the inner stream. + WriteState::FooterSent => break, WriteState::Failed => unreachable!(), // caught by check_ok() } } + this.inner.poll_shutdown(cx) + } +} + +impl XmlStream { + /// Send the stream footer and close the sender side of the underlying + /// transport. + /// + /// Unlike `poll_close` (from the `Sink` impls), this will not close the + /// receiving side of the underlying the transport. It is advisable to call + /// `poll_close` eventually after `poll_shutdown` in order to gracefully + /// handle situations where the remote side does not close the stream + /// cleanly. + pub fn shutdown(&mut self) -> Shutdown<'_, Io, T> { + Shutdown { + stream: Pin::new(self), + } + } +} + +impl<'x, Io: AsyncWrite, T: FromXml + AsXml> Sink<&'x T> for XmlStream { + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.write_state.check_writable()?; + this.inner.poll_ready(cx) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.write_state.check_writable()?; + this.inner.poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_shutdown(cx))?; + let this = self.project(); this.inner.poll_close(cx) } @@ -419,5 +444,19 @@ impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStr } } +/// Future implementing [`XmlStream::shutdown`] using +/// [`XmlStream::poll_shutdown`]. +pub struct Shutdown<'a, Io: AsyncWrite, T: FromXml + AsXml> { + stream: Pin<&'a mut XmlStream>, +} + +impl<'a, Io: AsyncWrite, T: FromXml + AsXml> Future for Shutdown<'a, Io, T> { + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + self.stream.as_mut().poll_shutdown(cx) + } +} + /// Convenience alias for an XML stream using [`XmppStreamElement`]. pub type XmppStream = XmlStream; diff --git a/tokio-xmpp/src/xmlstream/tests.rs b/tokio-xmpp/src/xmlstream/tests.rs index 44c0256fba7c993ce2525236e4c005d6eb9dd08b..3c8d8745608b32bac19f20fd0fb71ccc1a06bc33 100644 --- a/tokio-xmpp/src/xmlstream/tests.rs +++ b/tokio-xmpp/src/xmlstream/tests.rs @@ -395,3 +395,73 @@ async fn test_emits_soft_timeout_after_silence() { responder.await.unwrap().expect("responder failed"); initiator.await.unwrap().expect("initiator failed"); } + +#[tokio::test] +async fn test_can_receive_after_shutdown() { + let (lhs, rhs) = tokio::io::duplex(65536); + + let initiator = tokio::spawn(async move { + let stream = initiate_stream( + tokio::io::BufStream::new(lhs), + "jabber:client", + StreamHeader::default(), + Timeouts::tight(), + ) + .await?; + let (_, mut stream) = stream.recv_features::().await?; + match stream.next().await { + Some(Err(ReadError::StreamFooterReceived)) => (), + other => panic!("unexpected stream message: {:?}", other), + } + match stream.next().await { + None => (), + other => panic!("unexpected stream message: {:?}", other), + } + stream + .send(&Data { + contents: "hello".to_owned(), + }) + .await?; + stream + .send(&Data { + contents: "world!".to_owned(), + }) + .await?; + stream.close().await?; + Ok::<_, io::Error>(()) + }); + + let responder = tokio::spawn(async move { + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; + let stream = stream.send_header(StreamHeader::default()).await?; + let mut stream = stream + .send_features::(&StreamFeatures::default()) + .await?; + stream.shutdown().await?; + match stream.next().await { + Some(Ok(Data { contents })) => assert_eq!(contents, "hello"), + other => panic!("unexpected stream message: {:?}", other), + } + match stream.next().await { + Some(Ok(Data { contents })) => assert_eq!(contents, "world!"), + other => panic!("unexpected stream message: {:?}", other), + } + match stream.next().await { + Some(Err(ReadError::StreamFooterReceived)) => (), + other => panic!("unexpected stream message: {:?}", other), + } + match stream.next().await { + None => (), + other => panic!("unexpected stream message: {:?}", other), + } + Ok::<_, io::Error>(()) + }); + + responder.await.unwrap().expect("responder failed"); + initiator.await.unwrap().expect("initiator failed"); +}