@@ -404,6 +404,22 @@ impl<'x, Io: AsyncWrite> RawXmlStreamProj<'x, Io> {
}
}
+impl<Io: AsyncWrite> RawXmlStream<Io> {
+ /// Flush all buffered data and shut down the sender side of the
+ /// underlying transport.
+ ///
+ /// Unlike `poll_close` (from the `Sink` impls), this will not close the
+ /// receiving side of the underlying the transport. It is advisable to call
+ /// `poll_close` eventually after `poll_shutdown` in order to gracefully
+ /// handle situations where the remote side does not close the stream
+ /// cleanly.
+ pub fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ ready!(self.as_mut().poll_flush(cx))?;
+ let this = self.project();
+ this.parser.inner_pinned().poll_shutdown(cx)
+ }
+}
+
impl<'x, Io: AsyncWrite> Sink<xso::Item<'x>> for RawXmlStream<Io> {
type Error = io::Error;
@@ -57,6 +57,7 @@
//! resetting the stream in a single step.
use core::fmt;
+use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use std::io;
@@ -357,22 +358,11 @@ impl<Io: AsyncBufRead, T: FromXml + AsXml + fmt::Debug> Stream for XmlStream<Io,
}
}
-impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStream<Io, T> {
- type Error = io::Error;
-
- fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- let this = self.project();
- this.write_state.check_writable()?;
- this.inner.poll_ready(cx)
- }
-
- fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- let this = self.project();
- this.write_state.check_writable()?;
- this.inner.poll_flush(cx)
- }
-
- fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+impl<Io: AsyncWrite, T: FromXml + AsXml> XmlStream<Io, T> {
+ /// Initiate stream shutdown and poll for completion.
+ ///
+ /// Please see [`Self::shutdown`] for details.
+ pub fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let mut this = self.project();
this.write_state.check_ok()?;
loop {
@@ -401,14 +391,49 @@ impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStr
}
*this.write_state = WriteState::FooterSent;
}
- // Footer sent => just poll the inner sink for flush.
- WriteState::FooterSent => {
- ready!(this.inner.as_mut().poll_flush(cx)?);
- break;
- }
+ // Footer sent => just close the inner stream.
+ WriteState::FooterSent => break,
WriteState::Failed => unreachable!(), // caught by check_ok()
}
}
+ this.inner.poll_shutdown(cx)
+ }
+}
+
+impl<Io: AsyncWrite + Unpin, T: FromXml + AsXml> XmlStream<Io, T> {
+ /// Send the stream footer and close the sender side of the underlying
+ /// transport.
+ ///
+ /// Unlike `poll_close` (from the `Sink` impls), this will not close the
+ /// receiving side of the underlying the transport. It is advisable to call
+ /// `poll_close` eventually after `poll_shutdown` in order to gracefully
+ /// handle situations where the remote side does not close the stream
+ /// cleanly.
+ pub fn shutdown(&mut self) -> Shutdown<'_, Io, T> {
+ Shutdown {
+ stream: Pin::new(self),
+ }
+ }
+}
+
+impl<'x, Io: AsyncWrite, T: FromXml + AsXml> Sink<&'x T> for XmlStream<Io, T> {
+ type Error = io::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ let this = self.project();
+ this.write_state.check_writable()?;
+ this.inner.poll_ready(cx)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ let this = self.project();
+ this.write_state.check_writable()?;
+ this.inner.poll_flush(cx)
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ ready!(self.as_mut().poll_shutdown(cx))?;
+ let this = self.project();
this.inner.poll_close(cx)
}
@@ -419,5 +444,19 @@ impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStr
}
}
+/// Future implementing [`XmlStream::shutdown`] using
+/// [`XmlStream::poll_shutdown`].
+pub struct Shutdown<'a, Io: AsyncWrite, T: FromXml + AsXml> {
+ stream: Pin<&'a mut XmlStream<Io, T>>,
+}
+
+impl<'a, Io: AsyncWrite, T: FromXml + AsXml> Future for Shutdown<'a, Io, T> {
+ type Output = io::Result<()>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+ self.stream.as_mut().poll_shutdown(cx)
+ }
+}
+
/// Convenience alias for an XML stream using [`XmppStreamElement`].
pub type XmppStream<Io> = XmlStream<Io, XmppStreamElement>;
@@ -395,3 +395,73 @@ async fn test_emits_soft_timeout_after_silence() {
responder.await.unwrap().expect("responder failed");
initiator.await.unwrap().expect("initiator failed");
}
+
+#[tokio::test]
+async fn test_can_receive_after_shutdown() {
+ let (lhs, rhs) = tokio::io::duplex(65536);
+
+ let initiator = tokio::spawn(async move {
+ let stream = initiate_stream(
+ tokio::io::BufStream::new(lhs),
+ "jabber:client",
+ StreamHeader::default(),
+ Timeouts::tight(),
+ )
+ .await?;
+ let (_, mut stream) = stream.recv_features::<Data>().await?;
+ match stream.next().await {
+ Some(Err(ReadError::StreamFooterReceived)) => (),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ match stream.next().await {
+ None => (),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ stream
+ .send(&Data {
+ contents: "hello".to_owned(),
+ })
+ .await?;
+ stream
+ .send(&Data {
+ contents: "world!".to_owned(),
+ })
+ .await?;
+ stream.close().await?;
+ Ok::<_, io::Error>(())
+ });
+
+ let responder = tokio::spawn(async move {
+ let stream = accept_stream(
+ tokio::io::BufStream::new(rhs),
+ "jabber:client",
+ Timeouts::tight(),
+ )
+ .await?;
+ let stream = stream.send_header(StreamHeader::default()).await?;
+ let mut stream = stream
+ .send_features::<Data>(&StreamFeatures::default())
+ .await?;
+ stream.shutdown().await?;
+ match stream.next().await {
+ Some(Ok(Data { contents })) => assert_eq!(contents, "hello"),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ match stream.next().await {
+ Some(Ok(Data { contents })) => assert_eq!(contents, "world!"),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ match stream.next().await {
+ Some(Err(ReadError::StreamFooterReceived)) => (),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ match stream.next().await {
+ None => (),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ Ok::<_, io::Error>(())
+ });
+
+ responder.await.unwrap().expect("responder failed");
+ initiator.await.unwrap().expect("initiator failed");
+}