diff --git a/parsers/src/iq.rs b/parsers/src/iq.rs index 77bfd8c17e5b5c684b04717f1fd918167b48d822..bcd02ba18915578228c116edba4ac1ff989d1674 100644 --- a/parsers/src/iq.rs +++ b/parsers/src/iq.rs @@ -39,6 +39,15 @@ impl IqHeader { } } +/// Payload of an IQ request stanza. +pub enum IqRequestPayload { + /// Payload of a type='get' stanza. + Get(Element), + + /// Payload of a type='set' stanza. + Set(Element), +} + /// Payload of an IQ stanza, by type. pub enum IqPayload { /// Payload of a type='get' stanza. diff --git a/tokio-xmpp/src/error.rs b/tokio-xmpp/src/error.rs index 6d81a160f8d445b23d509519c61793a7de0d80f8..a8b25f81e12c470923255b48676943178426c646 100644 --- a/tokio-xmpp/src/error.rs +++ b/tokio-xmpp/src/error.rs @@ -6,9 +6,11 @@ use hickory_resolver::{ use sasl::client::MechanismError as SaslMechanismError; use std::io; +use xmpp_parsers::stream_error::ReceivedStreamError; + use crate::{ connect::ServerConnectorError, jid, minidom, - parsers::sasl::DefinedCondition as SaslDefinedCondition, + parsers::sasl::DefinedCondition as SaslDefinedCondition, xmlstream::RecvFeaturesError, }; /// Top-level error type @@ -44,6 +46,8 @@ pub enum Error { Idna, /// Invalid IP/Port address Addr(AddrParseError), + /// Received a stream error + StreamError(ReceivedStreamError), } impl fmt::Display for Error { @@ -65,6 +69,7 @@ impl fmt::Display for Error { #[cfg(feature = "dns")] Error::Idna => write!(fmt, "IDNA error"), Error::Addr(e) => write!(fmt, "Wrong network address: {e}"), + Error::StreamError(e) => write!(fmt, "{e}"), } } } @@ -140,6 +145,15 @@ impl From for Error { } } +impl From for Error { + fn from(e: RecvFeaturesError) -> Self { + match e { + RecvFeaturesError::Io(e) => e.into(), + RecvFeaturesError::StreamError(e) => Self::StreamError(e), + } + } +} + /// XMPP protocol-level error #[derive(Debug)] pub enum ProtocolError { diff --git a/tokio-xmpp/src/xmlstream/initiator.rs b/tokio-xmpp/src/xmlstream/initiator.rs index 53f24ffa1a273ee807b2136363e83b757347b302..e0698d8adcd7bb92b7d324c59ee4fb23f941c533 100644 --- a/tokio-xmpp/src/xmlstream/initiator.rs +++ b/tokio-xmpp/src/xmlstream/initiator.rs @@ -4,6 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use core::fmt; use core::pin::Pin; use std::borrow::Cow; use std::io; @@ -12,7 +13,10 @@ use futures::SinkExt; use tokio::io::{AsyncBufRead, AsyncWrite}; -use xmpp_parsers::stream_features::StreamFeatures; +use xmpp_parsers::{ + stream_error::{ReceivedStreamError, StreamError}, + stream_features::StreamFeatures, +}; use xso::{AsXml, FromXml}; @@ -42,6 +46,49 @@ impl InitiatingStream { } } +#[derive(xso::FromXml)] +#[xml()] +enum StreamFeaturesPayload { + #[xml(transparent)] + Features(StreamFeatures), + #[xml(transparent)] + Error(StreamError), +} + +/// Error conditions when receiving stream features +#[derive(Debug)] +pub enum RecvFeaturesError { + /// I/o error while receiving stream features + Io(io::Error), + + /// Received a stream error instead of stream features + StreamError(ReceivedStreamError), +} + +impl fmt::Display for RecvFeaturesError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Io(e) => write!(f, "i/o error: {e}"), + Self::StreamError(e) => fmt::Display::fmt(&e, f), + } + } +} + +impl core::error::Error for RecvFeaturesError { + fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { + match self { + Self::Io(e) => Some(e), + Self::StreamError(e) => Some(e), + } + } +} + +impl From for RecvFeaturesError { + fn from(other: io::Error) -> Self { + Self::Io(other) + } +} + /// Type state for an initiator stream which has sent and received the stream /// header. /// @@ -73,26 +120,39 @@ impl PendingFeaturesRecv { /// After the stream features have been received, the stream can be used /// for exchanging stream-level elements (stanzas or "nonzas"). The Rust /// type for these elements must be given as type parameter `T`. + /// + /// If the peer sends a stream error instead of features, the error is + /// returned as [`RecvFeaturesError::StreamError`]. + /// + /// If the peer sends any payload which is neither stream features nor + /// a stream error, an [`io::Error`][`std::io::Error`] with + /// [`InvalidData`][`io::ErrorKind::InvalidData`] kind is returned. pub async fn recv_features( self, - ) -> io::Result<(StreamFeatures, XmlStream)> { + ) -> Result<(StreamFeatures, XmlStream), RecvFeaturesError> { let Self { mut stream, header: _, } = self; let features = loop { match ReadXso::read_from(Pin::new(&mut stream)).await { - Ok(v) => break v, + Ok(StreamFeaturesPayload::Features(v)) => break v, + Ok(StreamFeaturesPayload::Error(v)) => { + return Err(RecvFeaturesError::StreamError(ReceivedStreamError(v))) + } Err(ReadXsoError::SoftTimeout) => (), - Err(ReadXsoError::Hard(e)) => return Err(e), + Err(ReadXsoError::Hard(e)) => return Err(RecvFeaturesError::Io(e)), Err(ReadXsoError::Parse(e)) => { - return Err(io::Error::new(io::ErrorKind::InvalidData, e)) + return Err(RecvFeaturesError::Io(io::Error::new( + io::ErrorKind::InvalidData, + e, + ))) } Err(ReadXsoError::Footer) => { - return Err(io::Error::new( + return Err(RecvFeaturesError::Io(io::Error::new( io::ErrorKind::UnexpectedEof, "unexpected stream footer", - )) + ))) } } }; diff --git a/tokio-xmpp/src/xmlstream/mod.rs b/tokio-xmpp/src/xmlstream/mod.rs index 2233ad770022044587fec5da083a51034c2dc6c9..c5d219bd8bea067b7267945b193b21cce6b7ddc3 100644 --- a/tokio-xmpp/src/xmlstream/mod.rs +++ b/tokio-xmpp/src/xmlstream/mod.rs @@ -83,7 +83,7 @@ pub(crate) mod xmpp; use self::common::{RawError, RawXmlStream, ReadXsoError, ReadXsoState}; pub use self::common::{StreamHeader, Timeouts}; -pub use self::initiator::{InitiatingStream, PendingFeaturesRecv}; +pub use self::initiator::{InitiatingStream, PendingFeaturesRecv, RecvFeaturesError}; pub use self::responder::{AcceptedStream, PendingFeaturesSend}; pub use self::xmpp::XmppStreamElement; diff --git a/tokio-xmpp/src/xmlstream/responder.rs b/tokio-xmpp/src/xmlstream/responder.rs index dd28ba8c0959d74609ffe51445c3d8d9f7670479..20c65eea3b4bd67e54b232ffb844fca8a0779d4d 100644 --- a/tokio-xmpp/src/xmlstream/responder.rs +++ b/tokio-xmpp/src/xmlstream/responder.rs @@ -12,7 +12,7 @@ use futures::SinkExt; use tokio::io::{AsyncBufRead, AsyncWrite}; -use xmpp_parsers::stream_features::StreamFeatures; +use xmpp_parsers::{stream_error::StreamError, stream_features::StreamFeatures}; use xso::{AsXml, FromXml}; @@ -88,4 +88,17 @@ impl PendingFeaturesSend { Ok(XmlStream::wrap(stream)) } + + /// Send a stream error and shut the stream down. + /// + /// Sends the given stream error to the peer and cleanly closes the stream + /// by sending a stream footer. + pub async fn send_error(self, error: &'_ StreamError) -> io::Result<()> { + let Self { mut stream } = self; + Pin::new(&mut stream).start_send_xso(error)?; + stream.send(xso::Item::ElementFoot).await?; + stream.close().await?; + + Ok(()) + } } diff --git a/tokio-xmpp/src/xmlstream/tests.rs b/tokio-xmpp/src/xmlstream/tests.rs index e81045f8e711b8f52b4c664e091e5ac6c90d9692..fa6319122272a15fa2601f8acbe42f10b1020ea9 100644 --- a/tokio-xmpp/src/xmlstream/tests.rs +++ b/tokio-xmpp/src/xmlstream/tests.rs @@ -8,7 +8,10 @@ use core::time::Duration; use futures::{SinkExt, StreamExt}; -use xmpp_parsers::stream_features::StreamFeatures; +use xmpp_parsers::{ + stream_error::{DefinedCondition, StreamError}, + stream_features::StreamFeatures, +}; use super::*; @@ -73,7 +76,7 @@ async fn test_exchange_stream_features() { ) .await?; let (features, _) = stream.recv_features::().await?; - Ok::<_, io::Error>(features) + Ok::<_, RecvFeaturesError>(features) }); let responder = tokio::spawn(async move { let stream = accept_stream( @@ -93,6 +96,47 @@ async fn test_exchange_stream_features() { assert_eq!(features, StreamFeatures::default()); } +#[tokio::test] +async fn test_handle_early_stream_error() { + let (lhs, rhs) = tokio::io::duplex(65536); + let err = StreamError { + condition: DefinedCondition::InternalServerError, + text: None, + application_specific: Vec::new(), + }; + let initiator = tokio::spawn(async move { + let stream = initiate_stream( + tokio::io::BufStream::new(lhs), + "jabber:client", + StreamHeader::default(), + Timeouts::tight(), + ) + .await?; + match stream.recv_features::().await { + Ok((v, ..)) => panic!("test expected stream error, got features {v:?}"), + Err(RecvFeaturesError::Io(e)) => Err(e), + Err(RecvFeaturesError::StreamError(e)) => Ok(e), + } + }); + let responder = { + let err = err.clone(); + tokio::spawn(async move { + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; + let stream = stream.send_header(StreamHeader::default()).await?; + stream.send_error(&err).await?; + Ok::<_, io::Error>(()) + }) + }; + responder.await.unwrap().expect("responder failed"); + let received = initiator.await.unwrap().expect("initiator failed"); + assert_eq!(received.0, err); +} + #[tokio::test] async fn test_exchange_data() { let (lhs, rhs) = tokio::io::duplex(65536); @@ -115,7 +159,7 @@ async fn test_exchange_data() { Some(Ok(Data { contents })) => assert_eq!(contents, "world!"), other => panic!("unexpected stream message: {:?}", other), } - Ok::<_, io::Error>(()) + Ok::<_, RecvFeaturesError>(()) }); let responder = tokio::spawn(async move { @@ -163,7 +207,7 @@ async fn test_clean_shutdown() { Some(Err(ReadError::StreamFooterReceived)) => (), other => panic!("unexpected stream message: {:?}", other), } - Ok::<_, io::Error>(()) + Ok::<_, RecvFeaturesError>(()) }); let responder = tokio::spawn(async move { @@ -238,7 +282,7 @@ async fn test_exchange_data_stream_reset_and_shutdown() { Some(Err(ReadError::StreamFooterReceived)) => (), other => panic!("unexpected stream message: {:?}", other), } - Ok::<_, io::Error>(()) + Ok::<_, RecvFeaturesError>(()) }); let responder = tokio::spawn(async move { @@ -355,7 +399,7 @@ async fn test_emits_soft_timeout_after_silence() { Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::TimedOut => (), other => panic!("unexpected stream message: {:?}", other), } - Ok::<_, io::Error>(()) + Ok::<_, RecvFeaturesError>(()) }); let responder = tokio::spawn(async move { @@ -428,7 +472,7 @@ async fn test_can_receive_after_shutdown() { }) .await?; as SinkExt<&Data>>::close(&mut stream).await?; - Ok::<_, io::Error>(()) + Ok::<_, RecvFeaturesError>(()) }); let responder = tokio::spawn(async move {