tokio_xmpp: properly process <stream:error/> while waiting for features

Jonas Schäfer created

Before this, a stream error would not be readable by user code, as the
`recv_features()` function would not even attempt to parse it. This
change allows application code to react to stream errors which are
received before stream features are received.

skip-changelog, because there's no release with the xmlstream module
yet.

Change summary

parsers/src/iq.rs                     |  9 +++
tokio-xmpp/src/error.rs               | 16 +++++
tokio-xmpp/src/xmlstream/initiator.rs | 74 ++++++++++++++++++++++++++--
tokio-xmpp/src/xmlstream/mod.rs       |  2 
tokio-xmpp/src/xmlstream/responder.rs | 15 +++++
tokio-xmpp/src/xmlstream/tests.rs     | 58 +++++++++++++++++++--
6 files changed, 157 insertions(+), 17 deletions(-)

Detailed changes

parsers/src/iq.rs 🔗

@@ -39,6 +39,15 @@ impl IqHeader {
     }
 }
 
+/// Payload of an IQ request stanza.
+pub enum IqRequestPayload {
+    /// Payload of a type='get' stanza.
+    Get(Element),
+
+    /// Payload of a type='set' stanza.
+    Set(Element),
+}
+
 /// Payload of an IQ stanza, by type.
 pub enum IqPayload {
     /// Payload of a type='get' stanza.

tokio-xmpp/src/error.rs 🔗

@@ -6,9 +6,11 @@ use hickory_resolver::{
 use sasl::client::MechanismError as SaslMechanismError;
 use std::io;
 
+use xmpp_parsers::stream_error::ReceivedStreamError;
+
 use crate::{
     connect::ServerConnectorError, jid, minidom,
-    parsers::sasl::DefinedCondition as SaslDefinedCondition,
+    parsers::sasl::DefinedCondition as SaslDefinedCondition, xmlstream::RecvFeaturesError,
 };
 
 /// Top-level error type
@@ -44,6 +46,8 @@ pub enum Error {
     Idna,
     /// Invalid IP/Port address
     Addr(AddrParseError),
+    /// Received a stream error
+    StreamError(ReceivedStreamError),
 }
 
 impl fmt::Display for Error {
@@ -65,6 +69,7 @@ impl fmt::Display for Error {
             #[cfg(feature = "dns")]
             Error::Idna => write!(fmt, "IDNA error"),
             Error::Addr(e) => write!(fmt, "Wrong network address: {e}"),
+            Error::StreamError(e) => write!(fmt, "{e}"),
         }
     }
 }
@@ -140,6 +145,15 @@ impl From<AddrParseError> for Error {
     }
 }
 
+impl From<RecvFeaturesError> for Error {
+    fn from(e: RecvFeaturesError) -> Self {
+        match e {
+            RecvFeaturesError::Io(e) => e.into(),
+            RecvFeaturesError::StreamError(e) => Self::StreamError(e),
+        }
+    }
+}
+
 /// XMPP protocol-level error
 #[derive(Debug)]
 pub enum ProtocolError {

tokio-xmpp/src/xmlstream/initiator.rs 🔗

@@ -4,6 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this
 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
+use core::fmt;
 use core::pin::Pin;
 use std::borrow::Cow;
 use std::io;
@@ -12,7 +13,10 @@ use futures::SinkExt;
 
 use tokio::io::{AsyncBufRead, AsyncWrite};
 
-use xmpp_parsers::stream_features::StreamFeatures;
+use xmpp_parsers::{
+    stream_error::{ReceivedStreamError, StreamError},
+    stream_features::StreamFeatures,
+};
 
 use xso::{AsXml, FromXml};
 
@@ -42,6 +46,49 @@ impl<Io: AsyncBufRead + AsyncWrite + Unpin> InitiatingStream<Io> {
     }
 }
 
+#[derive(xso::FromXml)]
+#[xml()]
+enum StreamFeaturesPayload {
+    #[xml(transparent)]
+    Features(StreamFeatures),
+    #[xml(transparent)]
+    Error(StreamError),
+}
+
+/// Error conditions when receiving stream features
+#[derive(Debug)]
+pub enum RecvFeaturesError {
+    /// I/o error while receiving stream features
+    Io(io::Error),
+
+    /// Received a stream error instead of stream features
+    StreamError(ReceivedStreamError),
+}
+
+impl fmt::Display for RecvFeaturesError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::Io(e) => write!(f, "i/o error: {e}"),
+            Self::StreamError(e) => fmt::Display::fmt(&e, f),
+        }
+    }
+}
+
+impl core::error::Error for RecvFeaturesError {
+    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
+        match self {
+            Self::Io(e) => Some(e),
+            Self::StreamError(e) => Some(e),
+        }
+    }
+}
+
+impl From<io::Error> for RecvFeaturesError {
+    fn from(other: io::Error) -> Self {
+        Self::Io(other)
+    }
+}
+
 /// Type state for an initiator stream which has sent and received the stream
 /// header.
 ///
