From 4cfe4f842967f145bb3fa973a1e0d3a8820c63e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Sch=C3=A4fer?= Date: Sun, 18 Aug 2024 17:40:39 +0200 Subject: [PATCH] xmlstream: implement simple timeout logic This allows to detect and handle dying streams without getting stuck forever. Timeouts are always wrong, though, so we put the burden of choosing the right values (mostly) on the creator of a stream. --- tokio-xmpp/Cargo.toml | 1 + tokio-xmpp/examples/echo_component.rs | 4 +- tokio-xmpp/examples/echo_server.rs | 3 +- tokio-xmpp/src/client/login.rs | 7 +- tokio-xmpp/src/client/mod.rs | 22 ++- tokio-xmpp/src/client/stream.rs | 1 + tokio-xmpp/src/component/login.rs | 5 +- tokio-xmpp/src/component/mod.rs | 26 ++- tokio-xmpp/src/connect/mod.rs | 3 +- tokio-xmpp/src/connect/starttls.rs | 5 +- tokio-xmpp/src/connect/tcp.rs | 4 +- tokio-xmpp/src/xmlstream/common.rs | 254 ++++++++++++++++++++++---- tokio-xmpp/src/xmlstream/initiator.rs | 19 +- tokio-xmpp/src/xmlstream/mod.rs | 28 ++- tokio-xmpp/src/xmlstream/tests.rs | 143 ++++++++++++++- xmpp/src/builder.rs | 20 +- 16 files changed, 469 insertions(+), 76 deletions(-) diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index 49e9b9f58ff76cf9eebd500abc4f76459c9aa2f9..1a07dd3563c0c8c1b02f98851547ad187facf092 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -39,6 +39,7 @@ tokio-rustls = { version = "0.26", optional = true } [dev-dependencies] env_logger = { version = "0.11", default-features = false, features = ["auto-color", "humantime"] } # this is needed for echo-component example +tokio = { version = "1", features = ["test-util"] } tokio-xmpp = { path = ".", features = ["insecure-tcp"]} [features] diff --git a/tokio-xmpp/examples/echo_component.rs b/tokio-xmpp/examples/echo_component.rs index 70ffa538fe62dea3162fa55bbd78d0efa301770a..6c5493492b37b8df094fa8053c4b68d3f875cbcc 100644 --- a/tokio-xmpp/examples/echo_component.rs +++ b/tokio-xmpp/examples/echo_component.rs @@ -31,9 +31,7 @@ async fn main() { // If you don't need a custom server but default localhost:5347, you can use // Component::new() directly - let mut component = Component::new_plaintext(jid, password, server) - .await - .unwrap(); + let mut component = Component::new(jid, password).await.unwrap(); // Make the two interfaces for sending and receiving independent // of each other so we can move one into a closure. diff --git a/tokio-xmpp/examples/echo_server.rs b/tokio-xmpp/examples/echo_server.rs index dbf2eb0570f42a7644eb5a1b5597dcc68332c8e6..6cbb75cd43ac678d58640f3b81b3a8e9850bfd2f 100644 --- a/tokio-xmpp/examples/echo_server.rs +++ b/tokio-xmpp/examples/echo_server.rs @@ -2,7 +2,7 @@ use futures::{SinkExt, StreamExt}; use tokio::{self, io, net::TcpSocket}; use tokio_xmpp::parsers::stream_features::StreamFeatures; -use tokio_xmpp::xmlstream::{accept_stream, StreamHeader}; +use tokio_xmpp::xmlstream::{accept_stream, StreamHeader, Timeouts}; #[tokio::main] async fn main() -> Result<(), io::Error> { @@ -19,6 +19,7 @@ async fn main() -> Result<(), io::Error> { let stream = accept_stream( tokio::io::BufStream::new(stream), tokio_xmpp::parsers::ns::DEFAULT_NS, + Timeouts::default(), ) .await?; let stream = stream.send_header(StreamHeader::default()).await?; diff --git a/tokio-xmpp/src/client/login.rs b/tokio-xmpp/src/client/login.rs index 14cb029318d94afdea60dd5d94412d1fd5244bd7..90aa0a34811c759f082991343e8cb379c78d31cf 100644 --- a/tokio-xmpp/src/client/login.rs +++ b/tokio-xmpp/src/client/login.rs @@ -19,7 +19,9 @@ use crate::{ client::bind::bind, connect::ServerConnector, error::{AuthError, Error, ProtocolError}, - xmlstream::{xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, XmppStream}, + xmlstream::{ + xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, Timeouts, XmppStream, + }, }; pub async fn auth( @@ -107,11 +109,12 @@ pub async fn client_login( server: C, jid: Jid, password: String, + timeouts: Timeouts, ) -> Result<(Option, StreamFeatures, XmppStream), Error> { let username = jid.node().unwrap().as_str(); let password = password; - let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?; + let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?; let (features, xmpp_stream) = xmpp_stream.recv_features().await?; let channel_binding = C::channel_binding(xmpp_stream.get_stream())?; diff --git a/tokio-xmpp/src/client/mod.rs b/tokio-xmpp/src/client/mod.rs index 13fcf6f289f9b43076b08462fa0193155cee47bc..0197ccd30fa5b06c141aa7318d045a34cfcacf1e 100644 --- a/tokio-xmpp/src/client/mod.rs +++ b/tokio-xmpp/src/client/mod.rs @@ -5,6 +5,7 @@ use crate::{ client::{login::client_login, stream::ClientState}, connect::ServerConnector, error::Error, + xmlstream::Timeouts, Stanza, }; @@ -30,6 +31,7 @@ pub struct Client { password: String, connector: C, state: ClientState, + timeouts: Timeouts, reconnect: bool, // TODO: tls_required=true } @@ -95,6 +97,7 @@ impl Client { jid.clone(), password, DnsConfig::srv(&jid.domain().to_string(), "_xmpp-client._tcp", 5222), + Timeouts::default(), ); client.set_reconnect(true); client @@ -105,8 +108,14 @@ impl Client { jid: J, password: P, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Self { - Self::new_with_connector(jid, password, StartTlsServerConnector::from(dns_config)) + Self::new_with_connector( + jid, + password, + StartTlsServerConnector::from(dns_config), + timeouts, + ) } } @@ -117,8 +126,14 @@ impl Client { jid: J, password: P, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Self { - Self::new_with_connector(jid, password, TcpServerConnector::from(dns_config)) + Self::new_with_connector( + jid, + password, + TcpServerConnector::from(dns_config), + timeouts, + ) } } @@ -128,6 +143,7 @@ impl Client { jid: J, password: P, connector: C, + timeouts: Timeouts, ) -> Self { let jid = jid.into(); let password = password.into(); @@ -136,6 +152,7 @@ impl Client { connector.clone(), jid.clone(), password.clone(), + timeouts, )); let client = Client { jid, @@ -143,6 +160,7 @@ impl Client { connector, state: ClientState::Connecting(connect), reconnect: false, + timeouts, }; client } diff --git a/tokio-xmpp/src/client/stream.rs b/tokio-xmpp/src/client/stream.rs index 7b19ecb8de5a6d22bdea28f6d3f0849a2a898a6e..8120a8702cb97e80b5dbd99edb295d3de56f9245 100644 --- a/tokio-xmpp/src/client/stream.rs +++ b/tokio-xmpp/src/client/stream.rs @@ -56,6 +56,7 @@ impl Stream for Client { self.connector.clone(), self.jid.clone(), self.password.clone(), + self.timeouts, )); self.state = ClientState::Connecting(connect); self.poll_next(cx) diff --git a/tokio-xmpp/src/component/login.rs b/tokio-xmpp/src/component/login.rs index 427ef89ad8e1ccadace5884cf64be807911d6d2b..33b743e49d177e730b539e8a420c170f85e1fdaa 100644 --- a/tokio-xmpp/src/component/login.rs +++ b/tokio-xmpp/src/component/login.rs @@ -6,16 +6,17 @@ use xmpp_parsers::{component::Handshake, jid::Jid, ns}; use crate::component::ServerConnector; use crate::error::{AuthError, Error}; -use crate::xmlstream::{ReadError, XmppStream, XmppStreamElement}; +use crate::xmlstream::{ReadError, Timeouts, XmppStream, XmppStreamElement}; /// Log into an XMPP server as a client with a jid+pass pub async fn component_login( connector: C, jid: Jid, password: String, + timeouts: Timeouts, ) -> Result, Error> { let password = password; - let mut stream = connector.connect(&jid, ns::COMPONENT).await?; + let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?; let header = stream.take_header(); let mut stream = stream.skip_features(); let stream_id = match header.id { diff --git a/tokio-xmpp/src/component/mod.rs b/tokio-xmpp/src/component/mod.rs index efd58a40cd1e0d773804536357115d4c0d4673e0..d041fe4684e7f6a24825b36ae40d1f28c4201d87 100644 --- a/tokio-xmpp/src/component/mod.rs +++ b/tokio-xmpp/src/component/mod.rs @@ -6,8 +6,10 @@ use std::str::FromStr; use xmpp_parsers::jid::Jid; use crate::{ - component::login::component_login, connect::ServerConnector, xmlstream::XmppStream, Error, - Stanza, + component::login::component_login, + connect::ServerConnector, + xmlstream::{Timeouts, XmppStream}, + Error, Stanza, }; #[cfg(any(feature = "starttls", feature = "insecure-tcp"))] @@ -46,7 +48,13 @@ impl Component { /// Start a new XMPP component over plaintext TCP to localhost:5347 #[cfg(feature = "insecure-tcp")] pub async fn new(jid: &str, password: &str) -> Result { - Self::new_plaintext(jid, password, DnsConfig::addr("127.0.0.1:5347")).await + Self::new_plaintext( + jid, + password, + DnsConfig::addr("127.0.0.1:5347"), + Timeouts::tight(), + ) + .await } /// Start a new XMPP component over plaintext TCP @@ -55,8 +63,15 @@ impl Component { jid: &str, password: &str, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Result { - Component::new_with_connector(jid, password, TcpServerConnector::from(dns_config)).await + Component::new_with_connector( + jid, + password, + TcpServerConnector::from(dns_config), + timeouts, + ) + .await } } @@ -69,10 +84,11 @@ impl Component { jid: &str, password: &str, connector: C, + timeouts: Timeouts, ) -> Result { let jid = Jid::from_str(jid)?; let password = password.to_owned(); - let stream = component_login(connector, jid.clone(), password).await?; + let stream = component_login(connector, jid.clone(), password, timeouts).await?; Ok(Component { jid, stream }) } } diff --git a/tokio-xmpp/src/connect/mod.rs b/tokio-xmpp/src/connect/mod.rs index 937b04026914484522244dbd7217fb3bb3470a6c..ca1ae61f420fa0e78513e01e51cb7f1bf87597df 100644 --- a/tokio-xmpp/src/connect/mod.rs +++ b/tokio-xmpp/src/connect/mod.rs @@ -4,7 +4,7 @@ use sasl::common::ChannelBinding; use tokio::io::{AsyncBufRead, AsyncWrite}; use xmpp_parsers::jid::Jid; -use crate::xmlstream::PendingFeaturesRecv; +use crate::xmlstream::{PendingFeaturesRecv, Timeouts}; use crate::Error; #[cfg(feature = "starttls")] @@ -36,6 +36,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static { &self, jid: &Jid, ns: &'static str, + timeouts: Timeouts, ) -> impl std::future::Future, Error>> + Send; /// Return channel binding data if available diff --git a/tokio-xmpp/src/connect/starttls.rs b/tokio-xmpp/src/connect/starttls.rs index 1d0d26e4f157e711b188d6171e42b72c0bf9e5c3..7df476c15c46cc449a1c75dd0b1f218713d9a66f 100644 --- a/tokio-xmpp/src/connect/starttls.rs +++ b/tokio-xmpp/src/connect/starttls.rs @@ -44,7 +44,7 @@ use crate::{ connect::{DnsConfig, ServerConnector, ServerConnectorError}, error::{Error, ProtocolError}, xmlstream::{ - initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, XmppStream, + initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream, XmppStreamElement, }, Client, @@ -70,6 +70,7 @@ impl ServerConnector for StartTlsServerConnector { &self, jid: &Jid, ns: &'static str, + timeouts: Timeouts, ) -> Result, Error> { let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?); @@ -82,6 +83,7 @@ impl ServerConnector for StartTlsServerConnector { from: None, id: None, }, + timeouts, ) .await?; let (features, xmpp_stream) = xmpp_stream.recv_features().await?; @@ -98,6 +100,7 @@ impl ServerConnector for StartTlsServerConnector { from: None, id: None, }, + timeouts, ) .await?) } else { diff --git a/tokio-xmpp/src/connect/tcp.rs b/tokio-xmpp/src/connect/tcp.rs index 89a9e12dd02528d37a395e24271ebb100b985caa..474e25b711c006060bfac3b531b5aef88f7b44fb 100644 --- a/tokio-xmpp/src/connect/tcp.rs +++ b/tokio-xmpp/src/connect/tcp.rs @@ -6,7 +6,7 @@ use tokio::{io::BufStream, net::TcpStream}; use crate::{ connect::{DnsConfig, ServerConnector}, - xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader}, + xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts}, Client, Component, Error, }; @@ -35,6 +35,7 @@ impl ServerConnector for TcpServerConnector { &self, jid: &xmpp_parsers::jid::Jid, ns: &'static str, + timeouts: Timeouts, ) -> Result, Error> { let stream = BufStream::new(self.0.resolve().await?); Ok(initiate_stream( @@ -45,6 +46,7 @@ impl ServerConnector for TcpServerConnector { from: None, id: None, }, + timeouts, ) .await?) } diff --git a/tokio-xmpp/src/xmlstream/common.rs b/tokio-xmpp/src/xmlstream/common.rs index 2ea0526b930f738b7c052c9bf3420ab530f5e95c..ff509824745f13fd8045b1abc60d726fae316815 100644 --- a/tokio-xmpp/src/xmlstream/common.rs +++ b/tokio-xmpp/src/xmlstream/common.rs @@ -7,6 +7,7 @@ use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; +use core::time::Duration; use std::borrow::Cow; use std::io; @@ -14,7 +15,10 @@ use futures::{ready, Sink, SinkExt, Stream, StreamExt}; use bytes::{Buf, BytesMut}; -use tokio::io::{AsyncBufRead, AsyncWrite}; +use tokio::{ + io::{AsyncBufRead, AsyncWrite}, + time::Instant, +}; use xso::{ exports::rxml::{self, writer::TrackNamespace, xml_ncname, Event, Namespace}, @@ -25,6 +29,129 @@ use super::capture::{log_enabled, log_recv, log_send, CaptureBufRead}; use xmpp_parsers::ns::STREAM as XML_STREAM_NS; +/// Configuration for timeouts on an XML stream. +/// +/// The defaults are tuned toward common desktop/laptop use and may not hold +/// up to extreme conditions (arctic sattelite link, mobile internet on a +/// train in Brandenburg, Germany, and similar) and may be inefficient in +/// other conditions (stable server link, localhost communication). +#[derive(Debug, Clone, Copy)] +pub struct Timeouts { + /// Maximum silence time before a + /// [`ReadError::SoftTimeout`][`super::ReadError::SoftTimeout`] is + /// returned. + /// + /// Soft timeouts are not fatal, but they must be handled by user code so + /// that more data is read after at most [`Self::response_timeout`], + /// starting from the moment the soft timeout is returned. + pub read_timeout: Duration, + + /// Maximum silence after a soft timeout. + /// + /// If the stream is silent for longer than this time after a soft timeout + /// has been emitted, a hard [`TimedOut`][`std::io::ErrorKind::TimedOut`] + /// I/O error is returned and the stream is to be considered dead. + pub response_timeout: Duration, +} + +impl Default for Timeouts { + fn default() -> Self { + Self { + read_timeout: Duration::new(300, 0), + response_timeout: Duration::new(300, 0), + } + } +} + +impl Timeouts { + /// Tight timeouts suitable for communicating on a fast LAN or localhost. + pub fn tight() -> Self { + Self { + read_timeout: Duration::new(60, 0), + response_timeout: Duration::new(15, 0), + } + } + + fn data_to_soft(&self) -> Duration { + self.read_timeout + } + + fn soft_to_warn(&self) -> Duration { + self.response_timeout / 2 + } + + fn warn_to_hard(&self) -> Duration { + self.response_timeout / 2 + } +} + +#[derive(Clone, Copy)] +enum TimeoutLevel { + Soft, + Warn, + Hard, +} + +#[derive(Debug)] +pub(super) enum RawError { + Io(io::Error), + SoftTimeout, +} + +impl From for RawError { + fn from(other: io::Error) -> Self { + Self::Io(other) + } +} + +struct TimeoutState { + /// Configuration for the timeouts. + timeouts: Timeouts, + + /// Level of the next timeout which will trip. + level: TimeoutLevel, + + /// Sleep timer used for read timeouts. + // NOTE: even though we pretend we could deal with an !Unpin + // RawXmlStream, we really can't: box_stream for example needs it, + // but also all the typestate around the initial stream setup needs + // to be able to move the stream around. + deadline: Pin>, +} + +impl TimeoutState { + fn new(timeouts: Timeouts) -> Self { + Self { + deadline: Box::pin(tokio::time::sleep(timeouts.data_to_soft())), + level: TimeoutLevel::Soft, + timeouts, + } + } + + fn poll(&mut self, cx: &mut Context) -> Poll { + ready!(self.deadline.as_mut().poll(cx)); + // Deadline elapsed! + let to_return = self.level; + let (next_level, next_duration) = match self.level { + TimeoutLevel::Soft => (TimeoutLevel::Warn, self.timeouts.soft_to_warn()), + TimeoutLevel::Warn => (TimeoutLevel::Hard, self.timeouts.warn_to_hard()), + // Something short-ish so that we fire this over and over until + // someone finally kills the stream for good. + TimeoutLevel::Hard => (TimeoutLevel::Hard, Duration::new(1, 0)), + }; + self.level = next_level; + self.deadline.as_mut().reset(Instant::now() + next_duration); + Poll::Ready(to_return) + } + + fn reset(&mut self) { + self.level = TimeoutLevel::Soft; + self.deadline + .as_mut() + .reset((Instant::now() + self.timeouts.data_to_soft()).into()); + } +} + pin_project_lite::pin_project! { // NOTE: due to limitations of pin_project_lite, the field comments are // no doc comments. Luckily, this struct is only `pub(super)` anyway. @@ -37,6 +164,8 @@ pin_project_lite::pin_project! { // The writer used for serialising data. writer: rxml::writer::Encoder, + timeouts: TimeoutState, + // The default namespace to declare on the stream header. stream_ns: &'static str, @@ -112,7 +241,7 @@ impl RawXmlStream { writer } - pub(super) fn new(io: Io, stream_ns: &'static str) -> Self { + pub(super) fn new(io: Io, stream_ns: &'static str, timeouts: Timeouts) -> Self { let parser = rxml::Parser::default(); let mut io = CaptureBufRead::wrap(io); if log_enabled() { @@ -121,6 +250,7 @@ impl RawXmlStream { Self { parser: rxml::AsyncReader::wrap(io, parser), writer: Self::new_writer(stream_ns), + timeouts: TimeoutState::new(timeouts), tx_buffer_logged: 0, stream_ns, tx_buffer: BytesMut::new(), @@ -189,18 +319,34 @@ impl RawXmlStream { } impl Stream for RawXmlStream { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); loop { - return Poll::Ready( - match ready!(this.parser.as_mut().poll_read(cx)).transpose() { - // Skip the XML declaration, nobody wants to hear about that. - Some(Ok(rxml::Event::XmlDeclaration(_, _))) => continue, - other => other, - }, - ); + match this.parser.as_mut().poll_read(cx) { + Poll::Pending => (), + Poll::Ready(v) => { + this.timeouts.reset(); + match v.transpose() { + // Skip the XML declaration, nobody wants to hear about that. + Some(Ok(rxml::Event::XmlDeclaration(_, _))) => continue, + other => return Poll::Ready(other.map(|x| x.map_err(RawError::Io))), + } + } + }; + + // poll_read returned pending... what do the timeouts have to say? + match ready!(this.timeouts.poll(cx)) { + TimeoutLevel::Soft => return Poll::Ready(Some(Err(RawError::SoftTimeout))), + TimeoutLevel::Warn => (), + TimeoutLevel::Hard => { + return Poll::Ready(Some(Err(RawError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "read and response timeouts elapsed", + ))))) + } + } } } } @@ -312,6 +458,20 @@ pub(super) enum ReadXsoError { /// not well-formed document. Hard(io::Error), + /// The underlying stream signalled a soft read timeout before a child + /// element could be read. + /// + /// Note that soft timeouts which are triggered in the middle of receiving + /// an element are converted to hard timeouts (i.e. I/O errors). + /// + /// This masking is intentional, because: + /// - Returning a [`Self::SoftTimeout`] from the middle of parsing is not + /// possible without complicating the API. + /// - There is no reason why the remote side should interrupt sending data + /// in the middle of an element except if it or the transport has failed + /// fatally. + SoftTimeout, + /// A parse error occurred. /// /// The XML structure was well-formed, but the data contained did not @@ -324,19 +484,6 @@ pub(super) enum ReadXsoError { Parse(xso::error::Error), } -impl From for io::Error { - fn from(other: ReadXsoError) -> Self { - match other { - ReadXsoError::Hard(v) => v, - ReadXsoError::Parse(e) => io::Error::new(io::ErrorKind::InvalidData, e), - ReadXsoError::Footer => io::Error::new( - io::ErrorKind::UnexpectedEof, - "element footer while waiting for XSO element start", - ), - } - } -} - impl From for ReadXsoError { fn from(other: io::Error) -> Self { Self::Hard(other) @@ -425,13 +572,13 @@ impl ReadXsoState { .parser_pinned() .set_text_buffering(text_buffering); - let ev = ready!(source.as_mut().poll_next(cx)).transpose()?; + let ev = ready!(source.as_mut().poll_next(cx)).transpose(); match self { ReadXsoState::PreData => { log::trace!("ReadXsoState::PreData ev = {:?}", ev); match ev { - Some(rxml::Event::XmlDeclaration(_, _)) => (), - Some(rxml::Event::Text(_, data)) => { + Ok(Some(rxml::Event::XmlDeclaration(_, _))) => (), + Ok(Some(rxml::Event::Text(_, data))) => { if xso::is_xml_whitespace(data.as_bytes()) { log::trace!("Received {} bytes of whitespace", data.len()); source.as_mut().stream_pinned().discard_capture(); @@ -445,18 +592,18 @@ impl ReadXsoState { .into())); } } - Some(rxml::Event::StartElement(_, name, attrs)) => { + Ok(Some(rxml::Event::StartElement(_, name, attrs))) => { *self = ReadXsoState::Parsing( as FromXml>::from_events(name, attrs) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?, ); } // Amounts to EOF, as we expect to start on the stream level. - Some(rxml::Event::EndElement(_)) => { + Ok(Some(rxml::Event::EndElement(_))) => { *self = ReadXsoState::Done; return Poll::Ready(Err(ReadXsoError::Footer)); } - None => { + Ok(None) => { *self = ReadXsoState::Done; return Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, @@ -464,17 +611,42 @@ impl ReadXsoState { ) .into())); } + Err(RawError::SoftTimeout) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(ReadXsoError::SoftTimeout)); + } + Err(RawError::Io(e)) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(ReadXsoError::Hard(e))); + } } } ReadXsoState::Parsing(builder) => { log::trace!("ReadXsoState::Parsing ev = {:?}", ev); - let Some(ev) = ev else { - *self = ReadXsoState::Done; - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "eof during XSO parsing", - ) - .into())); + let ev = match ev { + Ok(Some(ev)) => ev, + Ok(None) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "eof during XSO parsing", + ) + .into())); + } + Err(RawError::Io(e)) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(e.into())); + } + Err(RawError::SoftTimeout) => { + // See also [`ReadXsoError::SoftTimeout`] for why + // we mask the SoftTimeout condition here. + *self = ReadXsoState::Done; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::TimedOut, + "read timeout during XSO parsing", + ) + .into())); + } }; match builder.feed(ev) { @@ -622,8 +794,10 @@ impl StreamHeader<'static> { mut stream: Pin<&mut RawXmlStream>, ) -> io::Result { loop { - match stream.as_mut().next().await.transpose()? { - Some(Event::StartElement(_, (ns, name), mut attrs)) => { + match stream.as_mut().next().await { + Some(Err(RawError::Io(e))) => return Err(e), + Some(Err(RawError::SoftTimeout)) => (), + Some(Ok(Event::StartElement(_, (ns, name), mut attrs))) => { if ns != XML_STREAM_NS || name != "stream" { return Err(io::Error::new( io::ErrorKind::InvalidData, @@ -666,7 +840,7 @@ impl StreamHeader<'static> { id: id.map(Cow::Owned), }); } - Some(Event::Text(_, _)) | Some(Event::EndElement(_)) => { + Some(Ok(Event::Text(_, _))) | Some(Ok(Event::EndElement(_))) => { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "unexpected content before stream header", @@ -674,7 +848,7 @@ impl StreamHeader<'static> { } // We cannot loop infinitely here because the XML parser will // prevent more than one XML declaration from being parsed. - Some(Event::XmlDeclaration(_, _)) => (), + Some(Ok(Event::XmlDeclaration(_, _))) => (), None => { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, diff --git a/tokio-xmpp/src/xmlstream/initiator.rs b/tokio-xmpp/src/xmlstream/initiator.rs index 0829c9a0f62679c4c3ddee49683e7db33627851e..4682548d96a776905678b18c444b61fd4cce8454 100644 --- a/tokio-xmpp/src/xmlstream/initiator.rs +++ b/tokio-xmpp/src/xmlstream/initiator.rs @@ -17,7 +17,7 @@ use xmpp_parsers::stream_features::StreamFeatures; use xso::{AsXml, FromXml}; use super::{ - common::{RawXmlStream, ReadXso, StreamHeader}, + common::{RawXmlStream, ReadXso, ReadXsoError, StreamHeader}, XmlStream, }; @@ -80,7 +80,22 @@ impl PendingFeaturesRecv { mut stream, header: _, } = self; - let features = ReadXso::read_from(Pin::new(&mut stream)).await?; + let features = loop { + match ReadXso::read_from(Pin::new(&mut stream)).await { + Ok(v) => break v, + Err(ReadXsoError::SoftTimeout) => (), + Err(ReadXsoError::Hard(e)) => return Err(e), + Err(ReadXsoError::Parse(e)) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e)) + } + Err(ReadXsoError::Footer) => { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected stream footer", + )) + } + } + }; Ok((features, XmlStream::wrap(stream))) } diff --git a/tokio-xmpp/src/xmlstream/mod.rs b/tokio-xmpp/src/xmlstream/mod.rs index a4d88a6aec7aba46e150a9284fde8beb6c99d040..5e4149660a5b9f7d8ba953da11fd1669a8473683 100644 --- a/tokio-xmpp/src/xmlstream/mod.rs +++ b/tokio-xmpp/src/xmlstream/mod.rs @@ -77,8 +77,8 @@ mod responder; mod tests; pub(crate) mod xmpp; -pub use self::common::StreamHeader; -use self::common::{RawXmlStream, ReadXsoError, ReadXsoState}; +use self::common::{RawError, RawXmlStream, ReadXsoError, ReadXsoState}; +pub use self::common::{StreamHeader, Timeouts}; pub use self::initiator::{InitiatingStream, PendingFeaturesRecv}; pub use self::responder::{AcceptedStream, PendingFeaturesSend}; pub use self::xmpp::XmppStreamElement; @@ -129,8 +129,9 @@ pub async fn initiate_stream( io: Io, stream_ns: &'static str, stream_header: StreamHeader<'_>, + timeouts: Timeouts, ) -> Result, io::Error> { - let stream = InitiatingStream(RawXmlStream::new(io, stream_ns)); + let stream = InitiatingStream(RawXmlStream::new(io, stream_ns, timeouts)); stream.send_header(stream_header).await } @@ -144,8 +145,9 @@ pub async fn initiate_stream( pub async fn accept_stream( io: Io, stream_ns: &'static str, + timeouts: Timeouts, ) -> Result, io::Error> { - let mut stream = RawXmlStream::new(io, stream_ns); + let mut stream = RawXmlStream::new(io, stream_ns, timeouts); let header = StreamHeader::recv(Pin::new(&mut stream)).await?; Ok(AcceptedStream { stream, header }) } @@ -319,14 +321,21 @@ impl Stream for XmlStream; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); + let mut this = self.project(); let result = match this.read_state.as_mut() { None => { // awaiting eof. - return match ready!(this.inner.poll_next(cx)) { - None => Poll::Ready(None), - Some(Ok(_)) => unreachable!("xml parser allowed data after stream footer"), - Some(Err(e)) => Poll::Ready(Some(Err(ReadError::HardError(e)))), + return loop { + match ready!(this.inner.as_mut().poll_next(cx)) { + None => break Poll::Ready(None), + Some(Ok(_)) => unreachable!("xml parser allowed data after stream footer"), + Some(Err(RawError::Io(e))) => { + break Poll::Ready(Some(Err(ReadError::HardError(e)))) + } + // Swallow soft timeout, we don't want the user to trigger + // anything here. + Some(Err(RawError::SoftTimeout)) => continue, + } }; } Some(read_state) => ready!(read_state.poll_advance(this.inner, cx)), @@ -341,6 +350,7 @@ impl Stream for XmlStream Poll::Ready(Some(Err(ReadError::SoftTimeout))), }; *this.read_state = Some(ReadXsoState::default()); result diff --git a/tokio-xmpp/src/xmlstream/tests.rs b/tokio-xmpp/src/xmlstream/tests.rs index 2270547ccb0e18b66bb7117dd25aa894f5011139..44c0256fba7c993ce2525236e4c005d6eb9dd08b 100644 --- a/tokio-xmpp/src/xmlstream/tests.rs +++ b/tokio-xmpp/src/xmlstream/tests.rs @@ -4,6 +4,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use std::time::Duration; + use futures::{SinkExt, StreamExt}; use xmpp_parsers::stream_features::StreamFeatures; @@ -29,12 +31,18 @@ async fn test_initiate_accept_stream() { to: Some("server".into()), id: Some("client-id".into()), }, + Timeouts::tight(), ) .await?; Ok::<_, io::Error>(stream.take_header()) }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; assert_eq!(stream.header().from.unwrap(), "client"); assert_eq!(stream.header().to.unwrap(), "server"); assert_eq!(stream.header().id.unwrap(), "client-id"); @@ -61,13 +69,19 @@ async fn test_exchange_stream_features() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (features, _) = stream.recv_features::().await?; Ok::<_, io::Error>(features) }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; stream .send_features::(&StreamFeatures::default()) @@ -88,6 +102,7 @@ async fn test_exchange_data() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -104,7 +119,12 @@ async fn test_exchange_data() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -134,6 +154,7 @@ async fn test_clean_shutdown() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -146,7 +167,12 @@ async fn test_clean_shutdown() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -172,6 +198,7 @@ async fn test_exchange_data_stream_reset_and_shutdown() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -215,7 +242,12 @@ async fn test_exchange_data_stream_reset_and_shutdown() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -262,3 +294,104 @@ async fn test_exchange_data_stream_reset_and_shutdown() { responder.await.unwrap().expect("responder failed"); initiator.await.unwrap().expect("initiator failed"); } + +#[tokio::test(start_paused = true)] +async fn test_emits_soft_timeout_after_silence() { + let (lhs, rhs) = tokio::io::duplex(65536); + + let client_timeouts = Timeouts { + read_timeout: Duration::new(300, 0), + response_timeout: Duration::new(15, 0), + }; + + // We do want to trigger only one set of timeouts, so we set the server + // timeouts much longer than the client timeouts + let server_timeouts = Timeouts { + read_timeout: Duration::new(900, 0), + response_timeout: Duration::new(15, 0), + }; + + let initiator = tokio::spawn(async move { + let stream = initiate_stream( + tokio::io::BufStream::new(lhs), + "jabber:client", + StreamHeader::default(), + client_timeouts, + ) + .await?; + let (_, mut stream) = stream.recv_features::().await?; + stream + .send(&Data { + contents: "hello".to_owned(), + }) + .await?; + match stream.next().await { + Some(Ok(Data { contents })) => assert_eq!(contents, "world!"), + other => panic!("unexpected stream message: {:?}", other), + } + // Here we prove that the stream doesn't see any data and also does + // not see the SoftTimeout too early. + // (Well, not exactly a proof: We only check until half of the read + // timeout, because that was easy to write and I deem it good enough.) + match tokio::time::timeout(client_timeouts.read_timeout / 2, stream.next()).await { + Err(_) => (), + Ok(ev) => panic!("early stream message (before soft timeout): {:?}", ev), + }; + // Now the next thing that happens is the soft timeout ... + match stream.next().await { + Some(Err(ReadError::SoftTimeout)) => (), + other => panic!("unexpected stream message: {:?}", other), + } + // Another check that the there is some time between soft and hard + // timeout. + match tokio::time::timeout(client_timeouts.response_timeout / 3, stream.next()).await { + Err(_) => (), + Ok(ev) => { + panic!("early stream message (before hard timeout): {:?}", ev); + } + }; + // ... and thereafter the hard timeout in form of an I/O error. + match stream.next().await { + Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::TimedOut => (), + other => panic!("unexpected stream message: {:?}", other), + } + Ok::<_, io::Error>(()) + }); + + let responder = tokio::spawn(async move { + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + server_timeouts, + ) + .await?; + let stream = stream.send_header(StreamHeader::default()).await?; + let mut stream = stream + .send_features::(&StreamFeatures::default()) + .await?; + stream + .send(&Data { + contents: "world!".to_owned(), + }) + .await?; + match stream.next().await { + Some(Ok(Data { contents })) => assert_eq!(contents, "hello"), + other => panic!("unexpected stream message: {:?}", other), + } + match stream.next().await { + Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::InvalidData => { + match e.downcast::() { + // the initiator closes the stream by dropping it once the + // timeout trips, so we get a hard eof here. + Ok(rxml::Error::InvalidEof(_)) => (), + other => panic!("unexpected error: {:?}", other), + } + } + other => panic!("unexpected stream message: {:?}", other), + } + Ok::<_, io::Error>(()) + }); + + responder.await.unwrap().expect("responder failed"); + initiator.await.unwrap().expect("initiator failed"); +} diff --git a/xmpp/src/builder.rs b/xmpp/src/builder.rs index 8b0e769d9ca7a389823149c75e7b529ec1a8f77b..f8f29ce3fa0bf06ccc0d4e608ab31aad6949e3a9 100644 --- a/xmpp/src/builder.rs +++ b/xmpp/src/builder.rs @@ -15,6 +15,7 @@ use tokio_xmpp::{ disco::{DiscoInfoResult, Feature, Identity}, ns, }, + xmlstream::Timeouts, Client as TokioXmppClient, }; @@ -51,6 +52,7 @@ pub struct ClientBuilder<'a, C: ServerConnector> { disco: (ClientType, String), features: Vec, resource: Option, + timeouts: Timeouts, } #[cfg(any(feature = "starttls-rust", feature = "starttls-native"))] @@ -80,6 +82,7 @@ impl ClientBuilder<'_, C> { disco: (ClientType::default(), String::from("tokio-xmpp")), features: vec![], resource: None, + timeouts: Timeouts::default(), } } @@ -109,6 +112,15 @@ impl ClientBuilder<'_, C> { self } + /// Configure the timeouts used. + /// + /// See [`Timeouts`] for more information on the semantics and the + /// defaults (which are used unless you call this method). + pub fn set_timeouts(mut self, timeouts: Timeouts) -> Self { + self.timeouts = timeouts; + self + } + pub fn enable_feature(mut self, feature: ClientFeature) -> Self { self.features.push(feature); self @@ -146,8 +158,12 @@ impl ClientBuilder<'_, C> { self.jid.clone().into() }; - let mut client = - TokioXmppClient::new_with_connector(jid, self.password, self.server_connector.clone()); + let mut client = TokioXmppClient::new_with_connector( + jid, + self.password, + self.server_connector.clone(), + self.timeouts, + ); client.set_reconnect(true); self.build_impl(client) }