xmlstream: split initiation reset in two phases

Jonas Schรคfer created

Change summary

tokio-xmpp/src/xmlstream/initiator.rs | 23 ++++++++++++++++++++++
tokio-xmpp/src/xmlstream/mod.rs       | 30 +++++++---------------------
tokio-xmpp/src/xmlstream/tests.rs     |  5 ++-
3 files changed, 34 insertions(+), 24 deletions(-)

Detailed changes

tokio-xmpp/src/xmlstream/initiator.rs ๐Ÿ”—

@@ -8,6 +8,8 @@ use core::pin::Pin;
 use std::borrow::Cow;
 use std::io;
 
+use futures::SinkExt;
+
 use tokio::io::{AsyncBufRead, AsyncWrite};
 
 use xmpp_parsers::stream_features::StreamFeatures;
@@ -19,6 +21,27 @@ use super::{
     XmlStream,
 };
 
+/// Type state for an initiator stream which has not yet sent its stream
+/// header.
+///
+/// To continue stream setup, call [`send_header`][`Self::send_header`].
+pub struct InitiatingStream<Io>(pub(super) RawXmlStream<Io>);
+
+impl<Io: AsyncBufRead + AsyncWrite + Unpin> InitiatingStream<Io> {
+    /// Send the stream header.
+    pub async fn send_header(
+        self,
+        header: StreamHeader<'_>,
+    ) -> io::Result<PendingFeaturesRecv<Io>> {
+        let Self(mut stream) = self;
+
+        header.send(Pin::new(&mut stream)).await?;
+        stream.flush().await?;
+        let header = StreamHeader::recv(Pin::new(&mut stream)).await?;
+        Ok(PendingFeaturesRecv { stream, header })
+    }
+}
+
 /// Type state for an initiator stream which has sent and received the stream
 /// header.
 ///

tokio-xmpp/src/xmlstream/mod.rs ๐Ÿ”—

@@ -41,7 +41,7 @@ use core::pin::Pin;
 use core::task::{Context, Poll};
 use std::io;
 
-use futures::{ready, Sink, SinkExt, Stream};
+use futures::{ready, Sink, Stream};
 
 use tokio::io::{AsyncBufRead, AsyncWrite};
 
@@ -54,7 +54,7 @@ mod responder;
 mod tests;
 
 use self::common::{RawXmlStream, ReadXsoError, ReadXsoState, StreamHeader};
-pub use self::initiator::PendingFeaturesRecv;
+pub use self::initiator::{InitiatingStream, PendingFeaturesRecv};
 pub use self::responder::{AcceptedStream, PendingFeaturesSend};
 
 /// Initiate a new stream
@@ -70,16 +70,8 @@ pub async fn initiate_stream<Io: AsyncBufRead + AsyncWrite + Unpin>(
     stream_ns: &'static str,
     stream_header: StreamHeader<'_>,
 ) -> Result<PendingFeaturesRecv<Io>, io::Error> {
-    let mut raw_stream = RawXmlStream::new(io, stream_ns);
-    stream_header.send(Pin::new(&mut raw_stream)).await?;
-    raw_stream.flush().await?;
-
-    let header = StreamHeader::recv(Pin::new(&mut raw_stream)).await?;
-
-    Ok(PendingFeaturesRecv {
-        stream: raw_stream,
-        header,
-    })
+    let stream = InitiatingStream(RawXmlStream::new(io, stream_ns));
+    stream.send_header(stream_header).await
 }
 
 /// Accept a new XML stream as responder
@@ -194,8 +186,8 @@ impl<Io: AsyncBufRead, T: FromXml + AsXml> XmlStream<Io, T> {
 impl<Io: AsyncBufRead + AsyncWrite + Unpin, T: FromXml + AsXml> XmlStream<Io, T> {
     /// Initiate a stream reset
     ///
-    /// The `header` is the new stream header which is sent to the remote
-    /// party.
+    /// To actually send the stream header, call
+    /// [`send_header`][`InitiatingStream::send_header`] on the result.
     ///
     /// # Panics
     ///
@@ -205,18 +197,12 @@ impl<Io: AsyncBufRead + AsyncWrite + Unpin, T: FromXml + AsXml> XmlStream<Io, T>
     ///
     /// In addition, attempting to reset a stream which has been closed by
     /// either side or which has had an I/O error will also cause a panic.
-    pub async fn initiate_reset(
-        self,
-        header: StreamHeader<'_>,
-    ) -> io::Result<PendingFeaturesRecv<Io>> {
+    pub fn initiate_reset(self) -> InitiatingStream<Io> {
         self.assert_retypable();
 
         let mut stream = self.inner;
         Pin::new(&mut stream).reset_state();
-        header.send(Pin::new(&mut stream)).await?;
-        stream.flush().await?;
-        let header = StreamHeader::recv(Pin::new(&mut stream)).await?;
-        Ok(PendingFeaturesRecv { stream, header })
+        InitiatingStream(stream)
     }
 
     /// Anticipate a new stream header sent by the remote party.

tokio-xmpp/src/xmlstream/tests.rs ๐Ÿ”—

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this
 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
-use futures::StreamExt;
+use futures::{SinkExt, StreamExt};
 
 use xmpp_parsers::stream_features::StreamFeatures;
 
@@ -185,7 +185,8 @@ async fn test_exchange_data_stream_reset_and_shutdown() {
             other => panic!("unexpected stream message: {:?}", other),
         }
         let stream = stream
-            .initiate_reset(StreamHeader {
+            .initiate_reset()
+            .send_header(StreamHeader {
                 from: Some("client".into()),
                 to: Some("server".into()),
                 id: Some("client-id".into()),