@@ -73,26 +120,39 @@ impl<Io: AsyncBufRead + AsyncWrite + Unpin> PendingFeaturesRecv<Io> {
     /// After the stream features have been received, the stream can be used
     /// for exchanging stream-level elements (stanzas or "nonzas"). The Rust
     /// type for these elements must be given as type parameter `T`.
+    ///
+    /// If the peer sends a stream error instead of features, the error is
+    /// returned as [`RecvFeaturesError::StreamError`].
+    ///
+    /// If the peer sends any payload which is neither stream features nor
+    /// a stream error, an [`io::Error`][`std::io::Error`] with
+    /// [`InvalidData`][`io::ErrorKind::InvalidData`] kind is returned.
     pub async fn recv_features<T: FromXml + AsXml>(
         self,
-    ) -> io::Result<(StreamFeatures, XmlStream<Io, T>)> {
+    ) -> Result<(StreamFeatures, XmlStream<Io, T>), RecvFeaturesError> {
         let Self {
             mut stream,
             header: _,
         } = self;
         let features = loop {
             match ReadXso::read_from(Pin::new(&mut stream)).await {
-                Ok(v) => break v,
+                Ok(StreamFeaturesPayload::Features(v)) => break v,
+                Ok(StreamFeaturesPayload::Error(v)) => {
+                    return Err(RecvFeaturesError::StreamError(ReceivedStreamError(v)))
+                }
                 Err(ReadXsoError::SoftTimeout) => (),
-                Err(ReadXsoError::Hard(e)) => return Err(e),
+                Err(ReadXsoError::Hard(e)) => return Err(RecvFeaturesError::Io(e)),
                 Err(ReadXsoError::Parse(e)) => {
-                    return Err(io::Error::new(io::ErrorKind::InvalidData, e))
+                    return Err(RecvFeaturesError::Io(io::Error::new(
+                        io::ErrorKind::InvalidData,
+                        e,
+                    )))
                 }
                 Err(ReadXsoError::Footer) => {
-                    return Err(io::Error::new(
+                    return Err(RecvFeaturesError::Io(io::Error::new(
                         io::ErrorKind::UnexpectedEof,
                         "unexpected stream footer",
-                    ))
+                    )))
                 }
             }
         };

tokio-xmpp/src/xmlstream/mod.rs 🔗

@@ -83,7 +83,7 @@ pub(crate) mod xmpp;
 
 use self::common::{RawError, RawXmlStream, ReadXsoError, ReadXsoState};
 pub use self::common::{StreamHeader, Timeouts};
-pub use self::initiator::{InitiatingStream, PendingFeaturesRecv};
+pub use self::initiator::{InitiatingStream, PendingFeaturesRecv, RecvFeaturesError};
 pub use self::responder::{AcceptedStream, PendingFeaturesSend};
 pub use self::xmpp::XmppStreamElement;
 

tokio-xmpp/src/xmlstream/responder.rs 🔗

@@ -12,7 +12,7 @@ use futures::SinkExt;
 
 use tokio::io::{AsyncBufRead, AsyncWrite};
 
-use xmpp_parsers::stream_features::StreamFeatures;
+use xmpp_parsers::{stream_error::StreamError, stream_features::StreamFeatures};
 
 use xso::{AsXml, FromXml};
 
@@ -88,4 +88,17 @@ impl<Io: AsyncBufRead + AsyncWrite + Unpin> PendingFeaturesSend<Io> {
 
         Ok(XmlStream::wrap(stream))
     }
+
+    /// Send a stream error and shut the stream down.
+    ///
+    /// Sends the given stream error to the peer and cleanly closes the stream
+    /// by sending a stream footer.
+    pub async fn send_error(self, error: &'_ StreamError) -> io::Result<()> {
+        let Self { mut stream } = self;
+        Pin::new(&mut stream).start_send_xso(error)?;
+        stream.send(xso::Item::ElementFoot).await?;
+        stream.close().await?;
+
+        Ok(())
+    }
 }

tokio-xmpp/src/xmlstream/tests.rs 🔗

@@ -8,7 +8,10 @@ use core::time::Duration;
 
 use futures::{SinkExt, StreamExt};
 
-use xmpp_parsers::stream_features::StreamFeatures;
+use xmpp_parsers::{
+    stream_error::{DefinedCondition, StreamError},
+    stream_features::StreamFeatures,
+};
 
 use super::*;
 
@@ -73,7 +76,7 @@ async fn test_exchange_stream_features() {
         )
         .await?;
         let (features, _) = stream.recv_features::<Data>().await?;
-        Ok::<_, io::Error>(features)
+        Ok::<_, RecvFeaturesError>(features)
     });
     let responder = tokio::spawn(async move {
         let stream = accept_stream(
@@ -93,6 +96,47 @@ async fn test_exchange_stream_features() {
     assert_eq!(features, StreamFeatures::default());
 }
 
+#[tokio::test]
+async fn test_handle_early_stream_error() {
+    let (lhs, rhs) = tokio::io::duplex(65536);
+    let err = StreamError {
+        condition: DefinedCondition::InternalServerError,
+        text: None,
+        application_specific: Vec::new(),
+    };
+    let initiator = tokio::spawn(async move {
+        let stream = initiate_stream(
+            tokio::io::BufStream::new(lhs),
+            "jabber:client",
+            StreamHeader::default(),
+            Timeouts::tight(),
+        )
+        .await?;
+        match stream.recv_features::<Data>().await {
+            Ok((v, ..)) => panic!("test expected stream error, got features {v:?}"),
+            Err(RecvFeaturesError::Io(e)) => Err(e),
+            Err(RecvFeaturesError::StreamError(e)) => Ok(e),
+        }
+    });
+    let responder = {
+        let err = err.clone();
+        tokio::spawn(async move {
+            let stream = accept_stream(
+                tokio::io::BufStream::new(rhs),
+                "jabber:client",
+                Timeouts::tight(),
+            )
+            .await?;
+            let stream = stream.send_header(StreamHeader::default()).await?;
+            stream.send_error(&err).await?;
+            Ok::<_, io::Error>(())
+        })
+    };
+    responder.await.unwrap().expect("responder failed");
+    let received = initiator.await.unwrap().expect("initiator failed");
+    assert_eq!(received.0, err);
+}
+
 #[tokio::test]
 async fn test_exchange_data() {
     let (lhs, rhs) = tokio::io::duplex(65536);
@@ -115,7 +159,7 @@ async fn test_exchange_data() {
             Some(Ok(Data { contents })) => assert_eq!(contents, "world!"),
             other => panic!("unexpected stream message: {:?}", other),
         }
-        Ok::<_, io::Error>(())
+        Ok::<_, RecvFeaturesError>(())
     });
 
     let responder = tokio::spawn(async move {
@@ -163,7 +207,7 @@ async fn test_clean_shutdown() {
             Some(Err(ReadError::StreamFooterReceived)) => (),
             other => panic!("unexpected stream message: {:?}", other),
         }
-        Ok::<_, io::Error>(())
+        Ok::<_, RecvFeaturesError>(())
     });
 
     let responder = tokio::spawn(async move {
@@ -238,7 +282,7 @@ async fn test_exchange_data_stream_reset_and_shutdown() {
             Some(Err(ReadError::StreamFooterReceived)) => (),
             other => panic!("unexpected stream message: {:?}", other),
         }
-        Ok::<_, io::Error>(())
+        Ok::<_, RecvFeaturesError>(())
     });
 
     let responder = tokio::spawn(async move {
@@ -355,7 +399,7 @@ async fn test_emits_soft_timeout_after_silence() {
             Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::TimedOut => (),
             other => panic!("unexpected stream message: {:?}", other),
         }
-        Ok::<_, io::Error>(())
+        Ok::<_, RecvFeaturesError>(())
     });
 
     let responder = tokio::spawn(async move {
@@ -428,7 +472,7 @@ async fn test_can_receive_after_shutdown() {
             })
             .await?;
         <XmlStream<_, _> as SinkExt<&Data>>::close(&mut stream).await?;
-        Ok::<_, io::Error>(())
+        Ok::<_, RecvFeaturesError>(())
     });
 
     let responder = tokio::spawn(async move {